/* * Copyright (c) 2018-2020 Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #pragma once #include #include "../driver/htc_i_driver.hpp" #include "htc_rpc_task_table.hpp" #include "htc_rpc_task_queue.hpp" #include "htc_rpc_task_id_free_list.hpp" #include "../../../htcs/impl/rpc/htcs_rpc_tasks.hpp" namespace ams::htc::server::rpc { template concept IsRpcTask = std::derived_from; struct RpcTaskFunctionTraits { public: template static std::tuple GetSetArgumentsImpl(R(C::*)(A...)); template static std::tuple GetGetResultImpl(R(C::*)(A...) const); }; template requires IsRpcTask using RpcTaskArgumentsType = decltype(RpcTaskFunctionTraits::GetSetArgumentsImpl(&T::SetArguments)); template requires IsRpcTask using RpcTaskResultsType = decltype(RpcTaskFunctionTraits::GetGetResultImpl(&T::GetResult)); template requires IsRpcTask using RpcTaskArgumentType = typename std::tuple_element>::type; template requires IsRpcTask using RpcTaskResultType = typename std::tuple_element>::type; class RpcClient { private: /* TODO: where is this value coming from, again? */ static constexpr size_t BufferSize = 0xE400; private: mem::StandardAllocator *m_allocator; driver::IDriver *m_driver; htclow::ChannelId m_channel_id; void *m_receive_thread_stack; void *m_send_thread_stack; os::ThreadType m_receive_thread; os::ThreadType m_send_thread; os::SdkMutex &m_mutex; RpcTaskIdFreeList &m_task_id_free_list; RpcTaskTable &m_task_table; bool m_task_active[MaxRpcCount]; bool m_is_htcs_task[MaxRpcCount]; RpcTaskQueue m_task_queue; bool m_cancelled; bool m_thread_running; os::EventType m_receive_buffer_available_events[MaxRpcCount]; os::EventType m_send_buffer_available_events[MaxRpcCount]; char m_receive_buffer[BufferSize]; char m_send_buffer[BufferSize]; private: static void ReceiveThreadEntry(void *arg) { static_cast(arg)->ReceiveThread(); } static void SendThreadEntry(void *arg) { static_cast(arg)->SendThread(); } Result ReceiveThread(); Result SendThread(); public: RpcClient(driver::IDriver *driver, htclow::ChannelId channel); RpcClient(mem::StandardAllocator *allocator, driver::IDriver *driver, htclow::ChannelId channel); ~RpcClient(); public: void Open(); void Close(); Result Start(); void Cancel(); void Wait(); int WaitAny(htclow::ChannelState state, os::EventType *event); private: Result ReceiveHeader(RpcPacket *header); Result ReceiveBody(char *dst, size_t size); Result SendRequest(const char *src, size_t size); private: template requires IsRpcTask ALWAYS_INLINE Result BeginImpl(std::index_sequence, u32 *out_task_id, RpcTaskArgumentType... args) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Allocate a free task id. */ u32 task_id; R_TRY(m_task_id_free_list.Allocate(std::addressof(task_id))); /* Create the new task. */ T *task = m_task_table.New(task_id); m_task_active[task_id] = true; m_is_htcs_task[task_id] = htcs::impl::rpc::IsHtcsTask; /* Ensure we clean up the task, if we fail after this. */ auto task_guard = SCOPE_GUARD { m_task_active[task_id] = false; m_is_htcs_task[task_id] = false; m_task_table.Delete(task_id); m_task_id_free_list.Free(task_id); }; /* Set the task arguments. */ R_TRY(task->SetArguments(args...)); /* Clear the task's events. */ os::ClearEvent(std::addressof(m_receive_buffer_available_events[task_id])); os::ClearEvent(std::addressof(m_send_buffer_available_events[task_id])); /* Add the task to our queue if we can, or cancel it. */ if (m_thread_running) { m_task_queue.Add(task_id, PacketCategory::Request); } else { task->Cancel(RpcTaskCancelReason::QueueNotAvailable); } /* Set the output task id. */ *out_task_id = task_id; /* We succeeded. */ task_guard.Cancel(); return ResultSuccess(); } template requires IsRpcTask ALWAYS_INLINE Result GetResultImpl(std::index_sequence, u32 task_id, RpcTaskResultType... args) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* Check that the task is completed. */ R_UNLESS(task->GetTaskState() == RpcTaskState::Completed, htc::ResultTaskNotCompleted()); /* Get the task's result. */ R_TRY(task->GetResult(args...)); return ResultSuccess(); } template requires IsRpcTask ALWAYS_INLINE Result EndImpl(std::index_sequence, u32 task_id, RpcTaskResultType... args) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* Ensure the task is freed if it needs to be, when we're done. */ auto task_guard = SCOPE_GUARD { m_task_active[task_id] = false; m_is_htcs_task[task_id] = false; m_task_table.Delete(task_id); m_task_id_free_list.Free(task_id); }; /* If the task was cancelled, handle that. */ if (task->GetTaskState() == RpcTaskState::Cancelled) { switch (task->GetTaskCancelReason()) { case RpcTaskCancelReason::BySocket: task_guard.Cancel(); return htc::ResultTaskCancelled(); case RpcTaskCancelReason::ClientFinalized: return htc::ResultCancelled(); case RpcTaskCancelReason::QueueNotAvailable: return htc::ResultTaskQueueNotAvailable(); AMS_UNREACHABLE_DEFAULT_CASE(); } } /* Get the task's result. */ R_TRY(task->GetResult(args...)); return ResultSuccess(); } s32 GetTaskHandle(u32 task_id); public: void Wait(u32 task_id) { os::WaitEvent(m_task_table.Get(task_id)->GetEvent()); } Handle DetachReadableHandle(u32 task_id) { return os::DetachReadableHandleOfSystemEvent(m_task_table.Get(task_id)->GetSystemEvent()); } void CancelBySocket(s32 handle); template requires (IsRpcTask && sizeof...(Args) == std::tuple_size>::value) Result Begin(u32 *out_task_id, Args &&... args) { return this->BeginImpl(std::make_index_sequence>::value>(), out_task_id, std::forward(args)...); } template requires (IsRpcTask && sizeof...(Args) == std::tuple_size>::value) Result GetResult(u32 task_id, Args &&... args) { return this->GetResultImpl(std::make_index_sequence>::value>(), task_id, std::forward(args)...); } template requires (IsRpcTask && sizeof...(Args) == std::tuple_size>::value) Result End(u32 task_id, Args &&... args) { return this->EndImpl(std::make_index_sequence>::value>(), task_id, std::forward(args)...); } template requires IsRpcTask Result VerifyTaskIdWitHandle(u32 task_id, s32 handle) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* Check the task handle. */ R_UNLESS(task->GetHandle() == handle, htc::ResultInvalidTaskId()); return ResultSuccess(); } template requires IsRpcTask Result Notify(u32 task_id) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Check that our queue is available. */ R_UNLESS(m_thread_running, htc::ResultTaskQueueNotAvailable()); /* Get the task. */ T *task = m_task_table.Get(task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* Add notification to our queue. */ m_task_queue.Add(task_id, PacketCategory::Notification); return ResultSuccess(); } template requires IsRpcTask void WaitNotification(u32 task_id) { /* Get the task from the table, releasing our lock afterwards. */ T *task; { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ task = m_task_table.Get(task_id); } /* Wait for a notification. */ task->WaitNotification(); } template requires IsRpcTask bool IsCancelled(u32 task_id) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); /* Check the task state. */ return task != nullptr && task->GetTaskState() == RpcTaskState::Cancelled; } template requires IsRpcTask bool IsCompleted(u32 task_id) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); /* Check the task state. */ return task != nullptr && task->GetTaskState() == RpcTaskState::Completed; } template requires IsRpcTask Result SendContinue(u32 task_id, const void *buffer, s64 buffer_size) { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* If the task was cancelled, handle that. */ if (task->GetTaskState() == RpcTaskState::Cancelled) { switch (task->GetTaskCancelReason()) { case RpcTaskCancelReason::QueueNotAvailable: return htc::ResultTaskQueueNotAvailable(); default: return htc::ResultTaskCancelled(); } } /* Set the task's buffer. */ if (buffer_size > 0) { task->SetBuffer(buffer, buffer_size); os::SignalEvent(std::addressof(m_send_buffer_available_events[task_id])); } return ResultSuccess(); } template requires IsRpcTask Result ReceiveContinue(u32 task_id, void *buffer, s64 buffer_size) { /* Get the task's buffer, and prepare to receive. */ const void *result_buffer; s64 result_size; { /* Lock ourselves. */ std::scoped_lock lk(m_mutex); /* Get the task. */ T *task = m_task_table.Get(task_id); R_UNLESS(task != nullptr, htc::ResultInvalidTaskId()); /* If the task was cancelled, handle that. */ if (task->GetTaskState() == RpcTaskState::Cancelled) { switch (task->GetTaskCancelReason()) { case RpcTaskCancelReason::QueueNotAvailable: return htc::ResultTaskQueueNotAvailable(); default: return htc::ResultTaskCancelled(); } } /* Get the result size. */ result_size = task->GetResultSize(); R_SUCCEED_IF(result_size == 0); /* Get the result buffer. */ result_buffer = task->GetBuffer(); } /* Wait for the receive buffer to become available. */ os::WaitEvent(std::addressof(m_receive_buffer_available_events[task_id])); /* Check that we weren't cancelled. */ R_UNLESS(!m_cancelled, htc::ResultCancelled()); /* Copy the received data. */ AMS_ASSERT(0 <= result_size && result_size <= buffer_size); std::memcpy(buffer, result_buffer, result_size); return ResultSuccess(); } }; }