From a0cc22302c4905b75409d256684b4434dde8f491 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Wed, 15 Jul 2020 09:15:49 -0700 Subject: [PATCH] kern: add KAddressArbiter::WaitIfEqual --- .../mesosphere/kern_k_condition_variable.hpp | 43 +++---- .../include/mesosphere/kern_k_process.hpp | 8 ++ .../include/mesosphere/kern_k_thread.hpp | 48 +++++++- .../source/kern_k_address_arbiter.cpp | 108 ++++++++++++++++++ .../source/kern_k_condition_variable.cpp | 2 +- .../libmesosphere/source/kern_k_thread.cpp | 10 +- .../source/svc/kern_svc_address_arbiter.cpp | 66 ++++++++++- .../util/util_intrusive_red_black_tree.hpp | 34 ++++++ 8 files changed, 275 insertions(+), 44 deletions(-) create mode 100644 libraries/libmesosphere/source/kern_k_address_arbiter.cpp diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp index 9000dcc68..3c8928157 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp @@ -20,26 +20,9 @@ namespace ams::kern { - struct KConditionVariableComparator { - static constexpr ALWAYS_INLINE int Compare(const KThread &lhs, const KThread &rhs) { - const uintptr_t l_key = lhs.GetConditionVariableKey(); - const uintptr_t r_key = rhs.GetConditionVariableKey(); - - if (l_key < r_key) { - /* Sort first by key */ - return -1; - } else if (l_key == r_key && lhs.GetPriority() < rhs.GetPriority()) { - /* And then by priority. */ - return -1; - } else { - return 1; - } - } - }; - class KConditionVariable { public: - using ThreadTree = util::IntrusiveRedBlackTreeMemberTraits<&KThread::condvar_arbiter_tree_node>::TreeType; + using ThreadTree = typename KThread::ConditionVariableThreadTreeType; private: ThreadTree tree; public: @@ -52,20 +35,20 @@ namespace ams::kern { /* Condition variable. */ void Signal(uintptr_t cv_key, s32 count); Result Wait(KProcessAddress addr, uintptr_t key, u32 value, s64 timeout); - - ALWAYS_INLINE void BeforeUpdatePriority(KThread *thread) { - MESOSPHERE_ASSERT(KScheduler::IsSchedulerLockedByCurrentThread()); - - this->tree.erase(this->tree.iterator_to(*thread)); - } - - ALWAYS_INLINE void AfterUpdatePriority(KThread *thread) { - MESOSPHERE_ASSERT(KScheduler::IsSchedulerLockedByCurrentThread()); - - this->tree.insert(*thread); - } private: KThread *SignalImpl(KThread *thread); }; + ALWAYS_INLINE void BeforeUpdatePriority(KConditionVariable::ThreadTree *tree, KThread *thread) { + MESOSPHERE_ASSERT(KScheduler::IsSchedulerLockedByCurrentThread()); + + tree->erase(tree->iterator_to(*thread)); + } + + ALWAYS_INLINE void AfterUpdatePriority(KConditionVariable::ThreadTree *tree, KThread *thread) { + MESOSPHERE_ASSERT(KScheduler::IsSchedulerLockedByCurrentThread()); + + tree->insert(*thread); + } + } diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp index f07c6f236..72b2dde2f 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp @@ -231,6 +231,14 @@ namespace ams::kern { return this->cond_var.Wait(address, cv_key, tag, ns); } + Result SignalAddressArbiter(uintptr_t address, ams::svc::SignalType signal_type, s32 value, s32 count) { + return this->address_arbiter.SignalToAddress(address, signal_type, value, count); + } + + Result WaitAddressArbiter(uintptr_t address, ams::svc::ArbitrationType arb_type, s32 value, s64 timeout) { + return this->address_arbiter.WaitForAddress(address, arb_type, value, timeout); + } + static KProcess *GetProcessFromId(u64 process_id); static Result GetProcessList(s32 *out_num_processes, ams::kern::svc::KUserPointer out_process_ids, s32 max_out_count); diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp index 7810213b4..0439ae9f5 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp @@ -119,6 +119,23 @@ namespace ams::kern { constexpr SyncObjectBuffer() : sync_objects() { /* ... */ } }; static_assert(sizeof(SyncObjectBuffer::sync_objects) == sizeof(SyncObjectBuffer::handles)); + + struct ConditionVariableComparator { + static constexpr ALWAYS_INLINE int Compare(const KThread &lhs, const KThread &rhs) { + const uintptr_t l_key = lhs.GetConditionVariableKey(); + const uintptr_t r_key = rhs.GetConditionVariableKey(); + + if (l_key < r_key) { + /* Sort first by key */ + return -1; + } else if (l_key == r_key && lhs.GetPriority() < rhs.GetPriority()) { + /* And then by priority. */ + return -1; + } else { + return 1; + } + } + }; private: static inline std::atomic s_next_thread_id = 0; private: @@ -150,10 +167,13 @@ namespace ams::kern { using WaiterListTraits = util::IntrusiveListMemberTraitsDeferredAssert<&KThread::waiter_list_node>; using WaiterList = WaiterListTraits::ListType; + using ConditionVariableThreadTreeTraits = util::IntrusiveRedBlackTreeMemberTraitsDeferredAssert<&KThread::condvar_arbiter_tree_node>; + using ConditionVariableThreadTree = ConditionVariableThreadTreeTraits::TreeType; + WaiterList waiter_list{}; WaiterList paused_waiter_list{}; KThread *lock_owner{}; - KConditionVariable *cond_var{}; + ConditionVariableThreadTree *condvar_tree{}; uintptr_t debug_params[3]{}; u32 arbiter_value{}; u32 suspend_request_flags{}; @@ -290,8 +310,21 @@ namespace ams::kern { this->priority = priority; } - void ClearConditionVariable() { - this->cond_var = nullptr; + constexpr void ClearConditionVariableTree() { + this->condvar_tree = nullptr; + } + + constexpr void SetAddressArbiter(ConditionVariableThreadTree *tree, uintptr_t address) { + this->condvar_tree = tree; + this->condvar_key = address; + } + + constexpr void ClearAddressArbiter() { + this->condvar_tree = nullptr; + } + + constexpr bool IsWaitingForAddressArbiter() const { + return this->condvar_tree != nullptr; } constexpr s32 GetIdealCore() const { return this->ideal_core_id; } @@ -308,7 +341,7 @@ namespace ams::kern { constexpr const QueueEntry &GetSleepingQueueEntry() const { return this->sleeping_queue_entry; } constexpr void SetSleepingQueue(KThreadQueue *q) { this->sleeping_queue = q; } - constexpr KConditionVariable *GetConditionVariable() const { return this->cond_var; } + constexpr ConditionVariableThreadTree *GetConditionVariableTree() const { return this->condvar_tree; } constexpr s32 GetNumKernelWaiters() const { return this->num_kernel_waiters; } @@ -416,9 +449,16 @@ namespace ams::kern { static constexpr bool IsWaiterListValid() { return WaiterListTraits::IsValid(); } + + static constexpr bool IsConditionVariableThreadTreeValid() { + return ConditionVariableThreadTreeTraits::IsValid(); + } + + using ConditionVariableThreadTreeType = ConditionVariableThreadTree; }; static_assert(alignof(KThread) == 0x10); static_assert(KThread::IsWaiterListValid()); + static_assert(KThread::IsConditionVariableThreadTreeValid()); class KScopedDisableDispatch { public: diff --git a/libraries/libmesosphere/source/kern_k_address_arbiter.cpp b/libraries/libmesosphere/source/kern_k_address_arbiter.cpp new file mode 100644 index 000000000..629e57839 --- /dev/null +++ b/libraries/libmesosphere/source/kern_k_address_arbiter.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include + +namespace ams::kern { + + namespace { + + constinit KThread g_arbiter_compare_thread; + + ALWAYS_INLINE bool ReadFromUser(s32 *out, KProcessAddress address) { + return UserspaceAccess::CopyMemoryFromUserSize32Bit(out, GetVoidPointer(address)); + } + + } + + Result KAddressArbiter::Signal(uintptr_t addr, s32 count) { + MESOSPHERE_UNIMPLEMENTED(); + } + + Result KAddressArbiter::SignalAndIncrementIfEqual(uintptr_t addr, s32 value, s32 count) { + MESOSPHERE_UNIMPLEMENTED(); + } + + Result KAddressArbiter::SignalAndModifyByWaitingCountIfEqual(uintptr_t addr, s32 value, s32 count) { + MESOSPHERE_UNIMPLEMENTED(); + } + + Result KAddressArbiter::WaitIfLessThan(uintptr_t addr, s32 value, bool decrement, s64 timeout) { + MESOSPHERE_UNIMPLEMENTED(); + } + + Result KAddressArbiter::WaitIfEqual(uintptr_t addr, s32 value, s64 timeout) { + /* Prepare to wait. */ + KThread *cur_thread = GetCurrentThreadPointer(); + KHardwareTimer *timer; + + { + KScopedSchedulerLockAndSleep slp(std::addressof(timer), cur_thread, timeout); + + /* Check that the thread isn't terminating. */ + if (cur_thread->IsTerminationRequested()) { + slp.CancelSleep(); + return svc::ResultTerminationRequested(); + } + + /* Set the synced object. */ + cur_thread->SetSyncedObject(nullptr, ams::svc::ResultTimedOut()); + + /* Read the value from userspace. */ + s32 user_value; + if (!ReadFromUser(std::addressof(user_value), addr)) { + slp.CancelSleep(); + return svc::ResultInvalidCurrentMemory(); + } + + /* Check that the value is equal. */ + if (value != user_value) { + slp.CancelSleep(); + return svc::ResultInvalidState(); + } + + /* Check that the timeout is non-zero. */ + if (timeout == 0) { + slp.CancelSleep(); + return svc::ResultTimedOut(); + } + + /* Set the arbiter. */ + cur_thread->SetAddressArbiter(std::addressof(this->tree), addr); + this->tree.insert(*cur_thread); + cur_thread->SetState(KThread::ThreadState_Waiting); + } + + /* Cancel the timer wait. */ + if (timer != nullptr) { + timer->CancelTask(cur_thread); + } + + /* Remove from the address arbiter. */ + { + KScopedSchedulerLock sl; + + if (cur_thread->IsWaitingForAddressArbiter()) { + this->tree.erase(this->tree.iterator_to(*cur_thread)); + cur_thread->ClearAddressArbiter(); + } + } + + /* Get the result. */ + KSynchronizationObject *dummy; + return cur_thread->GetWaitResult(std::addressof(dummy)); + } + +} diff --git a/libraries/libmesosphere/source/kern_k_condition_variable.cpp b/libraries/libmesosphere/source/kern_k_condition_variable.cpp index 526d504e2..50df876cb 100644 --- a/libraries/libmesosphere/source/kern_k_condition_variable.cpp +++ b/libraries/libmesosphere/source/kern_k_condition_variable.cpp @@ -146,7 +146,7 @@ namespace ams::kern { } it = this->tree.erase(it); - target_thread->ClearConditionVariable(); + target_thread->ClearConditionVariableTree(); ++num_waiters; } } diff --git a/libraries/libmesosphere/source/kern_k_thread.cpp b/libraries/libmesosphere/source/kern_k_thread.cpp index a0ebf689e..3c0681683 100644 --- a/libraries/libmesosphere/source/kern_k_thread.cpp +++ b/libraries/libmesosphere/source/kern_k_thread.cpp @@ -94,7 +94,7 @@ namespace ams::kern { /* Set parent and condvar tree. */ this->parent = nullptr; - this->cond_var = nullptr; + this->condvar_tree = nullptr; /* Set sync booleans. */ this->signaled = false; @@ -519,8 +519,8 @@ namespace ams::kern { } /* Ensure we don't violate condition variable red black tree invariants. */ - if (auto *cond_var = thread->GetConditionVariable(); cond_var != nullptr) { - cond_var->BeforeUpdatePriority(thread); + if (auto *cv_tree = thread->GetConditionVariableTree(); cv_tree != nullptr) { + BeforeUpdatePriority(cv_tree, thread); } /* Change the priority. */ @@ -528,8 +528,8 @@ namespace ams::kern { thread->SetPriority(new_priority); /* Restore the condition variable, if relevant. */ - if (auto *cond_var = thread->GetConditionVariable(); cond_var != nullptr) { - cond_var->AfterUpdatePriority(thread); + if (auto *cv_tree = thread->GetConditionVariableTree(); cv_tree != nullptr) { + AfterUpdatePriority(cv_tree, thread); } /* Update the scheduler. */ diff --git a/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp b/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp index 3943f0421..0fdb4780e 100644 --- a/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp +++ b/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp @@ -21,28 +21,86 @@ namespace ams::kern::svc { namespace { + constexpr bool IsKernelAddress(uintptr_t address) { + return KernelVirtualAddressSpaceBase <= address && address < KernelVirtualAddressSpaceEnd; + } + constexpr bool IsValidSignalType(ams::svc::SignalType type) { + switch (type) { + case ams::svc::SignalType_Signal: + case ams::svc::SignalType_SignalAndIncrementIfEqual: + case ams::svc::SignalType_SignalAndModifyByWaitingCountIfEqual: + return true; + default: + return false; + } + } + + constexpr bool IsValidArbitrationType(ams::svc::ArbitrationType type) { + switch (type) { + case ams::svc::ArbitrationType_WaitIfLessThan: + case ams::svc::ArbitrationType_DecrementAndWaitIfLessThan: + case ams::svc::ArbitrationType_WaitIfEqual: + return true; + default: + return false; + } + } + + Result WaitForAddress(uintptr_t address, ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) { + /* Validate input. */ + R_UNLESS(AMS_LIKELY(!IsKernelAddress(address)), svc::ResultInvalidCurrentMemory()); + R_UNLESS(util::IsAligned(address, sizeof(int32_t)), svc::ResultInvalidAddress()); + R_UNLESS(IsValidArbitrationType(arb_type), svc::ResultInvalidEnumValue()); + + /* Convert timeout from nanoseconds to ticks. */ + s64 timeout; + if (timeout_ns > 0) { + const ams::svc::Tick offset_tick(TimeSpan::FromNanoSeconds(timeout_ns)); + if (AMS_LIKELY(offset_tick > 0)) { + timeout = KHardwareTimer::GetTick() + offset_tick + 2; + if (AMS_UNLIKELY(timeout <= 0)) { + timeout = std::numeric_limits::max(); + } + } else { + timeout = std::numeric_limits::max(); + } + } else { + timeout = timeout_ns; + } + + return GetCurrentProcess().WaitAddressArbiter(address, arb_type, value, timeout); + } + + Result SignalToAddress(uintptr_t address, ams::svc::SignalType signal_type, int32_t value, int32_t count) { + /* Validate input. */ + R_UNLESS(AMS_LIKELY(!IsKernelAddress(address)), svc::ResultInvalidCurrentMemory()); + R_UNLESS(util::IsAligned(address, sizeof(int32_t)), svc::ResultInvalidAddress()); + R_UNLESS(IsValidSignalType(signal_type), svc::ResultInvalidEnumValue()); + + return GetCurrentProcess().SignalAddressArbiter(address, signal_type, value, count); + } } /* ============================= 64 ABI ============================= */ Result WaitForAddress64(ams::svc::Address address, ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) { - MESOSPHERE_PANIC("Stubbed SvcWaitForAddress64 was called."); + return WaitForAddress(address, arb_type, value, timeout_ns); } Result SignalToAddress64(ams::svc::Address address, ams::svc::SignalType signal_type, int32_t value, int32_t count) { - MESOSPHERE_PANIC("Stubbed SvcSignalToAddress64 was called."); + return SignalToAddress(address, signal_type, value, count); } /* ============================= 64From32 ABI ============================= */ Result WaitForAddress64From32(ams::svc::Address address, ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) { - MESOSPHERE_PANIC("Stubbed SvcWaitForAddress64From32 was called."); + return WaitForAddress(address, arb_type, value, timeout_ns); } Result SignalToAddress64From32(ams::svc::Address address, ams::svc::SignalType signal_type, int32_t value, int32_t count) { - MESOSPHERE_PANIC("Stubbed SvcSignalToAddress64From32 was called."); + return SignalToAddress(address, signal_type, value, count); } } diff --git a/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp b/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp index 4244f93d3..d7910f5b8 100644 --- a/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp +++ b/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp @@ -279,6 +279,40 @@ namespace ams::util { static_assert(GetParent(GetNode(GetPointer(DerivedStorage))) == GetPointer(DerivedStorage)); }; + template> + class IntrusiveRedBlackTreeMemberTraitsDeferredAssert; + + template + class IntrusiveRedBlackTreeMemberTraitsDeferredAssert { + public: + template + using TreeType = IntrusiveRedBlackTree; + + static constexpr bool IsValid() { + TYPED_STORAGE(Derived) DerivedStorage = {}; + return GetParent(GetNode(GetPointer(DerivedStorage))) == GetPointer(DerivedStorage); + } + private: + template + friend class IntrusiveRedBlackTree; + + static constexpr IntrusiveRedBlackTreeNode *GetNode(Derived *parent) { + return std::addressof(parent->*Member); + } + + static constexpr IntrusiveRedBlackTreeNode const *GetNode(Derived const *parent) { + return std::addressof(parent->*Member); + } + + static constexpr Derived *GetParent(IntrusiveRedBlackTreeNode *node) { + return util::GetParentPointer(node); + } + + static constexpr Derived const *GetParent(IntrusiveRedBlackTreeNode const *node) { + return util::GetParentPointer(node); + } + }; + template class IntrusiveRedBlackTreeBaseNode : public IntrusiveRedBlackTreeNode{};