mirror of
https://github.com/Atmosphere-NX/Atmosphere
synced 2025-01-08 21:47:57 +00:00
kern: partially implement Receive half of ReplyAndReceive
This commit is contained in:
parent
1b2203d102
commit
84b1be1d58
5 changed files with 315 additions and 33 deletions
|
@ -46,6 +46,8 @@ namespace ams::kern {
|
|||
|
||||
static void PostDestroy(uintptr_t arg);
|
||||
|
||||
virtual KProcess *GetOwner() const override { return this->owner; }
|
||||
|
||||
KReadableEvent &GetReadableEvent() { return this->readable_event; }
|
||||
KWritableEvent &GetWritableEvent() { return this->writable_event; }
|
||||
};
|
||||
|
|
|
@ -155,21 +155,7 @@ namespace ams::kern {
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T = KAutoObject>
|
||||
ALWAYS_INLINE KScopedAutoObject<T> GetObjectForIpc(ams::svc::Handle handle) const {
|
||||
static_assert(!std::is_base_of<KInterruptEvent, T>::value);
|
||||
|
||||
/* Handle pseudo-handles. */
|
||||
if constexpr (std::is_base_of<T, KProcess>::value) {
|
||||
if (handle == ams::svc::PseudoHandle::CurrentProcess) {
|
||||
return GetCurrentProcessPointer();
|
||||
}
|
||||
} else if constexpr (std::is_base_of<T, KThread>::value) {
|
||||
if (handle == ams::svc::PseudoHandle::CurrentThread) {
|
||||
return GetCurrentThreadPointer();
|
||||
}
|
||||
}
|
||||
|
||||
ALWAYS_INLINE KScopedAutoObject<KAutoObject> GetObjectForIpcWithoutPseudoHandle(ams::svc::Handle handle) const {
|
||||
/* Lock and look up in table. */
|
||||
KScopedDisableDispatch dd;
|
||||
KScopedSpinLock lk(this->lock);
|
||||
|
@ -178,15 +164,20 @@ namespace ams::kern {
|
|||
if (obj->DynamicCast<KInterruptEvent *>() != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if constexpr (std::is_same<T, KAutoObject>::value) {
|
||||
|
||||
return obj;
|
||||
} else {
|
||||
if (auto *obj = this->GetObjectImpl(handle); obj != nullptr) {
|
||||
return obj->DynamicCast<T*>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ALWAYS_INLINE KScopedAutoObject<KAutoObject> GetObjectForIpc(ams::svc::Handle handle, KThread *cur_thread) const {
|
||||
/* Handle pseudo-handles. */
|
||||
if (handle == ams::svc::PseudoHandle::CurrentProcess) {
|
||||
return static_cast<KAutoObject *>(static_cast<void *>(cur_thread->GetOwnerProcess()));
|
||||
}
|
||||
if (handle == ams::svc::PseudoHandle::CurrentThread) {
|
||||
return static_cast<KAutoObject *>(cur_thread);
|
||||
}
|
||||
|
||||
return GetObjectForIpcWithoutPseudoHandle(handle);
|
||||
}
|
||||
|
||||
ALWAYS_INLINE KScopedAutoObject<KAutoObject> GetObjectByIndex(ams::svc::Handle *out_handle, size_t index) const {
|
||||
|
|
|
@ -128,6 +128,11 @@ namespace ams::kern {
|
|||
size_t GetSize() const { return this->size; }
|
||||
KProcess *GetServerProcess() const { return this->server; }
|
||||
|
||||
void SetServerProcess(KProcess *process) {
|
||||
this->server = process;
|
||||
this->server->Open();
|
||||
}
|
||||
|
||||
void ClearThread() { this->thread = nullptr; }
|
||||
void ClearEvent() { this->event = nullptr; }
|
||||
|
||||
|
|
|
@ -17,8 +17,246 @@
|
|||
|
||||
namespace ams::kern {
|
||||
|
||||
namespace ipc {
|
||||
|
||||
using MessageBuffer = ams::svc::ipc::MessageBuffer;
|
||||
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class ReceiveList {
|
||||
private:
|
||||
u32 data[ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountMax * ipc::MessageBuffer::ReceiveListEntry::GetDataSize() / sizeof(u32)];
|
||||
s32 recv_list_count;
|
||||
uintptr_t msg_buffer_end;
|
||||
uintptr_t msg_buffer_space_end;
|
||||
public:
|
||||
static constexpr int GetEntryCount(const ipc::MessageBuffer::MessageHeader &header) {
|
||||
const auto count = header.GetReceiveListCount();
|
||||
switch (count) {
|
||||
case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_None:
|
||||
return 0;
|
||||
case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer:
|
||||
return 0;
|
||||
case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToSingleBuffer:
|
||||
return 1;
|
||||
default:
|
||||
return count - ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset;
|
||||
}
|
||||
}
|
||||
public:
|
||||
ReceiveList(const u32 *dst_msg, const ipc::MessageBuffer::MessageHeader &dst_header, const ipc::MessageBuffer::SpecialHeader &dst_special_header, size_t msg_size, size_t out_offset, s32 dst_recv_list_idx) {
|
||||
this->recv_list_count = dst_header.GetReceiveListCount();
|
||||
this->msg_buffer_end = reinterpret_cast<uintptr_t>(dst_msg) + sizeof(u32) * out_offset;
|
||||
this->msg_buffer_space_end = reinterpret_cast<uintptr_t>(dst_msg) + msg_size;
|
||||
|
||||
const u32 *recv_list = dst_msg + dst_recv_list_idx;
|
||||
__builtin_memcpy(this->data, recv_list, GetEntryCount(dst_header) * ipc::MessageBuffer::ReceiveListEntry::GetDataSize());
|
||||
}
|
||||
|
||||
constexpr bool IsIndex() const {
|
||||
return this->recv_list_count > ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset;
|
||||
}
|
||||
};
|
||||
|
||||
template<bool MoveHandleAllowed>
|
||||
ALWAYS_INLINE Result ProcessMessageSpecialData(int &offset, KProcess &dst_process, KProcess &src_process, KThread &src_thread, const ipc::MessageBuffer &dst_msg, const ipc::MessageBuffer &src_msg, const ipc::MessageBuffer::SpecialHeader &src_special_header) {
|
||||
/* Copy the special header to the destination. */
|
||||
offset = dst_msg.Set(src_special_header);
|
||||
|
||||
/* Copy the process ID. */
|
||||
if (src_special_header.GetHasProcessId()) {
|
||||
/* TODO: Atmosphere mitm extension support. */
|
||||
offset = dst_msg.SetProcessId(offset, src_process.GetId());
|
||||
}
|
||||
|
||||
/* Prepare to process handles. */
|
||||
auto &dst_handle_table = dst_process.GetHandleTable();
|
||||
auto &src_handle_table = src_process.GetHandleTable();
|
||||
Result result = ResultSuccess();
|
||||
|
||||
/* Process copy handles. */
|
||||
for (auto i = 0; i < src_special_header.GetCopyHandleCount(); ++i) {
|
||||
/* Get the handles. */
|
||||
const ams::svc::Handle src_handle = src_msg.GetHandle(offset);
|
||||
ams::svc::Handle dst_handle = ams::svc::InvalidHandle;
|
||||
|
||||
/* If we're in a success state, try to move the handle to the new table. */
|
||||
if (R_SUCCEEDED(result) && src_handle != ams::svc::InvalidHandle) {
|
||||
KScopedAutoObject obj = src_handle_table.GetObjectForIpc(src_handle, std::addressof(src_thread));
|
||||
if (obj.IsNotNull()) {
|
||||
Result add_result = dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe());
|
||||
if (R_FAILED(add_result)) {
|
||||
result = add_result;
|
||||
dst_handle = ams::svc::InvalidHandle;
|
||||
}
|
||||
} else {
|
||||
result = svc::ResultInvalidHandle();
|
||||
}
|
||||
}
|
||||
|
||||
/* Set the handle. */
|
||||
offset = dst_msg.SetHandle(offset, dst_handle);
|
||||
}
|
||||
|
||||
/* Process move handles. */
|
||||
if constexpr (MoveHandleAllowed) {
|
||||
for (auto i = 0; i < src_special_header.GetMoveHandleCount(); ++i) {
|
||||
/* Get the handles. */
|
||||
const ams::svc::Handle src_handle = src_msg.GetHandle(offset);
|
||||
ams::svc::Handle dst_handle = ams::svc::InvalidHandle;
|
||||
|
||||
/* Whether or not we've succeeded, we need to remove the handles from the source table. */
|
||||
if (src_handle != ams::svc::InvalidHandle) {
|
||||
if (R_SUCCEEDED(result)) {
|
||||
KScopedAutoObject obj = src_handle_table.GetObjectForIpcWithoutPseudoHandle(src_handle);
|
||||
if (obj.IsNotNull()) {
|
||||
Result add_result = dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe());
|
||||
|
||||
src_handle_table.Remove(src_handle);
|
||||
|
||||
if (R_FAILED(add_result)) {
|
||||
result = add_result;
|
||||
dst_handle = ams::svc::InvalidHandle;
|
||||
}
|
||||
} else {
|
||||
result = svc::ResultInvalidHandle();
|
||||
}
|
||||
} else {
|
||||
src_handle_table.Remove(src_handle);
|
||||
}
|
||||
}
|
||||
|
||||
/* Set the handle. */
|
||||
offset = dst_msg.SetHandle(offset, dst_handle);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ALWAYS_INLINE Result ReceiveMessage(bool &recv_list_broken, uintptr_t dst_message_buffer, size_t dst_buffer_size, KPhysicalAddress dst_message_paddr, KThread &src_thread, uintptr_t src_message_buffer, size_t src_buffer_size, KServerSession *session, KSessionRequest *request) {
|
||||
/* Prepare variables for receive. */
|
||||
const KThread &dst_thread = GetCurrentThread();
|
||||
KProcess &dst_process = *(dst_thread.GetOwnerProcess());
|
||||
KProcess &src_process = *(src_thread.GetOwnerProcess());
|
||||
auto &dst_page_table = dst_process.GetPageTable();
|
||||
auto &src_page_table = src_process.GetPageTable();
|
||||
|
||||
/* The receive list is initially not broken. */
|
||||
recv_list_broken = false;
|
||||
|
||||
/* Set the server process for the request. */
|
||||
request->SetServerProcess(std::addressof(dst_process));
|
||||
|
||||
/* Determine the message buffers. */
|
||||
u32 *dst_msg_ptr, *src_msg_ptr;
|
||||
bool dst_user, src_user;
|
||||
|
||||
if (dst_message_buffer) {
|
||||
dst_msg_ptr = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(dst_message_paddr));
|
||||
dst_user = true;
|
||||
} else {
|
||||
dst_msg_ptr = static_cast<ams::svc::ThreadLocalRegion *>(dst_thread.GetThreadLocalRegionHeapAddress())->message_buffer;
|
||||
dst_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer);
|
||||
dst_message_buffer = GetInteger(dst_thread.GetThreadLocalRegionAddress());
|
||||
dst_user = false;
|
||||
}
|
||||
|
||||
if (src_message_buffer) {
|
||||
/* NOTE: Nintendo does not check the result of this GetPhysicalAddress call. */
|
||||
KPhysicalAddress src_message_paddr;
|
||||
src_page_table.GetPhysicalAddress(std::addressof(src_message_paddr), src_message_buffer);
|
||||
|
||||
src_msg_ptr = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(src_message_paddr));
|
||||
src_user = true;
|
||||
} else {
|
||||
src_msg_ptr = static_cast<ams::svc::ThreadLocalRegion *>(src_thread.GetThreadLocalRegionHeapAddress())->message_buffer;
|
||||
src_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer);
|
||||
src_message_buffer = GetInteger(src_thread.GetThreadLocalRegionAddress());
|
||||
src_user = false;
|
||||
}
|
||||
|
||||
/* Parse the headers. */
|
||||
const ipc::MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size);
|
||||
const ipc::MessageBuffer src_msg(src_msg_ptr, src_buffer_size);
|
||||
const ipc::MessageBuffer::MessageHeader dst_header(dst_msg);
|
||||
const ipc::MessageBuffer::MessageHeader src_header(src_msg);
|
||||
const ipc::MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header);
|
||||
const ipc::MessageBuffer::SpecialHeader src_special_header(src_msg, src_header);
|
||||
|
||||
/* Get the end of the source message. */
|
||||
const size_t src_end_offset = ipc::MessageBuffer::GetRawDataIndex(src_header, src_special_header) + src_header.GetRawCount();
|
||||
|
||||
/* Ensure that the headers fit. */
|
||||
R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) <= dst_buffer_size, svc::ResultInvalidCombination());
|
||||
R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(src_header, src_special_header) <= src_buffer_size, svc::ResultInvalidCombination());
|
||||
|
||||
/* Ensure the receive list offset is after the end of raw data. */
|
||||
if (dst_header.GetReceiveListOffset()) {
|
||||
R_UNLESS(dst_header.GetReceiveListOffset() >= ipc::MessageBuffer::GetRawDataIndex(dst_header, dst_special_header) + dst_header.GetRawCount(), svc::ResultInvalidCombination());
|
||||
}
|
||||
|
||||
/* Ensure that the destination buffer is big enough to receive the source. */
|
||||
R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), svc::ResultMessageTooLarge());
|
||||
|
||||
/* Get the receive list. */
|
||||
const s32 dst_recv_list_idx = static_cast<s32>(ipc::MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header));
|
||||
ReceiveList dst_recv_list(dst_msg_ptr, dst_header, dst_special_header, dst_buffer_size, src_end_offset, dst_recv_list_idx);
|
||||
|
||||
/* Ensure that the source special header isn't invalid. */
|
||||
const bool src_has_special_header = src_header.GetHasSpecialHeader();
|
||||
if (src_has_special_header) {
|
||||
/* Sending move handles from client -> server is not allowed. */
|
||||
R_UNLESS(src_special_header.GetMoveHandleCount() == 0, svc::ResultInvalidCombination());
|
||||
}
|
||||
|
||||
/* Prepare for further processing. */
|
||||
int pointer_key = 0;
|
||||
int offset = dst_msg.Set(src_header);
|
||||
|
||||
/* Set up a guard to make sure that we end up in a clean state on error. */
|
||||
auto cleanup_guard = SCOPE_GUARD {
|
||||
/* TODO */
|
||||
MESOSPHERE_UNIMPLEMENTED();
|
||||
};
|
||||
|
||||
/* Process any special data. */
|
||||
if (src_header.GetHasSpecialHeader()) {
|
||||
/* After we process, make sure we track whether the receive list is broken. */
|
||||
ON_SCOPE_EXIT { if (offset > dst_recv_list_idx) { recv_list_broken = true; } };
|
||||
|
||||
/* Process special data. */
|
||||
R_TRY(ProcessMessageSpecialData<false>(offset, dst_process, src_process, src_thread, dst_msg, src_msg, src_special_header));
|
||||
}
|
||||
|
||||
/* Process any pointer buffers. */
|
||||
for (auto i = 0; i < src_header.GetPointerCount(); ++i) {
|
||||
MESOSPHERE_UNIMPLEMENTED();
|
||||
}
|
||||
|
||||
/* Process any map alias buffers. */
|
||||
for (auto i = 0; i < src_header.GetMapAliasCount(); ++i) {
|
||||
MESOSPHERE_UNIMPLEMENTED();
|
||||
}
|
||||
|
||||
/* Process any raw data. */
|
||||
if (src_header.GetRawCount()) {
|
||||
MESOSPHERE_UNIMPLEMENTED();
|
||||
}
|
||||
|
||||
/* TODO: Remove this when done, as these variables will be used by unimplemented stuff above. */
|
||||
static_cast<void>(dst_page_table);
|
||||
static_cast<void>(dst_user);
|
||||
static_cast<void>(src_user);
|
||||
static_cast<void>(pointer_key);
|
||||
|
||||
/* We succeeded! */
|
||||
cleanup_guard.Cancel();
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void ReplyAsyncError(KProcess *to_process, uintptr_t to_msg_buf, size_t to_msg_buf_size, Result result) {
|
||||
/* Convert the buffer to a physical address. */
|
||||
KPhysicalAddress phys_addr;
|
||||
|
@ -28,7 +266,7 @@ namespace ams::kern {
|
|||
u32 *to_msg = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(phys_addr));
|
||||
|
||||
/* Set the error. */
|
||||
ams::svc::ipc::MessageBuffer msg(to_msg, to_msg_buf_size);
|
||||
ipc::MessageBuffer msg(to_msg, to_msg_buf_size);
|
||||
msg.SetAsyncResult(result);
|
||||
}
|
||||
|
||||
|
@ -44,10 +282,56 @@ namespace ams::kern {
|
|||
this->parent->Close();
|
||||
}
|
||||
|
||||
Result KServerSession::ReceiveRequest(uintptr_t message, uintptr_t buffer_size, KPhysicalAddress message_paddr) {
|
||||
Result KServerSession::ReceiveRequest(uintptr_t server_message, uintptr_t server_buffer_size, KPhysicalAddress server_message_paddr) {
|
||||
MESOSPHERE_ASSERT_THIS();
|
||||
|
||||
/* Lock the session. */
|
||||
KScopedLightLock lk(this->lock);
|
||||
|
||||
/* Get the request and client thread. */
|
||||
KSessionRequest *request;
|
||||
KScopedAutoObject<KThread> client_thread;
|
||||
{
|
||||
KScopedSchedulerLock sl;
|
||||
|
||||
/* Ensure that we can service the request. */
|
||||
R_UNLESS(!this->parent->IsClientClosed(), svc::ResultSessionClosed());
|
||||
|
||||
/* Ensure we aren't already servicing a request. */
|
||||
R_UNLESS(this->current_request == nullptr, svc::ResultNotFound());
|
||||
|
||||
/* Ensure we have a request to service. */
|
||||
R_UNLESS(!this->request_list.empty(), svc::ResultNotFound());
|
||||
|
||||
/* Pop the first request from the list. */
|
||||
request = std::addressof(this->request_list.front());
|
||||
this->request_list.pop_front();
|
||||
|
||||
/* Get the thread for the request. */
|
||||
client_thread = KScopedAutoObject<KThread>(request->GetThread());
|
||||
R_UNLESS(client_thread.IsNotNull(), svc::ResultSessionClosed());
|
||||
}
|
||||
|
||||
/* Set the request as our current. */
|
||||
this->current_request = request;
|
||||
|
||||
/* Get the client address. */
|
||||
uintptr_t client_message = request->GetAddress();
|
||||
size_t client_buffer_size = request->GetSize();
|
||||
bool recv_list_broken = false;
|
||||
|
||||
/* Receive the message. */
|
||||
Result result = ReceiveMessage(recv_list_broken, server_message, server_buffer_size, server_message_paddr, *client_thread.GetPointerUnsafe(), client_message, client_buffer_size, this, request);
|
||||
|
||||
/* Handle cleanup on receive failure. */
|
||||
if (R_FAILED(result)) {
|
||||
/* TODO */
|
||||
MESOSPHERE_UNIMPLEMENTED();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Result KServerSession::SendReply(uintptr_t message, uintptr_t buffer_size, KPhysicalAddress message_paddr) {
|
||||
MESOSPHERE_UNIMPLEMENTED();
|
||||
}
|
||||
|
|
|
@ -439,34 +439,34 @@ namespace ams::svc::ipc {
|
|||
return index + (spc.GetHeaderSize() / sizeof(*this->buffer));
|
||||
}
|
||||
|
||||
ALWAYS_INLINE s32 SetHandle(s32 index, const ::ams::svc::Handle &hnd) {
|
||||
ALWAYS_INLINE s32 SetHandle(s32 index, const ::ams::svc::Handle &hnd) const {
|
||||
static_assert(util::IsAligned(sizeof(hnd), sizeof(*this->buffer)));
|
||||
__builtin_memcpy(this->buffer + index, std::addressof(hnd), sizeof(hnd));
|
||||
return index + (sizeof(hnd) / sizeof(*this->buffer));
|
||||
}
|
||||
|
||||
ALWAYS_INLINE s32 SetProcessId(s32 index, const u64 pid) {
|
||||
ALWAYS_INLINE s32 SetProcessId(s32 index, const u64 pid) const {
|
||||
static_assert(util::IsAligned(sizeof(pid), sizeof(*this->buffer)));
|
||||
__builtin_memcpy(this->buffer + index, std::addressof(pid), sizeof(pid));
|
||||
return index + (sizeof(pid) / sizeof(*this->buffer));
|
||||
}
|
||||
|
||||
ALWAYS_INLINE s32 Set(s32 index, const MapAliasDescriptor &desc) {
|
||||
ALWAYS_INLINE s32 Set(s32 index, const MapAliasDescriptor &desc) const {
|
||||
__builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize());
|
||||
return index + (desc.GetDataSize() / sizeof(*this->buffer));
|
||||
}
|
||||
|
||||
ALWAYS_INLINE s32 Set(s32 index, const PointerDescriptor &desc) {
|
||||
ALWAYS_INLINE s32 Set(s32 index, const PointerDescriptor &desc) const {
|
||||
__builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize());
|
||||
return index + (desc.GetDataSize() / sizeof(*this->buffer));
|
||||
}
|
||||
|
||||
ALWAYS_INLINE s32 Set(s32 index, const ReceiveListEntry &desc) {
|
||||
ALWAYS_INLINE s32 Set(s32 index, const ReceiveListEntry &desc) const {
|
||||
__builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize());
|
||||
return index + (desc.GetDataSize() / sizeof(*this->buffer));
|
||||
}
|
||||
|
||||
ALWAYS_INLINE s32 Set(s32 index, const u32 val) {
|
||||
ALWAYS_INLINE s32 Set(s32 index, const u32 val) const {
|
||||
static_assert(util::IsAligned(sizeof(val), sizeof(*this->buffer)));
|
||||
__builtin_memcpy(this->buffer + index, std::addressof(val), sizeof(val));
|
||||
return index + (sizeof(val) / sizeof(*this->buffer));
|
||||
|
@ -521,7 +521,7 @@ namespace ams::svc::ipc {
|
|||
}
|
||||
}
|
||||
|
||||
static constexpr ALWAYS_INLINE s32 GetMessageBufferSize(const MessageHeader &hdr, const SpecialHeader &spc) {
|
||||
static constexpr ALWAYS_INLINE size_t GetMessageBufferSize(const MessageHeader &hdr, const SpecialHeader &spc) {
|
||||
/* Get the size of the plain message. */
|
||||
size_t msg_size = GetReceiveListIndex(hdr, spc) * sizeof(util::BitPack32);
|
||||
|
||||
|
|
Loading…
Reference in a new issue