diff --git a/libraries/libstratosphere/include/stratosphere/os.hpp b/libraries/libstratosphere/include/stratosphere/os.hpp index e0690a20b..d1ad8529f 100644 --- a/libraries/libstratosphere/include/stratosphere/os.hpp +++ b/libraries/libstratosphere/include/stratosphere/os.hpp @@ -48,4 +48,5 @@ #include #include #include +#include #include diff --git a/libraries/libstratosphere/include/stratosphere/os/os_barrier.hpp b/libraries/libstratosphere/include/stratosphere/os/os_barrier.hpp new file mode 100644 index 000000000..f7a7f5102 --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/os/os_barrier.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once +#include +#include +#include + +namespace ams::os { + + class Barrier { + NON_COPYABLE(Barrier); + NON_MOVEABLE(Barrier); + private: + BarrierType m_barrier; + public: + explicit Barrier(int num_threads) { + InitializeBarrier(std::addressof(m_barrier), num_threads); + } + + ~Barrier() { + FinalizeBarrier(std::addressof(m_barrier)); + } + + void Await() { + return AwaitBarrier(std::addressof(m_barrier)); + } + + operator BarrierType &() { + return m_barrier; + } + + operator const BarrierType &() const { + return m_barrier; + } + + BarrierType *GetBase() { + return std::addressof(m_barrier); + } + }; + +} diff --git a/libraries/libstratosphere/include/stratosphere/os/os_barrier_api.hpp b/libraries/libstratosphere/include/stratosphere/os/os_barrier_api.hpp new file mode 100644 index 000000000..58aca4055 --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/os/os_barrier_api.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once +#include + +namespace ams::os { + + struct BarrierType; + + void InitializeBarrier(BarrierType *barrier, int num_threads); + void FinalizeBarrier(BarrierType *barrier); + + void AwaitBarrier(BarrierType *barrier); + +} diff --git a/libraries/libstratosphere/include/stratosphere/os/os_barrier_types.hpp b/libraries/libstratosphere/include/stratosphere/os/os_barrier_types.hpp new file mode 100644 index 000000000..88a88762f --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/os/os_barrier_types.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once +#include +#include +#include + +namespace ams::os { + + struct BarrierType { + u16 max_threads; + u16 waiting_threads; + u32 base_counter_lower; + u32 base_counter_upper; + + impl::InternalCriticalSectionStorage cs_barrier; + impl::InternalConditionVariableStorage cv_gathered; + }; + static_assert(std::is_trivial::value); + +} diff --git a/libraries/libstratosphere/source/os/os_barrier.cpp b/libraries/libstratosphere/source/os/os_barrier.cpp new file mode 100644 index 000000000..dad6ecd9b --- /dev/null +++ b/libraries/libstratosphere/source/os/os_barrier.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include + +namespace ams::os { + + namespace { + + ALWAYS_INLINE bool IsBarrierInitialized(const BarrierType *barrier) { + return barrier->max_threads != 0; + } + + ALWAYS_INLINE u64 GetBarrierBaseCounterImpl(const BarrierType *barrier) { + /* Check pre-conditions. */ + AMS_ASSERT(util::GetReference(barrier->cs_barrier).IsLockedByCurrentThread()); + + /* Convert two u32s to u64. */ + return (static_cast(barrier->base_counter_lower) << 0) | (static_cast(barrier->base_counter_upper) << BITSIZEOF(barrier->base_counter_lower)); + } + + ALWAYS_INLINE void SetBarrierBaseCounterImpl(BarrierType *barrier, u64 value) { + /* Check pre-conditions. */ + AMS_ASSERT(util::GetReference(barrier->cs_barrier).IsLockedByCurrentThread()); + + /* Store as two u32s. */ + barrier->base_counter_lower = static_cast(value >> 0); + barrier->base_counter_upper = static_cast(value >> BITSIZEOF(barrier->base_counter_lower)); + } + + } + + void InitializeBarrier(BarrierType *barrier, int num_threads) { + /* Check pre-conditions. */ + AMS_ASSERT(num_threads >= 1); + + /* Construct objects. */ + util::ConstructAt(barrier->cs_barrier); + util::ConstructAt(barrier->cv_gathered); + + /* Set member variables. */ + barrier->max_threads = num_threads; + barrier->waiting_threads = 0; + barrier->base_counter_lower = 0; + barrier->base_counter_upper = 0; + } + + void FinalizeBarrier(BarrierType *barrier) { + /* Check pre-conditions. */ + AMS_ASSERT(IsBarrierInitialized(barrier)); + AMS_ASSERT(barrier->waiting_threads == 0); + + /* Clear max threads. */ + barrier->max_threads = 0; + + /* Destroy objects. */ + util::DestroyAt(barrier->cs_barrier); + util::DestroyAt(barrier->cv_gathered); + } + + void AwaitBarrier(BarrierType *barrier) { + /* Check pre-conditions. */ + AMS_ASSERT(IsBarrierInitialized(barrier)); + + /* Await the barrier. */ + { + /* Acquire exclusive access to the barrier. */ + auto &cs = util::GetReference(barrier->cs_barrier); + std::scoped_lock lk(cs); + + /* Read barrier state. */ + const u64 base_counter = GetBarrierBaseCounterImpl(barrier); + const auto max_threads = barrier->max_threads; + auto waiting_threads = barrier->waiting_threads; + + /* Determine next base counter. */ + const u64 done_base_counter = base_counter + max_threads; + + /* Increment waiting threads. */ + ++waiting_threads; + + /* Check if all threads have synchronized. */ + if (waiting_threads >= max_threads) { + /* They have, so reset waiting thread count. */ + barrier->waiting_threads = 0; + + /* Set the updated base counter. */ + SetBarrierBaseCounterImpl(barrier, done_base_counter); + + /* Broadcast to our cv. */ + util::GetReference(barrier->cv_gathered).Broadcast(); + } else { + /* More threads are needed, so update waiting thread count. */ + barrier->waiting_threads = waiting_threads; + + /* Wait for remaining threads to await. */ + while (GetBarrierBaseCounterImpl(barrier) < done_base_counter) { + util::GetReference(barrier->cv_gathered).Wait(std::addressof(cs)); + } + } + } + } + +}