From d0673aa2fbdb8c5c1623ceba5c86af6d6d90013f Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Fri, 19 Feb 2021 14:07:50 -0800 Subject: [PATCH] htcs: implement client socket bindings --- .../source/htcs/htcs_socket.cpp | 581 +++++++++++++++++- 1 file changed, 580 insertions(+), 1 deletion(-) diff --git a/libraries/libstratosphere/source/htcs/htcs_socket.cpp b/libraries/libstratosphere/source/htcs/htcs_socket.cpp index 5884e43fb..3178e99a1 100644 --- a/libraries/libstratosphere/source/htcs/htcs_socket.cpp +++ b/libraries/libstratosphere/source/htcs/htcs_socket.cpp @@ -37,6 +37,10 @@ namespace ams::htcs { constinit client::VirtualSocketCollection *g_sockets = nullptr; + void SetLastError(uintptr_t error_code) { + os::SetTlsValue(g_tls_slot, error_code); + } + void InitializeImpl(void *buffer, size_t buffer_size, int num_sessions) { /* Check the session count. */ AMS_ASSERT(0 < num_sessions && num_sessions <= SessionCountMax); @@ -171,8 +175,583 @@ namespace ams::htcs { g_manager = nullptr; g_monitor = nullptr; - /* Finalize the bsd client sessions. */ + /* Finalize the htcs client sessions. */ client::FinalizeSessionManager(); } + const HtcsPeerName GetPeerNameAny() { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Get name. */ + HtcsPeerName name; + g_manager->GetPeerNameAny(std::addressof(name)); + + return name; + } + + const HtcsPeerName GetDefaultHostName() { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Get name. */ + HtcsPeerName name; + g_manager->GetDefaultHostName(std::addressof(name)); + + return name; + } + + s32 GetLastError() { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + return static_cast(os::GetTlsValue(g_tls_slot)); + } + + s32 Socket() { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Socket(error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Close(s32 desc) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Close(desc, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Connect(s32 desc, const SockAddrHtcs *address) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Check that the address family is correct. */ + AMS_ASSERT(address->family == HTCS_AF_HTCS); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Connect(desc, address, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Bind(s32 desc, const SockAddrHtcs *address) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Check that the address family is correct. */ + AMS_ASSERT(address->family == HTCS_AF_HTCS); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Bind(desc, address, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Listen(s32 desc, s32 backlog_count) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Listen(desc, backlog_count, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Accept(s32 desc, SockAddrHtcs *address) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Ensure we have an address. */ + SockAddrHtcs tmp; + if (address == nullptr) { + address = std::addressof(tmp); + } + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Accept(desc, address, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Shutdown(s32 desc, s32 how) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Shutdown(desc, how, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Fcntl(s32 desc, s32 command, s32 value) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Fcntl(desc, command, value, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + s32 Select(s32 count, FdSet *read, FdSet *write, FdSet *exception, TimeVal *timeout) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Check that we have some form of input. */ + if (read == nullptr && write == nullptr && exception == nullptr) { + SetLastError(static_cast(HTCS_EINVAL)); + return -1; + } + + /* Check that the timeout is valid. */ + if (timeout != nullptr && (timeout->tv_sec < 0 || timeout->tv_usec < 0)) { + SetLastError(static_cast(HTCS_EINVAL)); + return -1; + } + + /* Perform the operation. */ + s32 error_code = 0; + const s32 ret = g_sockets->Select(read, write, exception, timeout, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + ssize_t Recv(s32 desc, void *buffer, size_t buffer_size, s32 flags) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const ssize_t ret = g_sockets->Recv(desc, buffer, buffer_size, flags, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + ssize_t Send(s32 desc, const void *buffer, size_t buffer_size, s32 flags) { + /* Check that we have a manager. */ + AMS_ASSERT(g_manager != nullptr); + + /* Check that we have a socket collection. */ + AMS_ASSERT(g_sockets != nullptr); + + /* Perform the operation. */ + s32 error_code = 0; + const ssize_t ret = g_sockets->Send(desc, buffer, buffer_size, flags, error_code); + if (ret < 0) { + SetLastError(static_cast(error_code)); + } + + return ret; + } + + void FdSetZero(FdSet *set) { + AMS_ASSERT(set != nullptr); + + std::memset(set, 0, sizeof(*set)); + } + + void FdSetSet(s32 fd, FdSet *set) { + AMS_ASSERT(set != nullptr); + + for (auto i = 0; i < FdSetSize; ++i) { + if (set->fds[i] == 0) { + set->fds[i] = fd; + break; + } + } + } + + void FdSetClr(s32 fd, FdSet *set) { + AMS_ASSERT(set != nullptr); + + for (auto i = 0; i < FdSetSize; ++i) { + if (set->fds[i] == fd) { + std::memcpy(set->fds + i, set->fds + i + 1, (FdSetSize - (i + 1)) * sizeof(fd)); + set->fds[FdSetSize - 1] = 0; + break; + } + } + } + + bool FdSetIsSet(s32 fd, const FdSet *set) { + AMS_ASSERT(set != nullptr); + + for (auto i = 0; i < FdSetSize; ++i) { + if (set->fds[i] == fd) { + return true; + } + } + + return false; + } + + namespace client { + + sf::SharedPointer socket(s32 &last_error) { + sf::SharedPointer socket = nullptr; + R_ABORT_UNLESS(g_manager->CreateSocket(std::addressof(last_error), std::addressof(socket), g_enable_disconnection_emulation)); + return socket; + } + + s32 close(sf::SharedPointer socket, s32 &last_error) { + s32 res; + socket->Close(std::addressof(last_error), std::addressof(res)); + return res; + } + + s32 bind(sf::SharedPointer socket, const htcs::SockAddrHtcs *address, s32 &last_error) { + /* Create null-terminated address. */ + htcs::SockAddrHtcs null_terminated_address; + null_terminated_address.family = address->family; + util::Strlcpy(null_terminated_address.peer_name.name, address->peer_name.name, PeerNameBufferLength); + util::Strlcpy(null_terminated_address.port_name.name, address->port_name.name, PortNameBufferLength); + + s32 res; + socket->Bind(std::addressof(last_error), std::addressof(res), null_terminated_address); + return res; + } + + s32 listen(sf::SharedPointer socket, s32 backlog_count, s32 &last_error) { + s32 res; + socket->Listen(std::addressof(last_error), std::addressof(res), backlog_count); + return res; + } + + sf::SharedPointer accept(sf::SharedPointer socket, htcs::SockAddrHtcs *address, s32 &last_error) { + /* Begin the accept. */ + sf::SharedPointer res = nullptr; + u32 task_id = 0; + sf::CopyHandle event_handle; + if (R_SUCCEEDED(socket->AcceptStart(std::addressof(task_id), std::addressof(event_handle)))) { + /* Create system event. */ + os::SystemEventType event; + os::AttachReadableHandleToSystemEvent(std::addressof(event), event_handle.GetValue(), true, os::EventClearMode_ManualClear); + + /* When we're done, clean up the event. */ + ON_SCOPE_EXIT { os::DestroySystemEvent(std::addressof(event)); }; + + /* Wait for the accept to finish. */ + os::WaitSystemEvent(std::addressof(event)); + + /* End the accept. */ + socket->AcceptResults(std::addressof(last_error), std::addressof(res), address, task_id); + } else { + /* Set error. */ + last_error = HTCS_EINTR; + } + + /* Sleep, if an error occurred. */ + if (last_error != HTCS_ENONE) { + os::SleepThread(TimeSpan::FromMilliSeconds(1)); + } + + return res; + } + + s32 fcntl(sf::SharedPointer socket, s32 command, s32 value, s32 &last_error) { + s32 res; + socket->Fcntl(std::addressof(last_error), std::addressof(res), command, value); + return res; + } + + s32 shutdown(sf::SharedPointer socket, s32 how, s32 &last_error) { + s32 res; + socket->Shutdown(std::addressof(last_error), std::addressof(res), how); + return res; + } + + s32 connect(sf::SharedPointer socket, const htcs::SockAddrHtcs *address, s32 &last_error) { + /* Create null-terminated address. */ + htcs::SockAddrHtcs null_terminated_address; + null_terminated_address.family = address->family; + util::Strlcpy(null_terminated_address.peer_name.name, address->peer_name.name, PeerNameBufferLength); + util::Strlcpy(null_terminated_address.port_name.name, address->port_name.name, PortNameBufferLength); + + s32 res; + socket->Connect(std::addressof(last_error), std::addressof(res), null_terminated_address); + return res; + } + + s32 select(s32 * const read, s32 &num_read, s32 * const write, s32 &num_write, s32 * const except, s32 &num_except, htcs::TimeVal *timeout, s32 &last_error) { + /* Determine the timeout values. */ + s64 tv_sec = -1; + s64 tv_usec = -1; + if (timeout != nullptr) { + tv_sec = timeout->tv_sec; + tv_usec = timeout->tv_usec; + } + + using InArray = sf::InMapAliasArray; + using OutArray = sf::OutMapAliasArray; + + /* Begin the select. */ + s32 res = -1; + u32 task_id = 0; + sf::CopyHandle event_handle; + if (R_SUCCEEDED(g_manager->StartSelect(std::addressof(task_id), std::addressof(event_handle), InArray(read, num_read), InArray(write, num_write), InArray(except, num_except), tv_sec, tv_usec))) { + /* Create system event. */ + os::SystemEventType event; + os::AttachReadableHandleToSystemEvent(std::addressof(event), event_handle.GetValue(), true, os::EventClearMode_ManualClear); + + /* When we're done, clean up the event. */ + ON_SCOPE_EXIT { os::DestroySystemEvent(std::addressof(event)); }; + + /* Wait for the select to finish. */ + os::WaitSystemEvent(std::addressof(event)); + + /* End the select. */ + g_manager->EndSelect(std::addressof(last_error), std::addressof(res), OutArray(read, num_read), OutArray(write, num_write), OutArray(except, num_except), task_id); + } else { + /* Set error. */ + last_error = HTCS_EINTR; + os::SleepThread(TimeSpan::FromMilliSeconds(1)); + } + + return res; + } + + namespace { + + constexpr size_t MaximumBufferSizeForSmallTransfer = 0xDFE0; + + ssize_t recvLarge(sf::SharedPointer socket, void *buffer, size_t buffer_size, s32 flags, s32 &last_error) { + /* Setup. */ + s64 res = -1; + last_error = HTCS_EINTR; + + /* Start the receive. */ + u32 task_id = 0; + sf::CopyHandle event_handle; + if (R_SUCCEEDED(socket->StartRecv(std::addressof(task_id), std::addressof(event_handle), static_cast(buffer_size), flags))) { + /* Create system event. */ + os::SystemEventType event; + os::AttachReadableHandleToSystemEvent(std::addressof(event), event_handle.GetValue(), true, os::EventClearMode_ManualClear); + + /* When we're done, clean up the event. */ + ON_SCOPE_EXIT { os::DestroySystemEvent(std::addressof(event)); }; + + /* Wait for the receive to finish. */ + os::WaitSystemEvent(std::addressof(event)); + + /* End the receive. */ + socket->EndRecv(std::addressof(last_error), std::addressof(res), sf::OutAutoSelectBuffer(buffer, buffer_size), task_id); + } else { + /* Set error. */ + last_error = HTCS_EINTR; + os::SleepThread(TimeSpan::FromMilliSeconds(1)); + } + + return static_cast(res); + } + + ssize_t sendLarge(sf::SharedPointer socket, const void *buffer, size_t buffer_size, s32 flags, s32 &last_error) { + /* Setup. */ + s64 res = -1; + last_error = HTCS_EINTR; + + /* Start the send. */ + u32 task_id = 0; + s64 max_size = 0; + sf::CopyHandle event_handle; + if (R_SUCCEEDED(socket->StartSend(std::addressof(task_id), std::addressof(event_handle), std::addressof(max_size), static_cast(buffer_size), flags))) { + /* Create system event. */ + os::SystemEventType event; + os::AttachReadableHandleToSystemEvent(std::addressof(event), event_handle.GetValue(), true, os::EventClearMode_ManualClear); + + /* When we're done, clean up the event. */ + ON_SCOPE_EXIT { os::DestroySystemEvent(std::addressof(event)); }; + + /* Send all the data. */ + bool done = false; + size_t sent = 0; + while (sent < buffer_size) { + /* Determine how much to send, this iteration. */ + const u8 *cur = static_cast(buffer) + sent; + const s64 cur_size = std::min(max_size, static_cast(buffer_size - sent)); + + /* Continue sending data. */ + s64 cur_sent = 0; + bool wait = false; + const Result result = socket->ContinueSend(std::addressof(cur_sent), std::addressof(wait), sf::InNonSecureAutoSelectBuffer(cur, cur_size), task_id); + if (cur_sent <= 0 || R_FAILED(result)) { + done = true; + break; + } + + /* Wait if we should. */ + if (wait) { + os::WaitSystemEvent(std::addressof(event)); + os::ClearSystemEvent(std::addressof(event)); + } + + /* Advance. */ + sent += cur_sent; + } + + /* Wait for the send to finish. */ + if (!done) { + os::WaitSystemEvent(std::addressof(event)); + } + + /* End the send. */ + socket->EndSend(std::addressof(last_error), std::addressof(res), task_id); + } else { + /* Set error. */ + last_error = HTCS_EINTR; + os::SleepThread(TimeSpan::FromMilliSeconds(1)); + } + + return static_cast(res); + } + + } + + ssize_t recv(sf::SharedPointer socket, void *buffer, size_t buffer_size, s32 flags, s32 &last_error) { + /* Determine how much to receive. */ + size_t recv_size = buffer_size; + + if ((flags & HTCS_MSG_WAITALL) == 0) { + recv_size = std::min(MaximumBufferSizeForSmallTransfer, buffer_size); + } + + /* Perform a large receive, if we have to. */ + if (recv_size > MaximumBufferSizeForSmallTransfer) { + return recvLarge(socket, buffer, recv_size, flags, last_error); + } + + /* Start the receive. */ + s64 res = -1; + u32 task_id = 0; + sf::CopyHandle event_handle; + if (R_SUCCEEDED(socket->RecvStart(std::addressof(task_id), std::addressof(event_handle), static_cast(recv_size), flags))) { + /* Create system event. */ + os::SystemEventType event; + os::AttachReadableHandleToSystemEvent(std::addressof(event), event_handle.GetValue(), true, os::EventClearMode_ManualClear); + + /* When we're done, clean up the event. */ + ON_SCOPE_EXIT { os::DestroySystemEvent(std::addressof(event)); }; + + /* Wait for the receive to finish. */ + os::WaitSystemEvent(std::addressof(event)); + + /* End the receive. */ + socket->RecvResults(std::addressof(last_error), std::addressof(res), sf::OutAutoSelectBuffer(buffer, recv_size), task_id); + } else { + /* Set error. */ + last_error = HTCS_EINTR; + os::SleepThread(TimeSpan::FromMilliSeconds(1)); + } + + return static_cast(res); + } + + ssize_t send(sf::SharedPointer socket, const void *buffer, size_t buffer_size, s32 flags, s32 &last_error) { + /* Perform a large send, if we have to. */ + if (buffer_size > MaximumBufferSizeForSmallTransfer) { + return sendLarge(socket, buffer, buffer_size, flags, last_error); + } + + /* Start the send. */ + s64 res = -1; + u32 task_id = 0; + sf::CopyHandle event_handle; + if (R_SUCCEEDED(socket->SendStart(std::addressof(task_id), std::addressof(event_handle), sf::InNonSecureAutoSelectBuffer(buffer, buffer_size), flags))) { + /* Create system event. */ + os::SystemEventType event; + os::AttachReadableHandleToSystemEvent(std::addressof(event), event_handle.GetValue(), true, os::EventClearMode_ManualClear); + + /* When we're done, clean up the event. */ + ON_SCOPE_EXIT { os::DestroySystemEvent(std::addressof(event)); }; + + /* Wait for the send to finish. */ + os::WaitSystemEvent(std::addressof(event)); + + /* End the send. */ + socket->SendResults(std::addressof(last_error), std::addressof(res), task_id); + } else { + /* Set error. */ + last_error = HTCS_EINTR; + os::SleepThread(TimeSpan::FromMilliSeconds(1)); + } + + return static_cast(res); + } + + } + }