diff --git a/libraries/libstratosphere/include/stratosphere/os/os_message_queue.hpp b/libraries/libstratosphere/include/stratosphere/os/os_message_queue.hpp index d7f6205e6..fc1fdeb31 100644 --- a/libraries/libstratosphere/include/stratosphere/os/os_message_queue.hpp +++ b/libraries/libstratosphere/include/stratosphere/os/os_message_queue.hpp @@ -46,17 +46,17 @@ namespace ams::os { return TimedSendMessageQueue(std::addressof(this->mq), data, timeout); } - /* Sending (LIFO functionality) */ - void SendNext(uintptr_t data) { - return SendNextMessageQueue(std::addressof(this->mq), data); + /* Jamming (LIFO functionality) */ + void Jam(uintptr_t data) { + return JamMessageQueue(std::addressof(this->mq), data); } - bool TrySendNext(uintptr_t data) { - return TrySendNextMessageQueue(std::addressof(this->mq), data); + bool TryJam(uintptr_t data) { + return TryJamMessageQueue(std::addressof(this->mq), data); } - bool TimedSendNext(uintptr_t data, TimeSpan timeout) { - return TimedSendNextMessageQueue(std::addressof(this->mq), data, timeout); + bool TimedJam(uintptr_t data, TimeSpan timeout) { + return TimedJamMessageQueue(std::addressof(this->mq), data, timeout); } /* Receive functionality */ diff --git a/libraries/libstratosphere/include/stratosphere/os/os_message_queue_api.hpp b/libraries/libstratosphere/include/stratosphere/os/os_message_queue_api.hpp index 51fc0a5c2..0f772a459 100644 --- a/libraries/libstratosphere/include/stratosphere/os/os_message_queue_api.hpp +++ b/libraries/libstratosphere/include/stratosphere/os/os_message_queue_api.hpp @@ -31,10 +31,10 @@ namespace ams::os { bool TrySendMessageQueue(MessageQueueType *mq, uintptr_t data); bool TimedSendMessageQueue(MessageQueueType *mq, uintptr_t data, TimeSpan timeout); - /* Sending (LIFO functionality) */ - void SendNextMessageQueue(MessageQueueType *mq, uintptr_t data); - bool TrySendNextMessageQueue(MessageQueueType *mq, uintptr_t data); - bool TimedSendNextMessageQueue(MessageQueueType *mq, uintptr_t data, TimeSpan timeout); + /* Jamming (LIFO functionality) */ + void JamMessageQueue(MessageQueueType *mq, uintptr_t data); + bool TryJamMessageQueue(MessageQueueType *mq, uintptr_t data); + bool TimedJamMessageQueue(MessageQueueType *mq, uintptr_t data, TimeSpan timeout); /* Receive functionality */ void ReceiveMessageQueue(uintptr_t *out, MessageQueueType *mq); diff --git a/libraries/libstratosphere/source/os/impl/os_message_queue_helper.hpp b/libraries/libstratosphere/source/os/impl/os_message_queue_helper.hpp new file mode 100644 index 000000000..f39ed66e0 --- /dev/null +++ b/libraries/libstratosphere/source/os/impl/os_message_queue_helper.hpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#pragma once +#include + +namespace ams::os::impl { + + template + concept IsMessageQueueType = requires(T &t) { + { t.buffer } -> std::convertible_to; + { t.offset } -> std::convertible_to; + { t.count } -> std::same_as; + { t.capacity } -> std::same_as; + }; + + template requires IsMessageQueueType + class MessageQueueHelper { + public: + static ALWAYS_INLINE bool IsMessageQueueFull(const T *mq) { + return mq->count >= mq->capacity; + } + + static ALWAYS_INLINE bool IsMessageQueueEmpty(const T *mq) { + return mq->count == 0; + } + + static void EnqueueUnsafe(T *mq, uintptr_t data) { + /* Ensure our limits are correct. */ + auto count = mq->count; + const auto capacity = mq->capacity; + AMS_ASSERT(count < capacity); + + /* Determine where we're writing. */ + auto ind = mq->offset + count; + if (ind >= capacity) { + ind -= capacity; + } + AMS_ASSERT(0 <= ind && ind < capacity); + + /* Write the data. */ + mq->buffer[ind] = data; + ++count; + + /* Update tracking. */ + mq->count = count; + } + + static uintptr_t DequeueUnsafe(T *mq) { + /* Ensure our limits are correct. */ + auto count = mq->count; + auto offset = mq->offset; + const auto capacity = mq->capacity; + AMS_ASSERT(count > 0); + AMS_ASSERT(offset >= 0 && offset < capacity); + + /* Get the data. */ + auto data = mq->buffer[offset++]; + + /* Calculate new tracking variables. */ + if (offset >= capacity) { + offset -= capacity; + } + --count; + + /* Update tracking. */ + mq->offset = offset; + mq->count = count; + + return data; + } + + static void JamUnsafe(T *mq, uintptr_t data) { + /* Ensure our limits are correct. */ + auto count = mq->count; + const auto capacity = mq->capacity; + AMS_ASSERT(count < capacity); + + /* Determine where we're writing. */ + auto offset = mq->offset - 1; + if (offset < 0) { + offset += capacity; + } + AMS_ASSERT(0 <= offset && offset < capacity); + + /* Write the data. */ + mq->buffer[offset] = data; + ++count; + + /* Update tracking. */ + mq->offset = offset; + mq->count = count; + } + + static uintptr_t PeekUnsafe(const T *mq) { + /* Ensure our limits are correct. */ + const auto count = mq->count; + const auto offset = mq->offset; + AMS_ASSERT(count > 0); + + return mq->buffer[offset]; + } + }; + +} diff --git a/libraries/libstratosphere/source/os/os_message_queue.cpp b/libraries/libstratosphere/source/os/os_message_queue.cpp index e950c07b8..bde4dfa83 100644 --- a/libraries/libstratosphere/source/os/os_message_queue.cpp +++ b/libraries/libstratosphere/source/os/os_message_queue.cpp @@ -17,94 +17,13 @@ #include "impl/os_timeout_helper.hpp" #include "impl/os_waitable_object_list.hpp" #include "impl/os_waitable_holder_impl.hpp" +#include "impl/os_message_queue_helper.hpp" namespace ams::os { namespace { - ALWAYS_INLINE bool IsMessageQueueFull(const MessageQueueType *mq) { - return mq->count >= mq->capacity; - } - - ALWAYS_INLINE bool IsMessageQueueEmpty(const MessageQueueType *mq) { - return mq->count == 0; - } - - void SendUnsafe(MessageQueueType *mq, uintptr_t data) { - /* Ensure our limits are correct. */ - auto count = mq->count; - auto capacity = mq->capacity; - AMS_ASSERT(count < capacity); - - /* Determine where we're writing. */ - auto ind = mq->offset + count; - if (ind >= capacity) { - ind -= capacity; - } - AMS_ASSERT(0 <= ind && ind < capacity); - - /* Write the data. */ - mq->buffer[ind] = data; - ++count; - - /* Update tracking. */ - mq->count = count; - } - - void SendNextUnsafe(MessageQueueType *mq, uintptr_t data) { - /* Ensure our limits are correct. */ - auto count = mq->count; - auto capacity = mq->capacity; - AMS_ASSERT(count < capacity); - - /* Determine where we're writing. */ - auto offset = mq->offset - 1; - if (offset < 0) { - offset += capacity; - } - AMS_ASSERT(0 <= offset && offset < capacity); - - /* Write the data. */ - mq->buffer[offset] = data; - ++count; - - /* Update tracking. */ - mq->offset = offset; - mq->count = count; - } - - uintptr_t ReceiveUnsafe(MessageQueueType *mq) { - /* Ensure our limits are correct. */ - auto count = mq->count; - auto offset = mq->offset; - auto capacity = mq->capacity; - AMS_ASSERT(count > 0); - AMS_ASSERT(offset >= 0 && offset < capacity); - - /* Get the data. */ - auto data = mq->buffer[offset]; - - /* Calculate new tracking variables. */ - if ((++offset) >= capacity) { - offset -= capacity; - } - --count; - - /* Update tracking. */ - mq->offset = offset; - mq->count = count; - - return data; - } - - uintptr_t PeekUnsafe(const MessageQueueType *mq) { - /* Ensure our limits are correct. */ - auto count = mq->count; - auto offset = mq->offset; - AMS_ASSERT(count > 0); - - return mq->buffer[offset]; - } + using MessageQueueHelper = impl::MessageQueueHelper; } @@ -158,12 +77,12 @@ namespace ams::os { /* Acquire mutex, wait sendable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueFull(mq)) { + while (MessageQueueHelper::IsMessageQueueFull(mq)) { GetReference(mq->cv_not_full).Wait(GetPointer(mq->cs_queue)); } /* Send, signal. */ - SendUnsafe(mq, data); + MessageQueueHelper::EnqueueUnsafe(mq, data); GetReference(mq->cv_not_empty).Broadcast(); GetReference(mq->waitlist_not_empty).SignalAllThreads(); } @@ -176,12 +95,12 @@ namespace ams::os { /* Acquire mutex, check sendable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - if (IsMessageQueueFull(mq)) { + if (MessageQueueHelper::IsMessageQueueFull(mq)) { return false; } /* Send, signal. */ - SendUnsafe(mq, data); + MessageQueueHelper::EnqueueUnsafe(mq, data); GetReference(mq->cv_not_empty).Broadcast(); GetReference(mq->waitlist_not_empty).SignalAllThreads(); } @@ -198,7 +117,7 @@ namespace ams::os { impl::TimeoutHelper timeout_helper(timeout); std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueFull(mq)) { + while (MessageQueueHelper::IsMessageQueueFull(mq)) { if (timeout_helper.TimedOut()) { return false; } @@ -206,7 +125,7 @@ namespace ams::os { } /* Send, signal. */ - SendUnsafe(mq, data); + MessageQueueHelper::EnqueueUnsafe(mq, data); GetReference(mq->cv_not_empty).Broadcast(); GetReference(mq->waitlist_not_empty).SignalAllThreads(); } @@ -214,38 +133,38 @@ namespace ams::os { return true; } - /* Sending (LIFO functionality) */ - void SendNextMessageQueue(MessageQueueType *mq, uintptr_t data) { + /* Jamming (LIFO functionality) */ + void JamMessageQueue(MessageQueueType *mq, uintptr_t data) { AMS_ASSERT(mq->state == MessageQueueType::State_Initialized); { /* Acquire mutex, wait sendable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueFull(mq)) { + while (MessageQueueHelper::IsMessageQueueFull(mq)) { GetReference(mq->cv_not_full).Wait(GetPointer(mq->cs_queue)); } /* Send, signal. */ - SendNextUnsafe(mq, data); + MessageQueueHelper::JamUnsafe(mq, data); GetReference(mq->cv_not_empty).Broadcast(); GetReference(mq->waitlist_not_empty).SignalAllThreads(); } } - bool TrySendNextMessageQueue(MessageQueueType *mq, uintptr_t data) { + bool TryJamMessageQueue(MessageQueueType *mq, uintptr_t data) { AMS_ASSERT(mq->state == MessageQueueType::State_Initialized); { /* Acquire mutex, check sendable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - if (IsMessageQueueFull(mq)) { + if (MessageQueueHelper::IsMessageQueueFull(mq)) { return false; } /* Send, signal. */ - SendNextUnsafe(mq, data); + MessageQueueHelper::JamUnsafe(mq, data); GetReference(mq->cv_not_empty).Broadcast(); GetReference(mq->waitlist_not_empty).SignalAllThreads(); } @@ -253,7 +172,7 @@ namespace ams::os { return true; } - bool TimedSendNextMessageQueue(MessageQueueType *mq, uintptr_t data, TimeSpan timeout) { + bool TimedJamMessageQueue(MessageQueueType *mq, uintptr_t data, TimeSpan timeout) { AMS_ASSERT(mq->state == MessageQueueType::State_Initialized); AMS_ASSERT(timeout.GetNanoSeconds() >= 0); @@ -262,7 +181,7 @@ namespace ams::os { impl::TimeoutHelper timeout_helper(timeout); std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueFull(mq)) { + while (MessageQueueHelper::IsMessageQueueFull(mq)) { if (timeout_helper.TimedOut()) { return false; } @@ -270,7 +189,7 @@ namespace ams::os { } /* Send, signal. */ - SendNextUnsafe(mq, data); + MessageQueueHelper::JamUnsafe(mq, data); GetReference(mq->cv_not_empty).Broadcast(); GetReference(mq->waitlist_not_empty).SignalAllThreads(); } @@ -286,12 +205,12 @@ namespace ams::os { /* Acquire mutex, wait receivable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueEmpty(mq)) { + while (MessageQueueHelper::IsMessageQueueEmpty(mq)) { GetReference(mq->cv_not_empty).Wait(GetPointer(mq->cs_queue)); } /* Receive, signal. */ - *out = ReceiveUnsafe(mq); + *out = MessageQueueHelper::DequeueUnsafe(mq); GetReference(mq->cv_not_full).Broadcast(); GetReference(mq->waitlist_not_full).SignalAllThreads(); } @@ -304,12 +223,12 @@ namespace ams::os { /* Acquire mutex, check receivable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - if (IsMessageQueueEmpty(mq)) { + if (MessageQueueHelper::IsMessageQueueEmpty(mq)) { return false; } /* Receive, signal. */ - *out = ReceiveUnsafe(mq); + *out = MessageQueueHelper::DequeueUnsafe(mq); GetReference(mq->cv_not_full).Broadcast(); GetReference(mq->waitlist_not_full).SignalAllThreads(); } @@ -326,7 +245,7 @@ namespace ams::os { impl::TimeoutHelper timeout_helper(timeout); std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueEmpty(mq)) { + while (MessageQueueHelper::IsMessageQueueEmpty(mq)) { if (timeout_helper.TimedOut()) { return false; } @@ -334,7 +253,7 @@ namespace ams::os { } /* Receive, signal. */ - *out = ReceiveUnsafe(mq); + *out = MessageQueueHelper::DequeueUnsafe(mq); GetReference(mq->cv_not_full).Broadcast(); GetReference(mq->waitlist_not_full).SignalAllThreads(); } @@ -350,12 +269,12 @@ namespace ams::os { /* Acquire mutex, wait receivable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueEmpty(mq)) { + while (MessageQueueHelper::IsMessageQueueEmpty(mq)) { GetReference(mq->cv_not_empty).Wait(GetPointer(mq->cs_queue)); } /* Peek. */ - *out = PeekUnsafe(mq); + *out = MessageQueueHelper::PeekUnsafe(mq); } } @@ -366,12 +285,12 @@ namespace ams::os { /* Acquire mutex, check receivable. */ std::scoped_lock lk(GetReference(mq->cs_queue)); - if (IsMessageQueueEmpty(mq)) { + if (MessageQueueHelper::IsMessageQueueEmpty(mq)) { return false; } /* Peek. */ - *out = PeekUnsafe(mq); + *out = MessageQueueHelper::PeekUnsafe(mq); } return true; @@ -386,7 +305,7 @@ namespace ams::os { impl::TimeoutHelper timeout_helper(timeout); std::scoped_lock lk(GetReference(mq->cs_queue)); - while (IsMessageQueueEmpty(mq)) { + while (MessageQueueHelper::IsMessageQueueEmpty(mq)) { if (timeout_helper.TimedOut()) { return false; } @@ -394,7 +313,7 @@ namespace ams::os { } /* Peek. */ - *out = PeekUnsafe(mq); + *out = MessageQueueHelper::PeekUnsafe(mq); } return true;