Atmosphere/libraries/libstratosphere/source/os/impl/os_thread_manager_impl.os.windows.cpp

263 lines
9.3 KiB
C++

/*
* Copyright (c) Atmosphère-NX
*
* This program is free software; you can redistribute it and/or modify it
* under the terms and conditions of the GNU General Public License,
* version 2, as published by the Free Software Foundation.
*
* This program is distributed in the hope it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include <stratosphere.hpp>
#include "os_thread_manager_impl.os.windows.hpp"
#include "os_thread_manager.hpp"
namespace ams::os::impl {
namespace {
constexpr size_t DefaultStackSize = 1_MB;
class ScopedSaveLastError {
NON_COPYABLE(ScopedSaveLastError);
NON_MOVEABLE(ScopedSaveLastError);
private:
DWORD m_last_error;
public:
ALWAYS_INLINE ScopedSaveLastError() : m_last_error(::GetLastError()) { /* ... */ }
ALWAYS_INLINE ~ScopedSaveLastError() { ::SetLastError(m_last_error); }
};
unsigned __stdcall InvokeThread(void *arg) {
ThreadType *thread = static_cast<ThreadType *>(arg);
/* Invoke the thread. */
ThreadManager::InvokeThread(thread);
return 0;
}
os::NativeHandle DuplicateThreadHandle(os::NativeHandle thread_handle) {
/* Get the thread's windows handle. */
os::NativeHandle windows_handle = os::InvalidNativeHandle;
AMS_ABORT_UNLESS(::DuplicateHandle(::GetCurrentProcess(), thread_handle, ::GetCurrentProcess(), std::addressof(windows_handle), 0, 0, DUPLICATE_SAME_ACCESS));
return windows_handle;
}
os::ThreadType *DynamicAllocateAndRegisterThreadType() {
/* Get the thread manager. */
auto &thread_manager = GetThreadManager();
/* Allocate a thread. */
auto *thread = thread_manager.AllocateThreadType();
AMS_ABORT_UNLESS(thread != nullptr);
/* Setup the thread object. */
SetupThreadObjectUnsafe(thread, nullptr, nullptr, nullptr, nullptr, 0, DefaultThreadPriority);
thread->state = ThreadType::State_Started;
thread->auto_registered = true;
/* Set the thread's windows handle. */
thread->native_handle = DuplicateThreadHandle(::GetCurrentThread());
thread->ideal_core = ::GetCurrentProcessorNumber();
thread->affinity_mask = thread_manager.GetThreadAvailableCoreMask();
/* Place the object under the thread manager. */
thread_manager.PlaceThreadObjectUnderThreadManagerSafe(thread);
return thread;
}
}
ThreadManagerWindowsImpl::ThreadManagerWindowsImpl(ThreadType *main_thread) : m_tls_index(::TlsAlloc()) {
/* Verify tls index is valid. */
AMS_ASSERT(m_tls_index != TLS_OUT_OF_INDEXES);
/* Setup the main thread object. */
SetupThreadObjectUnsafe(main_thread, nullptr, nullptr, nullptr, nullptr, DefaultStackSize, DefaultThreadPriority);
/* Setup the main thread's windows handle information. */
main_thread->native_handle = DuplicateThreadHandle(::GetCurrentThread());
main_thread->ideal_core = ::GetCurrentProcessorNumber();
main_thread->affinity_mask = this->GetThreadAvailableCoreMask();
}
Result ThreadManagerWindowsImpl::CreateThread(ThreadType *thread, s32 ideal_core) {
/* Create the thread. */
os::NativeHandle thread_handle = reinterpret_cast<os::NativeHandle>(_beginthreadex(nullptr, thread->stack_size, &InvokeThread, thread, CREATE_SUSPENDED, 0));
AMS_ASSERT(thread_handle != os::InvalidNativeHandle);
/* Set the thread's windows handle information. */
thread->native_handle = thread_handle;
thread->ideal_core = ideal_core;
thread->affinity_mask = this->GetThreadAvailableCoreMask();
R_SUCCEED();
}
void ThreadManagerWindowsImpl::DestroyThreadUnsafe(ThreadType *thread) {
/* Close the thread's handle. */
const auto ret = ::CloseHandle(thread->native_handle);
AMS_ASSERT(ret);
AMS_UNUSED(ret);
thread->native_handle = os::InvalidNativeHandle;
}
void ThreadManagerWindowsImpl::StartThread(const ThreadType *thread) {
ScopedSaveLastError save_error;
/* Resume the thread. */
const auto ret = ::ResumeThread(thread->native_handle);
AMS_ASSERT(ret == 1);
AMS_UNUSED(ret);
}
void ThreadManagerWindowsImpl::WaitForThreadExit(ThreadType *thread) {
::WaitForSingleObject(thread->native_handle, INFINITE);
}
bool ThreadManagerWindowsImpl::TryWaitForThreadExit(ThreadType *thread) {
return ::WaitForSingleObject(thread->native_handle, 0) == 0;
}
void ThreadManagerWindowsImpl::YieldThread() {
::Sleep(0);
}
bool ThreadManagerWindowsImpl::ChangePriority(ThreadType *thread, s32 priority) {
AMS_UNUSED(thread, priority);
return true;
}
s32 ThreadManagerWindowsImpl::GetCurrentPriority(const ThreadType *thread) const {
return thread->base_priority;
}
ThreadId ThreadManagerWindowsImpl::GetThreadId(const ThreadType *thread) const {
ScopedSaveLastError save_error;
const auto thread_id = ::GetThreadId(thread->native_handle);
AMS_ASSERT(thread_id != 0);
return thread_id;
}
void ThreadManagerWindowsImpl::SuspendThreadUnsafe(ThreadType *thread) {
ScopedSaveLastError save_error;
const auto ret = ::SuspendThread(thread->native_handle);
AMS_ASSERT(ret == 0);
AMS_UNUSED(ret);
}
void ThreadManagerWindowsImpl::ResumeThreadUnsafe(ThreadType *thread) {
ScopedSaveLastError save_error;
const auto ret = ::ResumeThread(thread->native_handle);
AMS_ASSERT(ret == 1);
AMS_UNUSED(ret);
}
/* TODO: void ThreadManagerWindowsImpl::GetThreadContextUnsafe(ThreadContextInfo *out_context, const ThreadType *thread); */
void ThreadManagerWindowsImpl::NotifyThreadNameChangedImpl(const ThreadType *thread) const {
/* TODO */
AMS_UNUSED(thread);
}
void ThreadManagerWindowsImpl::SetCurrentThread(ThreadType *thread) const {
ScopedSaveLastError save_error;
::TlsSetValue(m_tls_index, thread);
}
ThreadType *ThreadManagerWindowsImpl::GetCurrentThread() const {
ScopedSaveLastError save_error;
/* Get the thread from tls index. */
ThreadType *thread = static_cast<ThreadType *>(static_cast<void *>(::TlsGetValue(m_tls_index)));
/* If the thread's TLS isn't set, we need to find it (and set tls) or make it. */
if (thread == nullptr) {
/* Try to find the thread. */
thread = GetThreadManager().FindThreadTypeById(::GetCurrentThreadId());
if (thread == nullptr) {
/* Create the thread. */
thread = DynamicAllocateAndRegisterThreadType();
}
/* Set the thread's TLS. */
this->SetCurrentThread(thread);
}
return thread;
}
s32 ThreadManagerWindowsImpl::GetCurrentCoreNumber() const {
ScopedSaveLastError save_error;
return ::GetCurrentProcessorNumber();
}
void ThreadManagerWindowsImpl::SetThreadCoreMask(ThreadType *thread, s32 ideal_core, u64 affinity_mask) const {
ScopedSaveLastError save_error;
/* If we should use the default, set the actual ideal core. */
if (ideal_core == IdealCoreUseDefault) {
affinity_mask = this->GetThreadAvailableCoreMask();
ideal_core = util::CountTrailingZeros(affinity_mask);
affinity_mask = static_cast<decltype(affinity_mask)>(1) << ideal_core;
}
/* Lock the thread. */
std::scoped_lock lk(util::GetReference(thread->cs_thread));
/* Set the thread's affinity mask. */
const auto old = ::SetThreadAffinityMask(thread->native_handle, affinity_mask);
AMS_ABORT_UNLESS(old != 0);
/* Set the ideal core. */
if (ideal_core != IdealCoreNoUpdate) {
thread->ideal_core = ideal_core;
}
/* Set the tracked affinity mask. */
thread->affinity_mask = affinity_mask;
}
void ThreadManagerWindowsImpl::GetThreadCoreMask(s32 *out_ideal_core, u64 *out_affinity_mask, const ThreadType *thread) const {
ScopedSaveLastError save_error;
/* Lock the thread. */
std::scoped_lock lk(util::GetReference(thread->cs_thread));
/* Set the output. */
if (out_ideal_core != nullptr) {
*out_ideal_core = thread->ideal_core;
}
if (out_affinity_mask != nullptr) {
*out_affinity_mask = thread->affinity_mask;
}
}
u64 ThreadManagerWindowsImpl::GetThreadAvailableCoreMask() const {
ScopedSaveLastError save_error;
/* Get the process's affinity mask. */
u64 process_affinity, system_affinity;
AMS_ABORT_UNLESS(::GetProcessAffinityMask(::GetCurrentProcess(), std::addressof(process_affinity), std::addressof(system_affinity)) != 0);
return process_affinity;
}
}