kern: partially implement Receive half of ReplyAndReceive

This commit is contained in:
Michael Scire 2020-07-10 00:03:50 -07:00
parent 1b2203d102
commit 84b1be1d58
5 changed files with 315 additions and 33 deletions

View file

@ -46,6 +46,8 @@ namespace ams::kern {
static void PostDestroy(uintptr_t arg);
virtual KProcess *GetOwner() const override { return this->owner; }
KReadableEvent &GetReadableEvent() { return this->readable_event; }
KWritableEvent &GetWritableEvent() { return this->writable_event; }
};

View file

@ -155,21 +155,7 @@ namespace ams::kern {
}
}
template<typename T = KAutoObject>
ALWAYS_INLINE KScopedAutoObject<T> GetObjectForIpc(ams::svc::Handle handle) const {
static_assert(!std::is_base_of<KInterruptEvent, T>::value);
/* Handle pseudo-handles. */
if constexpr (std::is_base_of<T, KProcess>::value) {
if (handle == ams::svc::PseudoHandle::CurrentProcess) {
return GetCurrentProcessPointer();
}
} else if constexpr (std::is_base_of<T, KThread>::value) {
if (handle == ams::svc::PseudoHandle::CurrentThread) {
return GetCurrentThreadPointer();
}
}
ALWAYS_INLINE KScopedAutoObject<KAutoObject> GetObjectForIpcWithoutPseudoHandle(ams::svc::Handle handle) const {
/* Lock and look up in table. */
KScopedDisableDispatch dd;
KScopedSpinLock lk(this->lock);
@ -178,15 +164,20 @@ namespace ams::kern {
if (obj->DynamicCast<KInterruptEvent *>() != nullptr) {
return nullptr;
}
if constexpr (std::is_same<T, KAutoObject>::value) {
return obj;
} else {
if (auto *obj = this->GetObjectImpl(handle); obj != nullptr) {
return obj->DynamicCast<T*>();
} else {
return nullptr;
}
return obj;
}
ALWAYS_INLINE KScopedAutoObject<KAutoObject> GetObjectForIpc(ams::svc::Handle handle, KThread *cur_thread) const {
/* Handle pseudo-handles. */
if (handle == ams::svc::PseudoHandle::CurrentProcess) {
return static_cast<KAutoObject *>(static_cast<void *>(cur_thread->GetOwnerProcess()));
}
if (handle == ams::svc::PseudoHandle::CurrentThread) {
return static_cast<KAutoObject *>(cur_thread);
}
return GetObjectForIpcWithoutPseudoHandle(handle);
}
ALWAYS_INLINE KScopedAutoObject<KAutoObject> GetObjectByIndex(ams::svc::Handle *out_handle, size_t index) const {

View file

@ -128,6 +128,11 @@ namespace ams::kern {
size_t GetSize() const { return this->size; }
KProcess *GetServerProcess() const { return this->server; }
void SetServerProcess(KProcess *process) {
this->server = process;
this->server->Open();
}
void ClearThread() { this->thread = nullptr; }
void ClearEvent() { this->event = nullptr; }

View file

@ -17,8 +17,246 @@
namespace ams::kern {
namespace ipc {
using MessageBuffer = ams::svc::ipc::MessageBuffer;
}
namespace {
class ReceiveList {
private:
u32 data[ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountMax * ipc::MessageBuffer::ReceiveListEntry::GetDataSize() / sizeof(u32)];
s32 recv_list_count;
uintptr_t msg_buffer_end;
uintptr_t msg_buffer_space_end;
public:
static constexpr int GetEntryCount(const ipc::MessageBuffer::MessageHeader &header) {
const auto count = header.GetReceiveListCount();
switch (count) {
case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_None:
return 0;
case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer:
return 0;
case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToSingleBuffer:
return 1;
default:
return count - ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset;
}
}
public:
ReceiveList(const u32 *dst_msg, const ipc::MessageBuffer::MessageHeader &dst_header, const ipc::MessageBuffer::SpecialHeader &dst_special_header, size_t msg_size, size_t out_offset, s32 dst_recv_list_idx) {
this->recv_list_count = dst_header.GetReceiveListCount();
this->msg_buffer_end = reinterpret_cast<uintptr_t>(dst_msg) + sizeof(u32) * out_offset;
this->msg_buffer_space_end = reinterpret_cast<uintptr_t>(dst_msg) + msg_size;
const u32 *recv_list = dst_msg + dst_recv_list_idx;
__builtin_memcpy(this->data, recv_list, GetEntryCount(dst_header) * ipc::MessageBuffer::ReceiveListEntry::GetDataSize());
}
constexpr bool IsIndex() const {
return this->recv_list_count > ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset;
}
};
template<bool MoveHandleAllowed>
ALWAYS_INLINE Result ProcessMessageSpecialData(int &offset, KProcess &dst_process, KProcess &src_process, KThread &src_thread, const ipc::MessageBuffer &dst_msg, const ipc::MessageBuffer &src_msg, const ipc::MessageBuffer::SpecialHeader &src_special_header) {
/* Copy the special header to the destination. */
offset = dst_msg.Set(src_special_header);
/* Copy the process ID. */
if (src_special_header.GetHasProcessId()) {
/* TODO: Atmosphere mitm extension support. */
offset = dst_msg.SetProcessId(offset, src_process.GetId());
}
/* Prepare to process handles. */
auto &dst_handle_table = dst_process.GetHandleTable();
auto &src_handle_table = src_process.GetHandleTable();
Result result = ResultSuccess();
/* Process copy handles. */
for (auto i = 0; i < src_special_header.GetCopyHandleCount(); ++i) {
/* Get the handles. */
const ams::svc::Handle src_handle = src_msg.GetHandle(offset);
ams::svc::Handle dst_handle = ams::svc::InvalidHandle;
/* If we're in a success state, try to move the handle to the new table. */
if (R_SUCCEEDED(result) && src_handle != ams::svc::InvalidHandle) {
KScopedAutoObject obj = src_handle_table.GetObjectForIpc(src_handle, std::addressof(src_thread));
if (obj.IsNotNull()) {
Result add_result = dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe());
if (R_FAILED(add_result)) {
result = add_result;
dst_handle = ams::svc::InvalidHandle;
}
} else {
result = svc::ResultInvalidHandle();
}
}
/* Set the handle. */
offset = dst_msg.SetHandle(offset, dst_handle);
}
/* Process move handles. */
if constexpr (MoveHandleAllowed) {
for (auto i = 0; i < src_special_header.GetMoveHandleCount(); ++i) {
/* Get the handles. */
const ams::svc::Handle src_handle = src_msg.GetHandle(offset);
ams::svc::Handle dst_handle = ams::svc::InvalidHandle;
/* Whether or not we've succeeded, we need to remove the handles from the source table. */
if (src_handle != ams::svc::InvalidHandle) {
if (R_SUCCEEDED(result)) {
KScopedAutoObject obj = src_handle_table.GetObjectForIpcWithoutPseudoHandle(src_handle);
if (obj.IsNotNull()) {
Result add_result = dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe());
src_handle_table.Remove(src_handle);
if (R_FAILED(add_result)) {
result = add_result;
dst_handle = ams::svc::InvalidHandle;
}
} else {
result = svc::ResultInvalidHandle();
}
} else {
src_handle_table.Remove(src_handle);
}
}
/* Set the handle. */
offset = dst_msg.SetHandle(offset, dst_handle);
}
}
return result;
}
ALWAYS_INLINE Result ReceiveMessage(bool &recv_list_broken, uintptr_t dst_message_buffer, size_t dst_buffer_size, KPhysicalAddress dst_message_paddr, KThread &src_thread, uintptr_t src_message_buffer, size_t src_buffer_size, KServerSession *session, KSessionRequest *request) {
/* Prepare variables for receive. */
const KThread &dst_thread = GetCurrentThread();
KProcess &dst_process = *(dst_thread.GetOwnerProcess());
KProcess &src_process = *(src_thread.GetOwnerProcess());
auto &dst_page_table = dst_process.GetPageTable();
auto &src_page_table = src_process.GetPageTable();
/* The receive list is initially not broken. */
recv_list_broken = false;
/* Set the server process for the request. */
request->SetServerProcess(std::addressof(dst_process));
/* Determine the message buffers. */
u32 *dst_msg_ptr, *src_msg_ptr;
bool dst_user, src_user;
if (dst_message_buffer) {
dst_msg_ptr = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(dst_message_paddr));
dst_user = true;
} else {
dst_msg_ptr = static_cast<ams::svc::ThreadLocalRegion *>(dst_thread.GetThreadLocalRegionHeapAddress())->message_buffer;
dst_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer);
dst_message_buffer = GetInteger(dst_thread.GetThreadLocalRegionAddress());
dst_user = false;
}
if (src_message_buffer) {
/* NOTE: Nintendo does not check the result of this GetPhysicalAddress call. */
KPhysicalAddress src_message_paddr;
src_page_table.GetPhysicalAddress(std::addressof(src_message_paddr), src_message_buffer);
src_msg_ptr = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(src_message_paddr));
src_user = true;
} else {
src_msg_ptr = static_cast<ams::svc::ThreadLocalRegion *>(src_thread.GetThreadLocalRegionHeapAddress())->message_buffer;
src_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer);
src_message_buffer = GetInteger(src_thread.GetThreadLocalRegionAddress());
src_user = false;
}
/* Parse the headers. */
const ipc::MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size);
const ipc::MessageBuffer src_msg(src_msg_ptr, src_buffer_size);
const ipc::MessageBuffer::MessageHeader dst_header(dst_msg);
const ipc::MessageBuffer::MessageHeader src_header(src_msg);
const ipc::MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header);
const ipc::MessageBuffer::SpecialHeader src_special_header(src_msg, src_header);
/* Get the end of the source message. */
const size_t src_end_offset = ipc::MessageBuffer::GetRawDataIndex(src_header, src_special_header) + src_header.GetRawCount();
/* Ensure that the headers fit. */
R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) <= dst_buffer_size, svc::ResultInvalidCombination());
R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(src_header, src_special_header) <= src_buffer_size, svc::ResultInvalidCombination());
/* Ensure the receive list offset is after the end of raw data. */
if (dst_header.GetReceiveListOffset()) {
R_UNLESS(dst_header.GetReceiveListOffset() >= ipc::MessageBuffer::GetRawDataIndex(dst_header, dst_special_header) + dst_header.GetRawCount(), svc::ResultInvalidCombination());
}
/* Ensure that the destination buffer is big enough to receive the source. */
R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), svc::ResultMessageTooLarge());
/* Get the receive list. */
const s32 dst_recv_list_idx = static_cast<s32>(ipc::MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header));
ReceiveList dst_recv_list(dst_msg_ptr, dst_header, dst_special_header, dst_buffer_size, src_end_offset, dst_recv_list_idx);
/* Ensure that the source special header isn't invalid. */
const bool src_has_special_header = src_header.GetHasSpecialHeader();
if (src_has_special_header) {
/* Sending move handles from client -> server is not allowed. */
R_UNLESS(src_special_header.GetMoveHandleCount() == 0, svc::ResultInvalidCombination());
}
/* Prepare for further processing. */
int pointer_key = 0;
int offset = dst_msg.Set(src_header);
/* Set up a guard to make sure that we end up in a clean state on error. */
auto cleanup_guard = SCOPE_GUARD {
/* TODO */
MESOSPHERE_UNIMPLEMENTED();
};
/* Process any special data. */
if (src_header.GetHasSpecialHeader()) {
/* After we process, make sure we track whether the receive list is broken. */
ON_SCOPE_EXIT { if (offset > dst_recv_list_idx) { recv_list_broken = true; } };
/* Process special data. */
R_TRY(ProcessMessageSpecialData<false>(offset, dst_process, src_process, src_thread, dst_msg, src_msg, src_special_header));
}
/* Process any pointer buffers. */
for (auto i = 0; i < src_header.GetPointerCount(); ++i) {
MESOSPHERE_UNIMPLEMENTED();
}
/* Process any map alias buffers. */
for (auto i = 0; i < src_header.GetMapAliasCount(); ++i) {
MESOSPHERE_UNIMPLEMENTED();
}
/* Process any raw data. */
if (src_header.GetRawCount()) {
MESOSPHERE_UNIMPLEMENTED();
}
/* TODO: Remove this when done, as these variables will be used by unimplemented stuff above. */
static_cast<void>(dst_page_table);
static_cast<void>(dst_user);
static_cast<void>(src_user);
static_cast<void>(pointer_key);
/* We succeeded! */
cleanup_guard.Cancel();
return ResultSuccess();
}
ALWAYS_INLINE void ReplyAsyncError(KProcess *to_process, uintptr_t to_msg_buf, size_t to_msg_buf_size, Result result) {
/* Convert the buffer to a physical address. */
KPhysicalAddress phys_addr;
@ -28,7 +266,7 @@ namespace ams::kern {
u32 *to_msg = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(phys_addr));
/* Set the error. */
ams::svc::ipc::MessageBuffer msg(to_msg, to_msg_buf_size);
ipc::MessageBuffer msg(to_msg, to_msg_buf_size);
msg.SetAsyncResult(result);
}
@ -44,8 +282,54 @@ namespace ams::kern {
this->parent->Close();
}
Result KServerSession::ReceiveRequest(uintptr_t message, uintptr_t buffer_size, KPhysicalAddress message_paddr) {
MESOSPHERE_UNIMPLEMENTED();
Result KServerSession::ReceiveRequest(uintptr_t server_message, uintptr_t server_buffer_size, KPhysicalAddress server_message_paddr) {
MESOSPHERE_ASSERT_THIS();
/* Lock the session. */
KScopedLightLock lk(this->lock);
/* Get the request and client thread. */
KSessionRequest *request;
KScopedAutoObject<KThread> client_thread;
{
KScopedSchedulerLock sl;
/* Ensure that we can service the request. */
R_UNLESS(!this->parent->IsClientClosed(), svc::ResultSessionClosed());
/* Ensure we aren't already servicing a request. */
R_UNLESS(this->current_request == nullptr, svc::ResultNotFound());
/* Ensure we have a request to service. */
R_UNLESS(!this->request_list.empty(), svc::ResultNotFound());
/* Pop the first request from the list. */
request = std::addressof(this->request_list.front());
this->request_list.pop_front();
/* Get the thread for the request. */
client_thread = KScopedAutoObject<KThread>(request->GetThread());
R_UNLESS(client_thread.IsNotNull(), svc::ResultSessionClosed());
}
/* Set the request as our current. */
this->current_request = request;
/* Get the client address. */
uintptr_t client_message = request->GetAddress();
size_t client_buffer_size = request->GetSize();
bool recv_list_broken = false;
/* Receive the message. */
Result result = ReceiveMessage(recv_list_broken, server_message, server_buffer_size, server_message_paddr, *client_thread.GetPointerUnsafe(), client_message, client_buffer_size, this, request);
/* Handle cleanup on receive failure. */
if (R_FAILED(result)) {
/* TODO */
MESOSPHERE_UNIMPLEMENTED();
}
return result;
}
Result KServerSession::SendReply(uintptr_t message, uintptr_t buffer_size, KPhysicalAddress message_paddr) {

View file

@ -439,34 +439,34 @@ namespace ams::svc::ipc {
return index + (spc.GetHeaderSize() / sizeof(*this->buffer));
}
ALWAYS_INLINE s32 SetHandle(s32 index, const ::ams::svc::Handle &hnd) {
ALWAYS_INLINE s32 SetHandle(s32 index, const ::ams::svc::Handle &hnd) const {
static_assert(util::IsAligned(sizeof(hnd), sizeof(*this->buffer)));
__builtin_memcpy(this->buffer + index, std::addressof(hnd), sizeof(hnd));
return index + (sizeof(hnd) / sizeof(*this->buffer));
}
ALWAYS_INLINE s32 SetProcessId(s32 index, const u64 pid) {
ALWAYS_INLINE s32 SetProcessId(s32 index, const u64 pid) const {
static_assert(util::IsAligned(sizeof(pid), sizeof(*this->buffer)));
__builtin_memcpy(this->buffer + index, std::addressof(pid), sizeof(pid));
return index + (sizeof(pid) / sizeof(*this->buffer));
}
ALWAYS_INLINE s32 Set(s32 index, const MapAliasDescriptor &desc) {
ALWAYS_INLINE s32 Set(s32 index, const MapAliasDescriptor &desc) const {
__builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize());
return index + (desc.GetDataSize() / sizeof(*this->buffer));
}
ALWAYS_INLINE s32 Set(s32 index, const PointerDescriptor &desc) {
ALWAYS_INLINE s32 Set(s32 index, const PointerDescriptor &desc) const {
__builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize());
return index + (desc.GetDataSize() / sizeof(*this->buffer));
}
ALWAYS_INLINE s32 Set(s32 index, const ReceiveListEntry &desc) {
ALWAYS_INLINE s32 Set(s32 index, const ReceiveListEntry &desc) const {
__builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize());
return index + (desc.GetDataSize() / sizeof(*this->buffer));
}
ALWAYS_INLINE s32 Set(s32 index, const u32 val) {
ALWAYS_INLINE s32 Set(s32 index, const u32 val) const {
static_assert(util::IsAligned(sizeof(val), sizeof(*this->buffer)));
__builtin_memcpy(this->buffer + index, std::addressof(val), sizeof(val));
return index + (sizeof(val) / sizeof(*this->buffer));
@ -521,7 +521,7 @@ namespace ams::svc::ipc {
}
}
static constexpr ALWAYS_INLINE s32 GetMessageBufferSize(const MessageHeader &hdr, const SpecialHeader &spc) {
static constexpr ALWAYS_INLINE size_t GetMessageBufferSize(const MessageHeader &hdr, const SpecialHeader &spc) {
/* Get the size of the plain message. */
size_t msg_size = GetReceiveListIndex(hdr, spc) * sizeof(util::BitPack32);