libstratosphere: Add thread primitive, WaitableManager->RequestStop()

This commit is contained in:
Michael Scire 2018-11-07 23:18:46 -08:00
parent e65bee0d6a
commit 9b1a2451b0
3 changed files with 109 additions and 50 deletions

View file

@ -185,4 +185,31 @@ class TimeoutHelper {
return armGetSystemTick() >= this->end_tick; return armGetSystemTick() >= this->end_tick;
} }
};
class HosThread {
private:
Thread thr = {0};
public:
HosThread() {}
Result Initialize(ThreadFunc entry, void *arg, size_t stack_sz, int prio, int cpuid = -2) {
return threadCreate(&this->thr, entry, arg, stack_sz, prio, cpuid);
}
Handle GetHandle() const {
return this->thr.handle;
}
Result Start() {
return threadStart(&this->thr);
}
Result Join() {
Result rc = threadWaitForExit(&this->thr);
if (R_SUCCEEDED(rc)) {
rc = threadClose(&this->thr);
}
return rc;
}
}; };

View file

@ -13,7 +13,7 @@
* You should have received a copy of the GNU General Public License * You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>. * along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
#pragma once #pragma once
#include <switch.h> #include <switch.h>
@ -51,46 +51,50 @@ class WaitableManager : public SessionManagerBase {
std::array<std::weak_ptr<IDomainObject>, ManagerOptions::MaxDomains> domains; std::array<std::weak_ptr<IDomainObject>, ManagerOptions::MaxDomains> domains;
std::array<bool, ManagerOptions::MaxDomains> is_domain_allocated; std::array<bool, ManagerOptions::MaxDomains> is_domain_allocated;
std::array<DomainEntry, ManagerOptions::MaxDomainObjects> domain_objects; std::array<DomainEntry, ManagerOptions::MaxDomainObjects> domain_objects;
/* Waitable Manager */ /* Waitable Manager */
std::vector<IWaitable *> to_add_waitables; std::vector<IWaitable *> to_add_waitables;
std::vector<IWaitable *> waitables; std::vector<IWaitable *> waitables;
std::vector<IWaitable *> deferred_waitables; std::vector<IWaitable *> deferred_waitables;
u32 num_threads;
Thread *threads; u32 num_extra_threads = 0;
HosThread *threads = nullptr;
HosMutex process_lock; HosMutex process_lock;
HosMutex signal_lock; HosMutex signal_lock;
HosMutex add_lock; HosMutex add_lock;
HosMutex cur_thread_lock; HosMutex cur_thread_lock;
HosMutex deferred_lock; HosMutex deferred_lock;
bool has_new_waitables = false; bool has_new_waitables = false;
std::atomic<bool> should_stop = false;
IWaitable *next_signaled = nullptr; IWaitable *next_signaled = nullptr;
Handle main_thread_handle = INVALID_HANDLE;
Handle cur_thread_handle = INVALID_HANDLE; Handle cur_thread_handle = INVALID_HANDLE;
public: public:
WaitableManager(u32 n, u32 ss = 0x8000) : num_threads(n-1) { WaitableManager(u32 n, u32 ss = 0x8000) : num_extra_threads(n-1) {
u32 prio; u32 prio;
u32 cpuid = svcGetCurrentProcessorNumber();
Result rc; Result rc;
threads = new Thread[num_threads]; if (num_extra_threads) {
if (num_threads) { threads = new HosThread[num_extra_threads];
if (R_FAILED((rc = svcGetThreadPriority(&prio, CUR_THREAD_HANDLE)))) { if (R_FAILED((rc = svcGetThreadPriority(&prio, CUR_THREAD_HANDLE)))) {
fatalSimple(rc); fatalSimple(rc);
} }
for (unsigned int i = 0; i < num_threads; i++) { for (unsigned int i = 0; i < num_extra_threads; i++) {
threads[i] = {0}; if (R_FAILED(threads[i].Initialize(&WaitableManager::ProcessLoop, this, ss, prio))) {
threadCreate(&threads[i], &WaitableManager::ProcessLoop, this, ss, prio, cpuid); std::abort();
}
} }
} }
} }
~WaitableManager() override { ~WaitableManager() override {
/* This should call the destructor for every waitable. */ /* This should call the destructor for every waitable. */
std::for_each(to_add_waitables.begin(), to_add_waitables.end(), std::default_delete<IWaitable>{}); std::for_each(to_add_waitables.begin(), to_add_waitables.end(), std::default_delete<IWaitable>{});
std::for_each(waitables.begin(), waitables.end(), std::default_delete<IWaitable>{}); std::for_each(waitables.begin(), waitables.end(), std::default_delete<IWaitable>{});
std::for_each(deferred_waitables.begin(), deferred_waitables.end(), std::default_delete<IWaitable>{}); std::for_each(deferred_waitables.begin(), deferred_waitables.end(), std::default_delete<IWaitable>{});
/* TODO: Exit the threads? */ /* If we've reached here, we should already have exited the threads. */
} }
virtual void AddWaitable(IWaitable *w) override { virtual void AddWaitable(IWaitable *w) override {
@ -101,10 +105,15 @@ class WaitableManager : public SessionManagerBase {
this->CancelSynchronization(); this->CancelSynchronization();
} }
virtual void RequestStop() {
this->should_stop = true;
this->CancelSynchronization();
}
virtual void CancelSynchronization() { virtual void CancelSynchronization() {
svcCancelSynchronization(GetProcessingThreadHandle()); svcCancelSynchronization(GetProcessingThreadHandle());
} }
virtual void NotifySignaled(IWaitable *w) override { virtual void NotifySignaled(IWaitable *w) override {
std::scoped_lock lk{this->signal_lock}; std::scoped_lock lk{this->signal_lock};
if (this->next_signaled == nullptr) { if (this->next_signaled == nullptr) {
@ -112,18 +121,21 @@ class WaitableManager : public SessionManagerBase {
} }
this->CancelSynchronization(); this->CancelSynchronization();
} }
virtual void Process() override { virtual void Process() override {
/* Add initial set of waitables. */ /* Add initial set of waitables. */
AddWaitablesInternal(); AddWaitablesInternal();
/* Set main thread handle. */
this->main_thread_handle = GetCurrentThreadHandle();
Result rc; Result rc;
for (unsigned int i = 0; i < num_threads; i++) { for (unsigned int i = 0; i < num_extra_threads; i++) {
if (R_FAILED((rc = threadStart(&threads[i])))) { if (R_FAILED((rc = threads[i].Start()))) {
fatalSimple(rc); fatalSimple(rc);
} }
} }
ProcessLoop(this); ProcessLoop(this);
} }
private: private:
@ -131,35 +143,46 @@ class WaitableManager : public SessionManagerBase {
std::scoped_lock<HosMutex> lk{this->cur_thread_lock}; std::scoped_lock<HosMutex> lk{this->cur_thread_lock};
this->cur_thread_handle = h; this->cur_thread_handle = h;
} }
Handle GetProcessingThreadHandle() { Handle GetProcessingThreadHandle() {
std::scoped_lock<HosMutex> lk{this->cur_thread_lock}; std::scoped_lock<HosMutex> lk{this->cur_thread_lock};
return this->cur_thread_handle; return this->cur_thread_handle;
} }
static void ProcessLoop(void *t) { static void ProcessLoop(void *t) {
WaitableManager *this_ptr = (WaitableManager *)t; WaitableManager *this_ptr = (WaitableManager *)t;
while (true) { while (true) {
IWaitable *w = this_ptr->GetWaitable(); IWaitable *w = this_ptr->GetWaitable();
if (this_ptr->should_stop) {
if (GetCurrentThreadHandle() == this_ptr->main_thread_handle) {
/* Join all threads but the main one. */
for (unsigned int i = 0; i < this_ptr->num_extra_threads; i++) {
this_ptr->threads[i].Join();
}
break;
} else {
svcExitThread();
}
}
if (w) { if (w) {
Result rc = w->HandleSignaled(0); Result rc = w->HandleSignaled(0);
if (rc == 0xF601) { if (rc == 0xF601) {
/* Close! */ /* Close! */
delete w; delete w;
} else { } else {
if (w->IsDeferred()) { if (w->IsDeferred()) {
std::scoped_lock lk{this_ptr->deferred_lock}; std::scoped_lock lk{this_ptr->deferred_lock};
this_ptr->deferred_waitables.push_back(w); this_ptr->deferred_waitables.push_back(w);
} else { } else {
this_ptr->AddWaitable(w); this_ptr->AddWaitable(w);
} }
} }
} }
/* We finished processing, and maybe that means we can stop deferring an object. */ /* We finished processing, and maybe that means we can stop deferring an object. */
{ {
std::scoped_lock lk{this_ptr->deferred_lock}; std::scoped_lock lk{this_ptr->deferred_lock};
for (auto it = this_ptr->deferred_waitables.begin(); it != this_ptr->deferred_waitables.end();) { for (auto it = this_ptr->deferred_waitables.begin(); it != this_ptr->deferred_waitables.end();) {
auto w = *it; auto w = *it;
Result rc = w->HandleDeferred(); Result rc = w->HandleDeferred();
@ -181,46 +204,50 @@ class WaitableManager : public SessionManagerBase {
} }
} }
} }
IWaitable *GetWaitable() { IWaitable *GetWaitable() {
std::scoped_lock lk{this->process_lock}; std::scoped_lock lk{this->process_lock};
/* Set processing thread handle while in scope. */ /* Set processing thread handle while in scope. */
SetProcessingThreadHandle(GetCurrentThreadHandle()); SetProcessingThreadHandle(GetCurrentThreadHandle());
ON_SCOPE_EXIT { ON_SCOPE_EXIT {
SetProcessingThreadHandle(INVALID_HANDLE); SetProcessingThreadHandle(INVALID_HANDLE);
}; };
/* Prepare variables for result. */ /* Prepare variables for result. */
this->next_signaled = nullptr; this->next_signaled = nullptr;
IWaitable *result = nullptr; IWaitable *result = nullptr;
if (this->should_stop) {
return nullptr;
}
/* Add new waitables, if any. */ /* Add new waitables, if any. */
AddWaitablesInternal(); AddWaitablesInternal();
/* First, see if anything's already signaled. */ /* First, see if anything's already signaled. */
for (auto &w : this->waitables) { for (auto &w : this->waitables) {
if (w->IsSignaled()) { if (w->IsSignaled()) {
result = w; result = w;
} }
} }
/* It's possible somebody signaled us while we were iterating. */ /* It's possible somebody signaled us while we were iterating. */
{ {
std::scoped_lock lk{this->signal_lock}; std::scoped_lock lk{this->signal_lock};
if (this->next_signaled != nullptr) result = this->next_signaled; if (this->next_signaled != nullptr) result = this->next_signaled;
} }
if (result == nullptr) { if (result == nullptr) {
std::vector<Handle> handles; std::vector<Handle> handles;
std::vector<IWaitable *> wait_list; std::vector<IWaitable *> wait_list;
int handle_index = 0; int handle_index = 0;
Result rc; Result rc;
while (result == nullptr) { while (result == nullptr) {
/* Sort waitables by priority. */ /* Sort waitables by priority. */
std::sort(this->waitables.begin(), this->waitables.end(), IWaitable::Compare); std::sort(this->waitables.begin(), this->waitables.end(), IWaitable::Compare);
/* Copy out handles. */ /* Copy out handles. */
handles.resize(this->waitables.size()); handles.resize(this->waitables.size());
wait_list.resize(this->waitables.size()); wait_list.resize(this->waitables.size());
@ -234,10 +261,14 @@ class WaitableManager : public SessionManagerBase {
handles[num_handles++] = h; handles[num_handles++] = h;
} }
} }
/* Wait forever. */ /* Wait forever. */
rc = svcWaitSynchronization(&handle_index, handles.data(), num_handles, U64_MAX); rc = svcWaitSynchronization(&handle_index, handles.data(), num_handles, U64_MAX);
if (this->should_stop) {
return nullptr;
}
if (R_SUCCEEDED(rc)) { if (R_SUCCEEDED(rc)) {
IWaitable *w = wait_list[handle_index]; IWaitable *w = wait_list[handle_index];
size_t w_ind = std::distance(this->waitables.begin(), std::find(this->waitables.begin(), this->waitables.end(), w)); size_t w_ind = std::distance(this->waitables.begin(), std::find(this->waitables.begin(), this->waitables.end(), w));
@ -265,12 +296,12 @@ class WaitableManager : public SessionManagerBase {
} }
} }
} }
this->waitables.erase(std::remove_if(this->waitables.begin(), this->waitables.end(), [&](IWaitable *w) { return w == result; }), this->waitables.end()); this->waitables.erase(std::remove_if(this->waitables.begin(), this->waitables.end(), [&](IWaitable *w) { return w == result; }), this->waitables.end());
return result; return result;
} }
void AddWaitablesInternal() { void AddWaitablesInternal() {
std::scoped_lock lk{this->add_lock}; std::scoped_lock lk{this->add_lock};
if (this->has_new_waitables) { if (this->has_new_waitables) {
@ -284,7 +315,7 @@ class WaitableManager : public SessionManagerBase {
virtual void AddSession(Handle server_h, ServiceObjectHolder &&service) override { virtual void AddSession(Handle server_h, ServiceObjectHolder &&service) override {
this->AddWaitable(new ServiceSession(server_h, ManagerOptions::PointerBufferSize, std::move(service))); this->AddWaitable(new ServiceSession(server_h, ManagerOptions::PointerBufferSize, std::move(service)));
} }
/* Domain Manager */ /* Domain Manager */
public: public:
virtual std::shared_ptr<IDomainObject> AllocateDomain() override { virtual std::shared_ptr<IDomainObject> AllocateDomain() override {
@ -299,7 +330,7 @@ class WaitableManager : public SessionManagerBase {
} }
return nullptr; return nullptr;
} }
void FreeDomain(IDomainObject *domain) override { void FreeDomain(IDomainObject *domain) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) { for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) {
@ -313,7 +344,7 @@ class WaitableManager : public SessionManagerBase {
} }
} }
} }
virtual Result ReserveObject(IDomainObject *domain, u32 *out_object_id) override { virtual Result ReserveObject(IDomainObject *domain, u32 *out_object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) { for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) {
@ -325,7 +356,7 @@ class WaitableManager : public SessionManagerBase {
} }
return 0x25A0A; return 0x25A0A;
} }
virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (this->domain_objects[object_id-1].owner == nullptr) { if (this->domain_objects[object_id-1].owner == nullptr) {
@ -334,14 +365,14 @@ class WaitableManager : public SessionManagerBase {
} }
return 0x25A0A; return 0x25A0A;
} }
virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
this->domain_objects[object_id-1].obj_holder = std::move(holder); this->domain_objects[object_id-1].obj_holder = std::move(holder);
} }
} }
virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
@ -349,7 +380,7 @@ class WaitableManager : public SessionManagerBase {
} }
return nullptr; return nullptr;
} }
virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { virtual Result FreeObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
@ -359,7 +390,7 @@ class WaitableManager : public SessionManagerBase {
} }
return 0x3D80B; return 0x3D80B;
} }
virtual Result ForceFreeObject(u32 object_id) override { virtual Result ForceFreeObject(u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (this->domain_objects[object_id-1].owner != nullptr) { if (this->domain_objects[object_id-1].owner != nullptr) {

View file

@ -31,7 +31,8 @@ class WaitableManagerBase {
virtual void AddWaitable(IWaitable *w) = 0; virtual void AddWaitable(IWaitable *w) = 0;
virtual void NotifySignaled(IWaitable *w) = 0; virtual void NotifySignaled(IWaitable *w) = 0;
virtual void RequestStop() = 0;
virtual void Process() = 0; virtual void Process() = 0;
}; };