/* * Copyright (c) 2018-2020 Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include #include "socket_api.hpp" #include "socket_allocator.hpp" extern "C" { #include } namespace ams::socket::impl { namespace { constinit bool g_initialized = false; constinit os::SdkMutex g_heap_mutex; constinit lmem::HeapHandle g_heap_handle = nullptr; constinit int g_heap_generation = -1; ALWAYS_INLINE bool IsInitialized() { return g_initialized; } } void *Alloc(size_t size) { std::scoped_lock lk(g_heap_mutex); AMS_ASSERT(g_heap_generation > 0); void *ptr = nullptr; if (!g_heap_handle) { socket::impl::SetLastError(Errno::EOpNotSupp); } else if ((ptr = lmem::AllocateFromExpHeap(g_heap_handle, size)) == nullptr) { socket::impl::SetLastError(Errno::ENoMem); } return ptr; } void *Calloc(size_t num, size_t size) { std::scoped_lock lk(g_heap_mutex); AMS_ASSERT(g_heap_generation > 0); void *ptr = nullptr; if (!g_heap_handle) { socket::impl::SetLastError(Errno::EOpNotSupp); } else if ((ptr = lmem::AllocateFromExpHeap(g_heap_handle, size * num)) == nullptr) { socket::impl::SetLastError(Errno::ENoMem); } else { std::memset(ptr, 0, size * num); } return ptr; } void Free(void *ptr) { std::scoped_lock lk(g_heap_mutex); AMS_ASSERT(g_heap_generation > 0); if (!g_heap_handle) { socket::impl::SetLastError(Errno::EOpNotSupp); } else if (ptr != nullptr) { lmem::FreeToExpHeap(g_heap_handle, ptr); } } bool HeapIsAvailable(int generation) { std::scoped_lock lk(g_heap_mutex); return g_heap_handle && g_heap_generation == generation; } int GetHeapGeneration() { std::scoped_lock lk(g_heap_mutex); return g_heap_generation; } Errno GetLastError() { if (AMS_LIKELY(IsInitialized())) { return static_cast(errno); } else { return Errno::EInval; } } void SetLastError(Errno err) { if (AMS_LIKELY(IsInitialized())) { errno = static_cast(err); } } u32 InetHtonl(u32 host) { if constexpr (util::IsBigEndian()) { return host; } else { return util::SwapBytes(host); } } u16 InetHtons(u16 host) { if constexpr (util::IsBigEndian()) { return host; } else { return util::SwapBytes(host); } } u32 InetNtohl(u32 net) { if constexpr (util::IsBigEndian()) { return net; } else { return util::SwapBytes(net); } } u16 InetNtohs(u16 net) { if constexpr (util::IsBigEndian()) { return net; } else { return util::SwapBytes(net); } } namespace { void InitializeHeapImpl(void *buffer, size_t size) { /* NOTE: Nintendo uses both CreateOption_ThreadSafe *and* a global heap mutex. */ /* This is unnecessary, and using a single SdkMutex is more performant, since we're not recursive. */ std::scoped_lock lk(g_heap_mutex); g_heap_handle = lmem::CreateExpHeap(buffer, size, lmem::CreateOption_None); } Result InitializeCommon(const Config &config) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(!IsInitialized()); AMS_ABORT_UNLESS(config.GetMemoryPool() != nullptr); AMS_ABORT_UNLESS(1 <= config.GetConcurrencyCountMax() && config.GetConcurrencyCountMax() <= ConcurrencyLimitMax); if (!config.IsSmbpClient()) { AMS_ABORT_UNLESS(config.GetAllocatorPoolSize() < config.GetMemoryPoolSize()); } AMS_ABORT_UNLESS(util::IsAligned(config.GetMemoryPoolSize(), os::MemoryPageSize)); AMS_ABORT_UNLESS(util::IsAligned(config.GetAllocatorPoolSize(), os::MemoryPageSize)); AMS_ABORT_UNLESS(config.GetAllocatorPoolSize() >= 4_KB); if (!config.IsSystemClient()) { R_UNLESS(config.GetMemoryPoolSize() >= socket::MinSocketMemoryPoolSize, socket::ResultInsufficientProvidedMemory()); } const size_t transfer_memory_size = config.GetMemoryPoolSize() - config.GetAllocatorPoolSize(); if (!config.IsSmbpClient()) { R_UNLESS(transfer_memory_size >= socket::MinMemHeapAllocatorSize, socket::ResultInsufficientProvidedMemory()); } else { R_UNLESS(config.GetMemoryPoolSize() >= socket::MinimumSharedMbufPoolReservation, socket::ResultInsufficientProvidedMemory()); } /* Initialize the allocator heap. */ InitializeHeapImpl(static_cast(config.GetMemoryPool()) + transfer_memory_size, config.GetAllocatorPoolSize()); /* Initialize libnx. */ { const ::BsdInitConfig libnx_config = { .version = config.GetVersion(), .tmem_buffer = config.GetMemoryPool(), .tmem_buffer_size = transfer_memory_size, .tcp_tx_buf_size = static_cast(config.GetTcpInitialSendBufferSize()), .tcp_rx_buf_size = static_cast(config.GetTcpInitialReceiveBufferSize()), .tcp_tx_buf_max_size = static_cast(config.GetTcpAutoSendBufferSizeMax()), .tcp_rx_buf_max_size = static_cast(config.GetTcpAutoReceiveBufferSizeMax()), .udp_tx_buf_size = static_cast(config.GetUdpSendBufferSize()), .udp_rx_buf_size = static_cast(config.GetUdpReceiveBufferSize()), .sb_efficiency = static_cast(config.GetSocketBufferEfficiency()), }; const auto service_type = config.IsSystemClient() ? (1 << 1) : (1 << 0); R_ABORT_UNLESS(sm::Initialize()); R_ABORT_UNLESS(::bsdInitialize(std::addressof(libnx_config), static_cast(config.GetConcurrencyCountMax()), service_type)); } /* Set the heap generation. */ g_heap_generation = (g_heap_generation + 1) % MinimumHeapAlignment; /* TODO: socket::resolver::EnableResolverCalls()? Not necessary in our case (htc), but consider calling it. */ return ResultSuccess(); } ALWAYS_INLINE struct sockaddr *ConvertForLibnx(SockAddr *addr) { static_assert(sizeof(SockAddr) == sizeof(struct sockaddr)); static_assert(alignof(SockAddr) == alignof(struct sockaddr)); return reinterpret_cast(addr); } ALWAYS_INLINE const struct sockaddr *ConvertForLibnx(const SockAddr *addr) { static_assert(sizeof(SockAddr) == sizeof(struct sockaddr)); static_assert(alignof(SockAddr) == alignof(struct sockaddr)); return reinterpret_cast(addr); } static_assert(std::same_as); Errno TranslateResultToBsdErrorImpl(const Result &result) { if (R_SUCCEEDED(result)) { return Errno::ESuccess; } else if (svc::ResultInvalidCurrentMemory::Includes(result) || svc::ResultOutOfAddressSpace::Includes(result)) { return Errno::EFault; } else if (sf::hipc::ResultCommunicationError::Includes(result)) { return Errno::EL3Hlt; } else if (sf::hipc::ResultOutOfResource::Includes(result)) { return Errno::EAgain; } else { R_ABORT_UNLESS(result); return static_cast(-1); } } ALWAYS_INLINE void TranslateResultToBsdError(Errno &bsd_error, int &result) { Errno translate_error = Errno::ESuccess; if ((translate_error = TranslateResultToBsdErrorImpl(static_cast<::ams::Result>(::g_bsdResult))) != Errno::ESuccess) { bsd_error = translate_error; result = -1; } } } Result Initialize(const Config &config) { return InitializeCommon(config); } Result Finalize() { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* TODO: If we support statistics, kill the statistics thread. */ /* TODO: socket::resolver::DisableResolverCalls()? */ /* Finalize libnx. */ ::bsdExit(); /* Finalize the heap. */ lmem::HeapHandle heap_handle; { std::scoped_lock lk(g_heap_mutex); heap_handle = g_heap_handle; g_heap_handle = nullptr; } lmem::DestroyExpHeap(heap_handle); return ResultSuccess(); } Result InitializeAllocatorForInternal(void *buffer, size_t size) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(util::IsAligned(size, os::MemoryPageSize)); AMS_ABORT_UNLESS(size >= 4_KB); InitializeHeapImpl(buffer, size); return ResultSuccess(); } ssize_t RecvFrom(s32 desc, void *buffer, size_t buffer_size, MsgFlag flags, SockAddr *out_address, SockLenT *out_addr_len) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (buffer_size == 0) { return 0; } else if (buffer == nullptr) { socket::impl::SetLastError(Errno::EInval); return -1; } else if (buffer_size > std::numeric_limits::max()) { socket::impl::SetLastError(Errno::EFault); return -1; } /* If this is just a normal receive call, perform a normal receive. */ if (out_address == nullptr || out_addr_len == nullptr || *out_addr_len == 0) { return impl::Recv(desc, buffer, buffer_size, flags); } /* Perform the call. */ socklen_t length; Errno error = Errno::ESuccess; int result = ::bsdRecvFrom(desc, buffer, buffer_size, static_cast(flags), ConvertForLibnx(out_address), std::addressof(length)); TranslateResultToBsdError(error, result); if (result >= 0) { *out_addr_len = length; } else { socket::impl::SetLastError(error); } return result; } ssize_t Recv(s32 desc, void *buffer, size_t buffer_size, MsgFlag flags) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (buffer_size == 0) { return 0; } else if (buffer == nullptr) { socket::impl::SetLastError(Errno::EInval); return -1; } else if (buffer_size > std::numeric_limits::max()) { socket::impl::SetLastError(Errno::EFault); return -1; } /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdRecv(desc, buffer, buffer_size, static_cast(flags)); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } ssize_t SendTo(s32 desc, const void *buffer, size_t buffer_size, MsgFlag flags, const SockAddr *address, SockLenT len) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* If this is a normal send, perform a normal send. */ if (address == nullptr || len == 0) { return impl::Send(desc, buffer, buffer_size, flags); } /* Check input. */ if (buffer_size == 0) { return 0; } else if (buffer == nullptr) { socket::impl::SetLastError(Errno::EInval); return -1; } else if (buffer_size > std::numeric_limits::max()) { socket::impl::SetLastError(Errno::EFault); return -1; } /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdSendTo(desc, buffer, buffer_size, static_cast(flags), ConvertForLibnx(address), len); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } ssize_t Send(s32 desc, const void *buffer, size_t buffer_size, MsgFlag flags) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (buffer_size == 0) { return 0; } else if (buffer == nullptr) { socket::impl::SetLastError(Errno::EInval); return -1; } else if (buffer_size > std::numeric_limits::max()) { socket::impl::SetLastError(Errno::EFault); return -1; } /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdSend(desc, buffer, buffer_size, static_cast(flags)); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } s32 Shutdown(s32 desc, ShutdownMethod how) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdShutdown(desc, static_cast(how)); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } s32 SocketExempt(Family domain, Type type, Protocol protocol) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdSocketExempt(static_cast(domain), static_cast(type), static_cast(protocol)); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } s32 Accept(s32 desc, SockAddr *out_address, SockLenT *out_addr_len) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (out_address == nullptr && out_addr_len != nullptr && *out_addr_len != 0) { socket::impl::SetLastError(Errno::EFault); return -1; } socklen_t addrlen = static_cast((out_address && out_addr_len) ? *out_addr_len : 0); /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdAccept(desc, ConvertForLibnx(out_address), std::addressof(addrlen)); TranslateResultToBsdError(error, result); if (result >= 0) { if (out_addr_len != nullptr) { *out_addr_len = addrlen; } } else { socket::impl::SetLastError(error); } return result; } s32 Bind(s32 desc, const SockAddr *address, SockLenT len) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (address == nullptr || len == 0) { socket::impl::SetLastError(Errno::EInval); return -1; } /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdBind(desc, ConvertForLibnx(address), len); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } s32 GetSockName(s32 desc, SockAddr *out_address, SockLenT *out_addr_len) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (out_address == nullptr || out_addr_len == nullptr || *out_addr_len == 0) { socket::impl::SetLastError(Errno::EInval); return -1; } /* Perform the call. */ socklen_t length; Errno error = Errno::ESuccess; int result = ::bsdGetSockName(desc, ConvertForLibnx(out_address), std::addressof(length)); TranslateResultToBsdError(error, result); if (result >= 0) { *out_addr_len = length; } else { socket::impl::SetLastError(error); } return result; } s32 SetSockOpt(s32 desc, Level level, Option option_name, const void *option_value, SockLenT option_size) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Check input. */ if (option_value == nullptr) { socket::impl::SetLastError(Errno::EInval); return -1; } /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdSetSockOpt(desc, static_cast(level), static_cast(option_name), option_value, option_size); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } s32 Listen(s32 desc, s32 backlog) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdListen(desc, backlog); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } s32 Close(s32 desc) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); /* Perform the call. */ Errno error = Errno::ESuccess; int result = ::bsdClose(desc); TranslateResultToBsdError(error, result); if (result < 0) { socket::impl::SetLastError(error); } return result; } }