diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_scheduler.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_scheduler.hpp index 208d6df13..24f1c7048 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_scheduler.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_scheduler.hpp @@ -39,13 +39,13 @@ namespace ams::kern { static_assert(ams::svc::HighestThreadPriority <= HighestCoreMigrationAllowedPriority); struct SchedulingState { - util::Atomic needs_scheduling{false}; + util::Atomic needs_scheduling{false}; bool interrupt_task_runnable{false}; bool should_count_idle{false}; u64 idle_count{0}; KThread *highest_priority_thread{nullptr}; void *idle_thread_stack{nullptr}; - util::Atomic prev_thread{nullptr}; + KThread *prev_thread{nullptr}; KInterruptTaskManager *interrupt_task_manager{nullptr}; constexpr SchedulingState() = default; @@ -100,7 +100,7 @@ namespace ams::kern { } ALWAYS_INLINE KThread *GetPreviousThread() const { - return m_state.prev_thread.Load(); + return m_state.prev_thread; } ALWAYS_INLINE KThread *GetSchedulerCurrentThread() const { diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_slab_heap.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_slab_heap.hpp index c63a5691f..bd8157bd8 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_slab_heap.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_slab_heap.hpp @@ -76,18 +76,20 @@ namespace ams::kern { NON_MOVEABLE(KSlabHeapBase); private: size_t m_obj_size{}; - util::Atomic m_peak{0}; + uintptr_t m_peak{}; uintptr_t m_start{}; uintptr_t m_end{}; private: ALWAYS_INLINE void UpdatePeakImpl(uintptr_t obj) { + const util::AtomicRef peak_ref(m_peak); + const uintptr_t alloc_peak = obj + this->GetObjectSize(); - uintptr_t cur_peak = m_peak.Load(); + uintptr_t cur_peak = m_peak; do { if (alloc_peak <= cur_peak) { break; } - } while (!m_peak.CompareExchangeStrong(cur_peak, alloc_peak)); + } while (!peak_ref.CompareExchangeStrong(cur_peak, alloc_peak)); } public: constexpr KSlabHeapBase() = default; @@ -110,8 +112,7 @@ namespace ams::kern { const size_t num_obj = (memory_size / obj_size); m_start = reinterpret_cast(memory); m_end = m_start + num_obj * obj_size; - - m_peak.Store(m_start); + m_peak = m_start; /* Free the objects. */ u8 *cur = reinterpret_cast(m_end); @@ -175,7 +176,7 @@ namespace ams::kern { } ALWAYS_INLINE size_t GetPeakIndex() const { - return this->GetObjectIndex(reinterpret_cast(m_peak.Load())); + return this->GetObjectIndex(reinterpret_cast(m_peak)); } ALWAYS_INLINE uintptr_t GetSlabHeapAddress() const { diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp index c726e3fa2..23295e572 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp @@ -225,7 +225,7 @@ namespace ams::kern { s32 m_original_physical_ideal_core_id{}; s32 m_num_core_migration_disables{}; ThreadState m_thread_state{}; - util::Atomic m_termination_requested{false}; + util::Atomic m_termination_requested{false}; bool m_wait_cancelled{}; bool m_cancellable{}; bool m_signaled{}; diff --git a/libraries/libmesosphere/source/kern_k_scheduler.cpp b/libraries/libmesosphere/source/kern_k_scheduler.cpp index dcba85e38..b80e7e836 100644 --- a/libraries/libmesosphere/source/kern_k_scheduler.cpp +++ b/libraries/libmesosphere/source/kern_k_scheduler.cpp @@ -246,9 +246,9 @@ namespace ams::kern { if (cur_process != nullptr) { /* NOTE: Combining this into AMS_LIKELY(!... && ...) triggers an internal compiler error: Segmentation fault in GCC 9.2.0. */ if (AMS_LIKELY(!cur_thread->IsTerminationRequested()) && AMS_LIKELY(cur_thread->GetActiveCore() == m_core_id)) { - m_state.prev_thread.Store(cur_thread); + m_state.prev_thread = cur_thread; } else { - m_state.prev_thread.Store(nullptr); + m_state.prev_thread =nullptr; } } @@ -270,9 +270,12 @@ namespace ams::kern { void KScheduler::ClearPreviousThread(KThread *thread) { MESOSPHERE_ASSERT(IsSchedulerLockedByCurrentThread()); for (size_t i = 0; i < cpu::NumCores; ++i) { + /* Get an atomic reference to the core scheduler's previous thread. */ + const util::AtomicRef prev_thread(Kernel::GetScheduler(static_cast(i)).m_state.prev_thread); + /* Atomically clear the previous thread if it's our target. */ KThread *compare = thread; - Kernel::GetScheduler(static_cast(i)).m_state.prev_thread.CompareExchangeStrong(compare, nullptr); + prev_thread.CompareExchangeStrong(compare, nullptr); } } diff --git a/libraries/libmesosphere/source/kern_k_thread.cpp b/libraries/libmesosphere/source/kern_k_thread.cpp index bce7238f1..48993313c 100644 --- a/libraries/libmesosphere/source/kern_k_thread.cpp +++ b/libraries/libmesosphere/source/kern_k_thread.cpp @@ -1184,7 +1184,7 @@ namespace ams::kern { /* Determine if this is the first termination request. */ const bool first_request = [&] ALWAYS_INLINE_LAMBDA () -> bool { /* Perform an atomic compare-and-swap from false to true. */ - u8 expected = false; + bool expected = false; return m_termination_requested.CompareExchangeStrong(expected, true); }(); diff --git a/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp b/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp index 2e025d313..588f59e1a 100644 --- a/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp +++ b/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp @@ -100,6 +100,121 @@ namespace ams::util { #undef AMS_UTIL_IMPL_DEFINE_ATOMIC_STORE_EXCLUSIVE_FUNCTION + template + constexpr ALWAYS_INLINE T ConvertToTypeForAtomic(AtomicStorage s) { + if constexpr (std::integral) { + return static_cast(s); + } else if constexpr(std::is_pointer::value) { + return reinterpret_cast(s); + } else { + return std::bit_cast(s); + } + } + + template + constexpr ALWAYS_INLINE AtomicStorage ConvertToStorageForAtomic(T arg) { + if constexpr (std::integral) { + return static_cast>(arg); + } else if constexpr(std::is_pointer::value) { + if (std::is_constant_evaluated() && arg == nullptr) { + return 0; + } + + return reinterpret_cast>(arg); + } else { + return std::bit_cast>(arg); + } + } + + template + ALWAYS_INLINE StorageType AtomicLoadImpl(volatile StorageType * const p) { + if constexpr (Order != std::memory_order_relaxed) { + return ::ams::util::impl::LoadAcquireForAtomic(p); + } else { + return *p; + } + } + + template + ALWAYS_INLINE void AtomicStoreImpl(volatile StorageType * const p, const StorageType s) { + if constexpr (Order != std::memory_order_relaxed) { + ::ams::util::impl::StoreReleaseForAtomic(p, s); + } else { + *p = s; + } + } + + template + ALWAYS_INLINE StorageType LoadExclusiveForAtomicByMemoryOrder(volatile StorageType * const p) { + if constexpr (Order == std::memory_order_relaxed) { + return ::ams::util::impl::LoadExclusiveForAtomic(p); + } else if constexpr (Order == std::memory_order_consume || Order == std::memory_order_acquire) { + return ::ams::util::impl::LoadAcquireExclusiveForAtomic(p); + } else if constexpr (Order == std::memory_order_release) { + return ::ams::util::impl::LoadExclusiveForAtomic(p); + } else if constexpr (Order == std::memory_order_acq_rel || Order == std::memory_order_seq_cst) { + return ::ams::util::impl::LoadAcquireExclusiveForAtomic(p); + } else { + static_assert(Order != Order, "Invalid memory order"); + } + } + + template + ALWAYS_INLINE bool StoreExclusiveForAtomicByMemoryOrder(volatile StorageType * const p, const StorageType s) { + if constexpr (Order == std::memory_order_relaxed) { + return ::ams::util::impl::StoreExclusiveForAtomic(p, s); + } else if constexpr (Order == std::memory_order_consume || Order == std::memory_order_acquire) { + return ::ams::util::impl::StoreExclusiveForAtomic(p, s); + } else if constexpr (Order == std::memory_order_release) { + return ::ams::util::impl::StoreReleaseExclusiveForAtomic(p, s); + } else if constexpr (Order == std::memory_order_acq_rel || Order == std::memory_order_seq_cst) { + return ::ams::util::impl::StoreReleaseExclusiveForAtomic(p, s); + } else { + static_assert(Order != Order, "Invalid memory order"); + } + } + + template + ALWAYS_INLINE StorageType AtomicExchangeImpl(volatile StorageType * const p, const StorageType s) { + StorageType current; + do { + current = ::ams::util::impl::LoadExclusiveForAtomicByMemoryOrder(p); + } while(AMS_UNLIKELY(!impl::StoreExclusiveForAtomicByMemoryOrder(p, s))); + + return current; + } + + template + ALWAYS_INLINE bool AtomicCompareExchangeWeakImpl(volatile AtomicStorage * const p, T &expected, T desired) { + const AtomicStorage e = ::ams::util::impl::ConvertToStorageForAtomic(expected); + const AtomicStorage d = ::ams::util::impl::ConvertToStorageForAtomic(desired); + + const AtomicStorage current = ::ams::util::impl::LoadExclusiveForAtomicByMemoryOrder(p); + if (AMS_UNLIKELY(current != e)) { + impl::ClearExclusiveForAtomic(); + expected = ::ams::util::impl::ConvertToTypeForAtomic(current); + return false; + } + + return AMS_LIKELY(impl::StoreExclusiveForAtomicByMemoryOrder(p, d)); + } + + template + ALWAYS_INLINE bool AtomicCompareExchangeStrongImpl(volatile AtomicStorage * const p, T &expected, T desired) { + const AtomicStorage e = ::ams::util::impl::ConvertToStorageForAtomic(expected); + const AtomicStorage d = ::ams::util::impl::ConvertToStorageForAtomic(desired); + + do { + if (const AtomicStorage current = ::ams::util::impl::LoadExclusiveForAtomicByMemoryOrder(p); AMS_UNLIKELY(current != e)) { + impl::ClearExclusiveForAtomic(); + expected = ::ams::util::impl::ConvertToTypeForAtomic(current); + return false; + } + } while (AMS_UNLIKELY(!impl::StoreExclusiveForAtomicByMemoryOrder(p, d))); + + return true; + } + } template @@ -117,27 +232,11 @@ namespace ams::util { using DifferenceType = typename std::conditional::type>::type; static constexpr ALWAYS_INLINE T ConvertToType(StorageType s) { - if constexpr (std::integral) { - return static_cast(s); - } else if constexpr(std::is_pointer::value) { - return reinterpret_cast(s); - } else { - return std::bit_cast(s); - } + return impl::ConvertToTypeForAtomic(s); } static constexpr ALWAYS_INLINE StorageType ConvertToStorage(T arg) { - if constexpr (std::integral) { - return static_cast(arg); - } else if constexpr(std::is_pointer::value) { - if (std::is_constant_evaluated() && arg == nullptr) { - return 0; - } - - return reinterpret_cast(arg); - } else { - return std::bit_cast(arg); - } + return impl::ConvertToStorageForAtomic(arg); } private: StorageType m_v; @@ -157,148 +256,31 @@ namespace ams::util { return desired; } + ALWAYS_INLINE operator T() const { return this->Load(); } + template ALWAYS_INLINE T Load() const { - if constexpr (Order != std::memory_order_relaxed) { - return ConvertToType(impl::LoadAcquireForAtomic(this->GetStoragePointer())); - } else { - return ConvertToType(*this->GetStoragePointer()); - } + return ConvertToType(impl::AtomicLoadImpl(this->GetStoragePointer())); } template ALWAYS_INLINE void Store(T arg) { - if constexpr (Order != std::memory_order_relaxed) { - impl::StoreReleaseForAtomic(this->GetStoragePointer(), ConvertToStorage(arg)); - } else { - *this->GetStoragePointer() = ConvertToStorage(arg); - } + return impl::AtomicStoreImpl(this->GetStoragePointer(), ConvertToStorage(arg)); } template ALWAYS_INLINE T Exchange(T arg) { - volatile StorageType * const p = this->GetStoragePointer(); - const StorageType s = ConvertToStorage(arg); - - StorageType current; - - if constexpr (Order == std::memory_order_relaxed) { - do { - current = impl::LoadExclusiveForAtomic(p); - } while (AMS_UNLIKELY(!impl::StoreExclusiveForAtomic(p, s))); - } else if constexpr (Order == std::memory_order_consume || Order == std::memory_order_acquire) { - do { - current = impl::LoadAcquireExclusiveForAtomic(p); - } while (AMS_UNLIKELY(!impl::StoreExclusiveForAtomic(p, s))); - } else if constexpr (Order == std::memory_order_release) { - do { - current = impl::LoadExclusiveForAtomic(p); - } while (AMS_UNLIKELY(!impl::StoreReleaseExclusiveForAtomic(p, s))); - } else if constexpr (Order == std::memory_order_acq_rel || Order == std::memory_order_seq_cst) { - do { - current = impl::LoadAcquireExclusiveForAtomic(p); - } while (AMS_UNLIKELY(!impl::StoreReleaseExclusiveForAtomic(p, s))); - } else { - static_assert(Order != Order, "Invalid memory order"); - } - - return current; + return ConvertToType(impl::AtomicExchangeImpl(this->GetStoragePointer(), ConvertToStorage(arg))); } template ALWAYS_INLINE bool CompareExchangeWeak(T &expected, T desired) { - volatile StorageType * const p = this->GetStoragePointer(); - const StorageType e = ConvertToStorage(expected); - const StorageType d = ConvertToStorage(desired); - - if constexpr (Order == std::memory_order_relaxed) { - const StorageType current = impl::LoadExclusiveForAtomic(p); - if (AMS_UNLIKELY(current != e)) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - - return AMS_LIKELY(impl::StoreExclusiveForAtomic(p, d)); - } else if constexpr (Order == std::memory_order_consume || Order == std::memory_order_acquire) { - const StorageType current = impl::LoadAcquireExclusiveForAtomic(p); - if (AMS_UNLIKELY(current != e)) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - - return AMS_LIKELY(impl::StoreExclusiveForAtomic(p, d)); - } else if constexpr (Order == std::memory_order_release) { - const StorageType current = impl::LoadExclusiveForAtomic(p); - if (AMS_UNLIKELY(current != e)) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - - return AMS_LIKELY(impl::StoreReleaseExclusiveForAtomic(p, d)); - } else if constexpr (Order == std::memory_order_acq_rel || Order == std::memory_order_seq_cst) { - const StorageType current = impl::LoadAcquireExclusiveForAtomic(p); - if (AMS_UNLIKELY(current != e)) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - - return AMS_LIKELY(impl::StoreReleaseExclusiveForAtomic(p, d)); - } else { - static_assert(Order != Order, "Invalid memory order"); - } + return impl::AtomicCompareExchangeWeakImpl(this->GetStoragePointer(), expected, desired); } template ALWAYS_INLINE bool CompareExchangeStrong(T &expected, T desired) { - volatile StorageType * const p = this->GetStoragePointer(); - const StorageType e = ConvertToStorage(expected); - const StorageType d = ConvertToStorage(desired); - - if constexpr (Order == std::memory_order_relaxed) { - StorageType current; - do { - if (current = impl::LoadExclusiveForAtomic(p); current != e) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - } while (!impl::StoreExclusiveForAtomic(p, d)); - } else if constexpr (Order == std::memory_order_consume || Order == std::memory_order_acquire) { - StorageType current; - do { - if (current = impl::LoadAcquireExclusiveForAtomic(p); current != e) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - } while (!impl::StoreExclusiveForAtomic(p, d)); - } else if constexpr (Order == std::memory_order_release) { - StorageType current; - do { - if (current = impl::LoadExclusiveForAtomic(p); current != e) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - } while (!impl::StoreReleaseExclusiveForAtomic(p, d)); - } else if constexpr (Order == std::memory_order_acq_rel || Order == std::memory_order_seq_cst) { - StorageType current; - do { - if (current = impl::LoadAcquireExclusiveForAtomic(p); current != e) { - impl::ClearExclusiveForAtomic(); - expected = ConvertToType(current); - return false; - } - } while (!impl::StoreReleaseExclusiveForAtomic(p, d)); - } else { - static_assert(Order != Order, "Invalid memory order"); - } - - return true; + return impl::AtomicCompareExchangeStrongImpl(this->GetStoragePointer(), expected, desired); } #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATOR_, _POINTER_ALLOWED_) \ @@ -341,5 +323,108 @@ namespace ams::util { ALWAYS_INLINE T operator--(int) { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1); } }; + template + class AtomicRef { + NON_MOVEABLE(AtomicRef); + public: + static constexpr size_t RequiredAlignment = std::max(sizeof(T), alignof(T)); + private: + using StorageType = impl::AtomicStorage; + static_assert(sizeof(StorageType) == sizeof(T)); + static_assert(alignof(StorageType) >= alignof(T)); + + static constexpr bool IsIntegral = std::integral; + static constexpr bool IsPointer = std::is_pointer::value; + + static constexpr bool HasArithmeticFunctions = IsIntegral || IsPointer; + + using DifferenceType = typename std::conditional::type>::type; + + static constexpr ALWAYS_INLINE T ConvertToType(StorageType s) { + return impl::ConvertToTypeForAtomic(s); + } + + static constexpr ALWAYS_INLINE StorageType ConvertToStorage(T arg) { + return impl::ConvertToStorageForAtomic(arg); + } + private: + volatile StorageType * const m_p; + private: + ALWAYS_INLINE volatile StorageType *GetStoragePointer() const { return m_p; } + public: + explicit ALWAYS_INLINE AtomicRef(T &t) : m_p(reinterpret_cast(std::addressof(t))) { /* ... */ } + ALWAYS_INLINE AtomicRef(const AtomicRef &) noexcept = default; + + AtomicRef() = delete; + AtomicRef &operator=(const AtomicRef &) = delete; + + ALWAYS_INLINE T operator=(T desired) const { return const_cast(this)->Store(desired); } + + ALWAYS_INLINE operator T() const { return this->Load(); } + + template + ALWAYS_INLINE T Load() const { + return ConvertToType(impl::AtomicLoadImpl(this->GetStoragePointer())); + } + + template + ALWAYS_INLINE void Store(T arg) const { + return impl::AtomicStoreImpl(this->GetStoragePointer(), ConvertToStorage(arg)); + } + + template + ALWAYS_INLINE T Exchange(T arg) const { + return ConvertToType(impl::AtomicExchangeImpl(this->GetStoragePointer(), ConvertToStorage(arg))); + } + + template + ALWAYS_INLINE bool CompareExchangeWeak(T &expected, T desired) const { + return impl::AtomicCompareExchangeWeakImpl(this->GetStoragePointer(), expected, desired); + } + + template + ALWAYS_INLINE bool CompareExchangeStrong(T &expected, T desired) const { + return impl::AtomicCompareExchangeStrongImpl(this->GetStoragePointer(), expected, desired); + } + + #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATOR_, _POINTER_ALLOWED_) \ + template::type> \ + ALWAYS_INLINE T Fetch ## _OPERATION_(DifferenceType arg) const { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + volatile StorageType * const p = this->GetStoragePointer(); \ + \ + StorageType current; \ + do { \ + current = impl::LoadAcquireExclusiveForAtomic(p); \ + } while (AMS_UNLIKELY(!impl::StoreReleaseExclusiveForAtomic(p, ConvertToStorage(ConvertToType(current) _OPERATOR_ arg)))); \ + return ConvertToType(current); \ + } \ + \ + template::type> \ + ALWAYS_INLINE T operator _OPERATOR_##=(DifferenceType arg) const { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + return this->Fetch ## _OPERATION_(arg) _OPERATOR_ arg; \ + } + + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Add, +, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Sub, -, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(And, &, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Or, |, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Xor, ^, false) + + #undef AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION + + template::type> + ALWAYS_INLINE T operator++() const { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1) + 1; } + + template::type> + ALWAYS_INLINE T operator++(int) const { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1); } + + template::type> + ALWAYS_INLINE T operator--() const { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1) - 1; } + + template::type> + ALWAYS_INLINE T operator--(int) const { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1); } + }; } \ No newline at end of file diff --git a/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp b/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp index cb458a136..e44df5600 100644 --- a/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp +++ b/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp @@ -74,6 +74,8 @@ namespace ams::util { return (m_v = desired); } + ALWAYS_INLINE operator T() const { return this->Load(); } + template ALWAYS_INLINE T Load() const { return m_v.load(Order); @@ -84,22 +86,21 @@ namespace ams::util { return m_v.store(Order); } - template + template ALWAYS_INLINE T Exchange(T arg) { return m_v.exchange(arg, Order); } - template + template ALWAYS_INLINE bool CompareExchangeWeak(T &expected, T desired) { return m_v.compare_exchange_weak(expected, desired, Order); } - template + template ALWAYS_INLINE bool CompareExchangeStrong(T &expected, T desired) { return m_v.compare_exchange_strong(expected, desired, Order); } - #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATION_LOWER_, _OPERATOR_, _POINTER_ALLOWED_) \ template::type> \ ALWAYS_INLINE T Fetch ## _OPERATION_(DifferenceType arg) { \ @@ -134,5 +135,89 @@ namespace ams::util { ALWAYS_INLINE T operator--(int) { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1); } }; + template + class AtomicRef { + NON_MOVEABLE(AtomicRef); + public: + static constexpr size_t RequiredAlignment = std::atomic_ref::required_alignment; + private: + static constexpr bool IsIntegral = std::integral; + static constexpr bool IsPointer = std::is_pointer::value; + + static constexpr bool HasArithmeticFunctions = IsIntegral || IsPointer; + + using DifferenceType = typename std::conditional::type>::type; + private: + static_assert(std::atomic_ref::is_always_lock_free); + private: + std::atomic_ref m_ref; + public: + explicit ALWAYS_INLINE AtomicRef(T &t) : m_ref(t) { /* ... */ } + ALWAYS_INLINE AtomicRef(const AtomicRef &) noexcept = default; + + AtomicRef() = delete; + AtomicRef &operator=(const AtomicRef &) = delete; + + ALWAYS_INLINE T operator=(T desired) const { return (m_ref = desired); } + + template + ALWAYS_INLINE T Load() const { + return m_ref.load(Order); + } + + template + ALWAYS_INLINE void Store(T arg) const { + return m_ref.store(arg, Order); + } + + template + ALWAYS_INLINE T Exchange(T arg) const { + return m_ref.exchange(arg, Order); + } + + template + ALWAYS_INLINE bool CompareExchangeWeak(T &expected, T desired) const { + return m_ref.compare_exchange_weak(expected, desired, Order); + } + + template + ALWAYS_INLINE bool CompareExchangeStrong(T &expected, T desired) const { + return m_ref.compare_exchange_strong(expected, desired, Order); + } + + #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATION_LOWER_, _OPERATOR_, _POINTER_ALLOWED_) \ + template::type> \ + ALWAYS_INLINE T Fetch ## _OPERATION_(DifferenceType arg) const { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + return m_ref.fetch_##_OPERATION_LOWER_(arg); \ + } \ + \ + template::type> \ + ALWAYS_INLINE T operator _OPERATOR_##=(DifferenceType arg) const { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + return this->Fetch##_OPERATION_(arg) _OPERATOR_ arg; \ + } + + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Add, add, +, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Sub, sub, -, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(And, and, &, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Or, or, |, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Xor, xor, ^, false) + + #undef AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION + + template::type> + ALWAYS_INLINE T operator++() const { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1) + 1; } + + template::type> + ALWAYS_INLINE T operator++(int) const { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1); } + + template::type> + ALWAYS_INLINE T operator--() const { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1) - 1; } + + template::type> + ALWAYS_INLINE T operator--(int) const { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1); } + }; + } \ No newline at end of file