/* * Copyright (c) Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include #include "htc_rpc_client.hpp" namespace ams::htc::server::rpc { namespace { constexpr inline size_t ThreadStackSize = os::MemoryPageSize; alignas(os::ThreadStackAlignment) constinit u8 g_receive_thread_stack[ThreadStackSize]; alignas(os::ThreadStackAlignment) constinit u8 g_send_thread_stack[ThreadStackSize]; constinit os::SdkMutex g_rpc_mutex; constinit RpcTaskIdFreeList g_task_id_free_list; constinit RpcTaskTable g_task_table; } RpcClient::RpcClient(driver::IDriver *driver, htclow::ChannelId channel) : m_allocator(nullptr), m_driver(driver), m_channel_id(channel), m_receive_thread_stack(g_receive_thread_stack), m_send_thread_stack(g_send_thread_stack), m_mutex(g_rpc_mutex), m_task_id_free_list(g_task_id_free_list), m_task_table(g_task_table), m_task_active(), m_is_htcs_task(), m_task_queue(), m_cancelled(false), m_thread_running(false) { /* Initialize all events. */ for (size_t i = 0; i < MaxRpcCount; ++i) { os::InitializeEvent(std::addressof(m_receive_buffer_available_events[i]), false, os::EventClearMode_AutoClear); os::InitializeEvent(std::addressof(m_send_buffer_available_events[i]), false, os::EventClearMode_AutoClear); } } RpcClient::RpcClient(mem::StandardAllocator *allocator, driver::IDriver *driver, htclow::ChannelId channel) : m_allocator(allocator), m_driver(driver), m_channel_id(channel), m_receive_thread_stack(m_allocator->Allocate(ThreadStackSize, os::ThreadStackAlignment)), m_send_thread_stack(m_allocator->Allocate(ThreadStackSize, os::ThreadStackAlignment)), m_mutex(g_rpc_mutex), m_task_id_free_list(g_task_id_free_list), m_task_table(g_task_table), m_task_active(), m_is_htcs_task(), m_task_queue(), m_cancelled(false), m_thread_running(false) { /* Initialize all events. */ for (size_t i = 0; i < MaxRpcCount; ++i) { os::InitializeEvent(std::addressof(m_receive_buffer_available_events[i]), false, os::EventClearMode_AutoClear); os::InitializeEvent(std::addressof(m_send_buffer_available_events[i]), false, os::EventClearMode_AutoClear); } } RpcClient::~RpcClient() { /* Finalize all events. */ for (size_t i = 0; i < MaxRpcCount; ++i) { os::FinalizeEvent(std::addressof(m_receive_buffer_available_events[i])); os::FinalizeEvent(std::addressof(m_send_buffer_available_events[i])); } /* Free the thread stacks. */ if (m_allocator != nullptr) { m_allocator->Free(m_receive_thread_stack); m_allocator->Free(m_send_thread_stack); } m_receive_thread_stack = nullptr; m_send_thread_stack = nullptr; /* Free all tasks. */ for (u32 i = 0; i < MaxRpcCount; ++i) { if (m_task_active[i]) { std::scoped_lock lk(m_mutex); m_task_table.Delete(i); m_task_id_free_list.Free(i); } } } void RpcClient::Open() { R_ABORT_UNLESS(m_driver->Open(m_channel_id)); } void RpcClient::Close() { m_driver->Close(m_channel_id); } Result RpcClient::Start() { /* Connect. */ R_TRY(m_driver->Connect(m_channel_id)); /* Initialize our task queue. */ m_task_queue.Initialize(); /* Create our threads. */ R_ABORT_UNLESS(os::CreateThread(std::addressof(m_receive_thread), ReceiveThreadEntry, this, m_receive_thread_stack, ThreadStackSize, AMS_GET_SYSTEM_THREAD_PRIORITY(htc, HtcmiscReceive))); R_ABORT_UNLESS(os::CreateThread(std::addressof(m_send_thread), SendThreadEntry, this, m_send_thread_stack, ThreadStackSize, AMS_GET_SYSTEM_THREAD_PRIORITY(htc, HtcmiscSend))); /* Set thread name pointers. */ os::SetThreadNamePointer(std::addressof(m_receive_thread), AMS_GET_SYSTEM_THREAD_NAME(htc, HtcmiscReceive)); os::SetThreadNamePointer(std::addressof(m_send_thread), AMS_GET_SYSTEM_THREAD_NAME(htc, HtcmiscSend)); /* Start threads. */ os::StartThread(std::addressof(m_receive_thread)); os::StartThread(std::addressof(m_send_thread)); /* Set initial state. */ m_cancelled = false; m_thread_running = true; /* Clear events. */ for (size_t i = 0; i < MaxRpcCount; ++i) { os::ClearEvent(std::addressof(m_receive_buffer_available_events[i])); os::ClearEvent(std::addressof(m_send_buffer_available_events[i])); } R_SUCCEED(); } void RpcClient::Cancel() { /* Set cancelled. */ m_cancelled = true; /* Signal all events. */ for (size_t i = 0; i < MaxRpcCount; ++i) { os::SignalEvent(std::addressof(m_receive_buffer_available_events[i])); os::SignalEvent(std::addressof(m_send_buffer_available_events[i])); } /* Cancel our queue. */ m_task_queue.Cancel(); } void RpcClient::Wait() { /* Wait for thread to not be running. */ if (m_thread_running) { os::WaitThread(std::addressof(m_receive_thread)); os::WaitThread(std::addressof(m_send_thread)); os::DestroyThread(std::addressof(m_receive_thread)); os::DestroyThread(std::addressof(m_send_thread)); } m_thread_running = false; /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Finalize the task queue. */ m_task_queue.Finalize(); /* Cancel all tasks. */ for (size_t i = 0; i < MaxRpcCount; ++i) { if (m_task_active[i]) { m_task_table.Get(i)->Cancel(RpcTaskCancelReason::ClientFinalized); } } } int RpcClient::WaitAny(htclow::ChannelState state, os::EventType *event) { /* Check if we're already signaled. */ if (os::TryWaitEvent(event)) { return 1; } /* Wait. */ while (m_driver->GetChannelState(m_channel_id) != state) { const auto idx = os::WaitAny(m_driver->GetChannelStateEvent(m_channel_id), event); if (idx != 0) { return idx; } /* Clear the channel state event. */ os::ClearEvent(m_driver->GetChannelStateEvent(m_channel_id)); } return 0; } Result RpcClient::ReceiveThread() { /* Loop forever. */ auto *header = reinterpret_cast(m_receive_buffer); while (true) { /* Try to receive a packet header. */ R_TRY(this->ReceiveHeader(header)); /* Track how much we've received. */ size_t received = sizeof(*header); /* If the packet has one, receive its body. */ if (header->body_size > 0) { /* Sanity check the task id. */ AMS_ABORT_UNLESS(header->task_id < static_cast(MaxRpcCount)); /* Sanity check the body size. */ AMS_ABORT_UNLESS(util::IsIntValueRepresentable(header->body_size)); AMS_ABORT_UNLESS(static_cast(header->body_size) <= sizeof(m_receive_buffer) - received); /* Receive the body. */ R_TRY(this->ReceiveBody(header->data, header->body_size)); /* Note that we received the body. */ received += header->body_size; } /* Acquire exclusive access to the task tables. */ std::scoped_lock lk(m_mutex); /* Get the specified task. */ Task *task = m_task_table.Get(header->task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* If the task is canceled, free it. */ if (task->GetTaskState() == RpcTaskState::Cancelled) { m_task_active[header->task_id] = false; m_is_htcs_task[header->task_id] = false; m_task_table.Delete(header->task_id); m_task_id_free_list.Free(header->task_id); continue; } /* Handle the packet. */ switch (header->category) { case PacketCategory::Response: R_TRY(task->ProcessResponse(m_receive_buffer, received)); break; case PacketCategory::Notification: R_TRY(task->ProcessNotification(m_receive_buffer, received)); break; default: R_THROW(htc::ResultInvalidCategory()); } /* If we used the receive buffer, signal that we're done with it. */ if (task->IsReceiveBufferRequired()) { os::SignalEvent(std::addressof(m_receive_buffer_available_events[header->task_id])); } } } Result RpcClient::ReceiveHeader(RpcPacket *header) { /* Receive. */ s64 received; R_TRY(m_driver->Receive(std::addressof(received), reinterpret_cast(header), sizeof(*header), m_channel_id, htclow::ReceiveOption_ReceiveAllData)); /* Check size. */ R_UNLESS(static_cast(received) == sizeof(*header), htc::ResultInvalidSize()); R_SUCCEED(); } Result RpcClient::ReceiveBody(char *dst, size_t size) { /* Receive. */ s64 received; R_TRY(m_driver->Receive(std::addressof(received), dst, size, m_channel_id, htclow::ReceiveOption_ReceiveAllData)); /* Check size. */ R_UNLESS(static_cast(received) == size, htc::ResultInvalidSize()); R_SUCCEED(); } Result RpcClient::SendThread() { while (true) { /* Get a task. */ Task *task; u32 task_id{}; PacketCategory category{}; do { /* Dequeue a task. */ R_TRY(m_task_queue.Take(std::addressof(task_id), std::addressof(category))); /* Get the task from the table. */ std::scoped_lock lk(m_mutex); task = m_task_table.Get(task_id); } while (task == nullptr); /* If required, wait for the send buffer to become available. */ if (task->IsSendBufferRequired()) { os::WaitEvent(std::addressof(m_send_buffer_available_events[task_id])); /* Check if we've been cancelled. */ if (m_cancelled) { break; } } /* Handle the task. */ size_t packet_size; switch (category) { case PacketCategory::Request: R_TRY(task->CreateRequest(std::addressof(packet_size), m_send_buffer, sizeof(m_send_buffer), task_id)); break; case PacketCategory::Notification: R_TRY(task->CreateNotification(std::addressof(packet_size), m_send_buffer, sizeof(m_send_buffer), task_id)); break; AMS_UNREACHABLE_DEFAULT_CASE(); } /* Send the request. */ R_TRY(this->SendRequest(m_send_buffer, packet_size)); } R_THROW(htc::ResultCancelled()); } Result RpcClient::SendRequest(const char *src, size_t size) { /* Sanity check our size. */ AMS_ASSERT(util::IsIntValueRepresentable(size)); /* Send the data. */ s64 sent; R_TRY(m_driver->Send(std::addressof(sent), src, static_cast(size), m_channel_id)); /* Check that we sent the right amount. */ R_UNLESS(sent == static_cast(size), htc::ResultInvalidSize()); R_SUCCEED(); } void RpcClient::CancelBySocket(s32 handle) { /* Check if we need to cancel each task. */ for (size_t i = 0; i < MaxRpcCount; ++i) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Check that the task is active and is an htcs task. */ if (!m_task_active[i] || !m_is_htcs_task[i]) { continue; } /* Get the htcs task. */ auto *htcs_task = m_task_table.Get(i); /* Handle the case where the task handle is the one we're cancelling. */ if (this->GetTaskHandle(i) == handle) { /* If the task is complete, free it. */ if (htcs_task->GetTaskState() == RpcTaskState::Completed) { m_task_active[i] = false; m_is_htcs_task[i] = false; m_task_table.Delete(i); m_task_id_free_list.Free(i); } else { /* If the task is a send task, notify. */ if (htcs_task->GetTaskType() == htcs::impl::rpc::HtcsTaskType::Send) { m_task_queue.Add(i, PacketCategory::Notification); } /* Cancel the task. */ htcs_task->Cancel(RpcTaskCancelReason::BySocket); } /* The task has been cancelled, so we can move on. */ continue; } /* Handle the case where the task is a select task. */ if (htcs_task->GetTaskType() == htcs::impl::rpc::HtcsTaskType::Select) { /* Get the select task. */ auto *select_task = m_task_table.Get(i); /* Get the handle counts. */ const auto num_read = select_task->GetReadHandleCount(); const auto num_write = select_task->GetWriteHandleCount(); const auto num_exception = select_task->GetExceptionHandleCount(); const auto total = num_read + num_write + num_exception; /* Get the handle array. */ const auto *handles = select_task->GetHandles(); /* Check each handle. */ for (auto handle_idx = 0; handle_idx < total; ++handle_idx) { if (handles[handle_idx] == handle) { /* If the select is complete, free it. */ if (select_task->GetTaskState() == RpcTaskState::Completed) { m_task_active[i] = false; m_is_htcs_task[i] = false; m_task_table.Delete(i); m_task_id_free_list.Free(i); } else { /* Cancel the task. */ select_task->Cancel(RpcTaskCancelReason::BySocket); } } } } } } s32 RpcClient::GetTaskHandle(u32 task_id) { /* TODO: Why is this necessary to avoid a bogus array-bounds warning? */ AMS_ASSUME(task_id < MaxRpcCount); /* Check pre-conditions. */ AMS_ASSERT(m_task_active[task_id]); AMS_ASSERT(m_is_htcs_task[task_id]); /* Get the htcs task. */ auto *task = m_task_table.Get(task_id); /* Check that the task has a handle. */ if (!m_task_active[task_id] || !m_is_htcs_task[task_id] || task == nullptr) { return -1; } /* Get the task's type. */ const auto type = task->GetTaskType(); /* Check that the task is new enough. */ if (task->GetVersion() == 3) { if (type == htcs::impl::rpc::HtcsTaskType::Receive || type == htcs::impl::rpc::HtcsTaskType::Send) { return -1; } } /* Get the handle from the task. */ switch (type) { case htcs::impl::rpc::HtcsTaskType::Receive: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Send: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Shutdown: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Close: return -1; case htcs::impl::rpc::HtcsTaskType::Connect: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Listen: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Accept: return static_cast(task)->GetServerHandle(); case htcs::impl::rpc::HtcsTaskType::Socket: return -1; case htcs::impl::rpc::HtcsTaskType::Bind: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Fcntl: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::ReceiveSmall: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::SendSmall: return static_cast(task)->GetHandle(); case htcs::impl::rpc::HtcsTaskType::Select: return -1; AMS_UNREACHABLE_DEFAULT_CASE(); } } }