From 9b1a2451b04fcdecc2a63b267b14223ea23526c0 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Wed, 7 Nov 2018 23:18:46 -0800 Subject: [PATCH] libstratosphere: Add thread primitive, WaitableManager->RequestStop() --- .../include/stratosphere/hossynch.hpp | 27 ++++ .../include/stratosphere/waitable_manager.hpp | 129 +++++++++++------- .../stratosphere/waitable_manager_base.hpp | 3 +- 3 files changed, 109 insertions(+), 50 deletions(-) diff --git a/stratosphere/libstratosphere/include/stratosphere/hossynch.hpp b/stratosphere/libstratosphere/include/stratosphere/hossynch.hpp index 64df3eafb..4f0cbd2de 100644 --- a/stratosphere/libstratosphere/include/stratosphere/hossynch.hpp +++ b/stratosphere/libstratosphere/include/stratosphere/hossynch.hpp @@ -185,4 +185,31 @@ class TimeoutHelper { return armGetSystemTick() >= this->end_tick; } +}; + +class HosThread { + private: + Thread thr = {0}; + public: + HosThread() {} + + Result Initialize(ThreadFunc entry, void *arg, size_t stack_sz, int prio, int cpuid = -2) { + return threadCreate(&this->thr, entry, arg, stack_sz, prio, cpuid); + } + + Handle GetHandle() const { + return this->thr.handle; + } + + Result Start() { + return threadStart(&this->thr); + } + + Result Join() { + Result rc = threadWaitForExit(&this->thr); + if (R_SUCCEEDED(rc)) { + rc = threadClose(&this->thr); + } + return rc; + } }; \ No newline at end of file diff --git a/stratosphere/libstratosphere/include/stratosphere/waitable_manager.hpp b/stratosphere/libstratosphere/include/stratosphere/waitable_manager.hpp index c47ce81e3..47244e561 100644 --- a/stratosphere/libstratosphere/include/stratosphere/waitable_manager.hpp +++ b/stratosphere/libstratosphere/include/stratosphere/waitable_manager.hpp @@ -13,7 +13,7 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ - + #pragma once #include @@ -51,46 +51,50 @@ class WaitableManager : public SessionManagerBase { std::array, ManagerOptions::MaxDomains> domains; std::array is_domain_allocated; std::array domain_objects; - + /* Waitable Manager */ std::vector to_add_waitables; std::vector waitables; std::vector deferred_waitables; - u32 num_threads; - Thread *threads; + + u32 num_extra_threads = 0; + HosThread *threads = nullptr; + HosMutex process_lock; HosMutex signal_lock; HosMutex add_lock; HosMutex cur_thread_lock; HosMutex deferred_lock; bool has_new_waitables = false; - + std::atomic should_stop = false; + IWaitable *next_signaled = nullptr; + Handle main_thread_handle = INVALID_HANDLE; Handle cur_thread_handle = INVALID_HANDLE; public: - WaitableManager(u32 n, u32 ss = 0x8000) : num_threads(n-1) { + WaitableManager(u32 n, u32 ss = 0x8000) : num_extra_threads(n-1) { u32 prio; - u32 cpuid = svcGetCurrentProcessorNumber(); Result rc; - threads = new Thread[num_threads]; - if (num_threads) { + if (num_extra_threads) { + threads = new HosThread[num_extra_threads]; if (R_FAILED((rc = svcGetThreadPriority(&prio, CUR_THREAD_HANDLE)))) { fatalSimple(rc); } - for (unsigned int i = 0; i < num_threads; i++) { - threads[i] = {0}; - threadCreate(&threads[i], &WaitableManager::ProcessLoop, this, ss, prio, cpuid); + for (unsigned int i = 0; i < num_extra_threads; i++) { + if (R_FAILED(threads[i].Initialize(&WaitableManager::ProcessLoop, this, ss, prio))) { + std::abort(); + } } } } - + ~WaitableManager() override { /* This should call the destructor for every waitable. */ std::for_each(to_add_waitables.begin(), to_add_waitables.end(), std::default_delete{}); std::for_each(waitables.begin(), waitables.end(), std::default_delete{}); std::for_each(deferred_waitables.begin(), deferred_waitables.end(), std::default_delete{}); - - /* TODO: Exit the threads? */ + + /* If we've reached here, we should already have exited the threads. */ } virtual void AddWaitable(IWaitable *w) override { @@ -101,10 +105,15 @@ class WaitableManager : public SessionManagerBase { this->CancelSynchronization(); } + virtual void RequestStop() { + this->should_stop = true; + this->CancelSynchronization(); + } + virtual void CancelSynchronization() { svcCancelSynchronization(GetProcessingThreadHandle()); } - + virtual void NotifySignaled(IWaitable *w) override { std::scoped_lock lk{this->signal_lock}; if (this->next_signaled == nullptr) { @@ -112,18 +121,21 @@ class WaitableManager : public SessionManagerBase { } this->CancelSynchronization(); } - + virtual void Process() override { /* Add initial set of waitables. */ AddWaitablesInternal(); - + + /* Set main thread handle. */ + this->main_thread_handle = GetCurrentThreadHandle(); + Result rc; - for (unsigned int i = 0; i < num_threads; i++) { - if (R_FAILED((rc = threadStart(&threads[i])))) { + for (unsigned int i = 0; i < num_extra_threads; i++) { + if (R_FAILED((rc = threads[i].Start()))) { fatalSimple(rc); } } - + ProcessLoop(this); } private: @@ -131,35 +143,46 @@ class WaitableManager : public SessionManagerBase { std::scoped_lock lk{this->cur_thread_lock}; this->cur_thread_handle = h; } - + Handle GetProcessingThreadHandle() { std::scoped_lock lk{this->cur_thread_lock}; return this->cur_thread_handle; } - + static void ProcessLoop(void *t) { WaitableManager *this_ptr = (WaitableManager *)t; while (true) { IWaitable *w = this_ptr->GetWaitable(); + if (this_ptr->should_stop) { + if (GetCurrentThreadHandle() == this_ptr->main_thread_handle) { + /* Join all threads but the main one. */ + for (unsigned int i = 0; i < this_ptr->num_extra_threads; i++) { + this_ptr->threads[i].Join(); + } + break; + } else { + svcExitThread(); + } + } if (w) { Result rc = w->HandleSignaled(0); if (rc == 0xF601) { /* Close! */ delete w; } else { - if (w->IsDeferred()) { + if (w->IsDeferred()) { std::scoped_lock lk{this_ptr->deferred_lock}; this_ptr->deferred_waitables.push_back(w); - } else { + } else { this_ptr->AddWaitable(w); } } } - + /* We finished processing, and maybe that means we can stop deferring an object. */ { std::scoped_lock lk{this_ptr->deferred_lock}; - + for (auto it = this_ptr->deferred_waitables.begin(); it != this_ptr->deferred_waitables.end();) { auto w = *it; Result rc = w->HandleDeferred(); @@ -181,46 +204,50 @@ class WaitableManager : public SessionManagerBase { } } } - + IWaitable *GetWaitable() { std::scoped_lock lk{this->process_lock}; - + /* Set processing thread handle while in scope. */ SetProcessingThreadHandle(GetCurrentThreadHandle()); ON_SCOPE_EXIT { SetProcessingThreadHandle(INVALID_HANDLE); }; - + /* Prepare variables for result. */ this->next_signaled = nullptr; IWaitable *result = nullptr; - + + if (this->should_stop) { + return nullptr; + } + /* Add new waitables, if any. */ AddWaitablesInternal(); - + /* First, see if anything's already signaled. */ for (auto &w : this->waitables) { if (w->IsSignaled()) { result = w; } } - + /* It's possible somebody signaled us while we were iterating. */ { std::scoped_lock lk{this->signal_lock}; if (this->next_signaled != nullptr) result = this->next_signaled; } - + if (result == nullptr) { std::vector handles; std::vector wait_list; - + int handle_index = 0; Result rc; while (result == nullptr) { /* Sort waitables by priority. */ std::sort(this->waitables.begin(), this->waitables.end(), IWaitable::Compare); - + /* Copy out handles. */ handles.resize(this->waitables.size()); wait_list.resize(this->waitables.size()); @@ -234,10 +261,14 @@ class WaitableManager : public SessionManagerBase { handles[num_handles++] = h; } } - + /* Wait forever. */ rc = svcWaitSynchronization(&handle_index, handles.data(), num_handles, U64_MAX); - + + if (this->should_stop) { + return nullptr; + } + if (R_SUCCEEDED(rc)) { IWaitable *w = wait_list[handle_index]; size_t w_ind = std::distance(this->waitables.begin(), std::find(this->waitables.begin(), this->waitables.end(), w)); @@ -265,12 +296,12 @@ class WaitableManager : public SessionManagerBase { } } } - + this->waitables.erase(std::remove_if(this->waitables.begin(), this->waitables.end(), [&](IWaitable *w) { return w == result; }), this->waitables.end()); - + return result; } - + void AddWaitablesInternal() { std::scoped_lock lk{this->add_lock}; if (this->has_new_waitables) { @@ -284,7 +315,7 @@ class WaitableManager : public SessionManagerBase { virtual void AddSession(Handle server_h, ServiceObjectHolder &&service) override { this->AddWaitable(new ServiceSession(server_h, ManagerOptions::PointerBufferSize, std::move(service))); } - + /* Domain Manager */ public: virtual std::shared_ptr AllocateDomain() override { @@ -299,7 +330,7 @@ class WaitableManager : public SessionManagerBase { } return nullptr; } - + void FreeDomain(IDomainObject *domain) override { std::scoped_lock lk{this->domain_lock}; for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) { @@ -313,7 +344,7 @@ class WaitableManager : public SessionManagerBase { } } } - + virtual Result ReserveObject(IDomainObject *domain, u32 *out_object_id) override { std::scoped_lock lk{this->domain_lock}; for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) { @@ -325,7 +356,7 @@ class WaitableManager : public SessionManagerBase { } return 0x25A0A; } - + virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (this->domain_objects[object_id-1].owner == nullptr) { @@ -334,14 +365,14 @@ class WaitableManager : public SessionManagerBase { } return 0x25A0A; } - + virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { std::scoped_lock lk{this->domain_lock}; if (this->domain_objects[object_id-1].owner == domain) { this->domain_objects[object_id-1].obj_holder = std::move(holder); } } - + virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (this->domain_objects[object_id-1].owner == domain) { @@ -349,7 +380,7 @@ class WaitableManager : public SessionManagerBase { } return nullptr; } - + virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (this->domain_objects[object_id-1].owner == domain) { @@ -359,7 +390,7 @@ class WaitableManager : public SessionManagerBase { } return 0x3D80B; } - + virtual Result ForceFreeObject(u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (this->domain_objects[object_id-1].owner != nullptr) { diff --git a/stratosphere/libstratosphere/include/stratosphere/waitable_manager_base.hpp b/stratosphere/libstratosphere/include/stratosphere/waitable_manager_base.hpp index 777403b21..35242aac2 100644 --- a/stratosphere/libstratosphere/include/stratosphere/waitable_manager_base.hpp +++ b/stratosphere/libstratosphere/include/stratosphere/waitable_manager_base.hpp @@ -31,7 +31,8 @@ class WaitableManagerBase { virtual void AddWaitable(IWaitable *w) = 0; virtual void NotifySignaled(IWaitable *w) = 0; - + + virtual void RequestStop() = 0; virtual void Process() = 0; };