htc: implement remaining worker thread send logic (for channel mux)

This commit is contained in:
Michael Scire 2021-02-09 03:21:45 -08:00 committed by SciresM
parent df3d62df84
commit 4ed665bcd3
14 changed files with 346 additions and 13 deletions

View file

@ -15,6 +15,7 @@
*/ */
#include <stratosphere.hpp> #include <stratosphere.hpp>
#include "htclow_mux.hpp" #include "htclow_mux.hpp"
#include "../htclow_packet_factory.hpp"
#include "../ctrl/htclow_ctrl_state_machine.hpp" #include "../ctrl/htclow_ctrl_state_machine.hpp"
namespace ams::htclow::mux { namespace ams::htclow::mux {
@ -22,7 +23,7 @@ namespace ams::htclow::mux {
Mux::Mux(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm) Mux::Mux(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm)
: m_packet_factory(pf), m_state_machine(sm), m_task_manager(), m_event(os::EventClearMode_ManualClear), : m_packet_factory(pf), m_state_machine(sm), m_task_manager(), m_event(os::EventClearMode_ManualClear),
m_channel_impl_map(pf, sm, std::addressof(m_task_manager), std::addressof(m_event)), m_global_send_buffer(pf), m_channel_impl_map(pf, sm, std::addressof(m_task_manager), std::addressof(m_event)), m_global_send_buffer(pf),
m_mutex(), m_is_sleeping(false), m_version(ProtocolVersion) m_mutex(), m_state(MuxState::Normal), m_version(ProtocolVersion)
{ {
/* ... */ /* ... */
} }
@ -78,6 +79,50 @@ namespace ams::htclow::mux {
} }
} }
bool Mux::QuerySendPacket(PacketHeader *header, PacketBody *body, int *out_body_size) {
/* Lock ourselves. */
std::scoped_lock lk(m_mutex);
/* Get our map. */
auto &map = m_channel_impl_map.GetMap();
/* Iterate the map, checking for valid packet each time. */
for (auto &pair : map) {
/* Get the current channel impl. */
auto &channel_impl = m_channel_impl_map.GetChannelImpl(pair.second);
/* Check for an error packet. */
/* NOTE: it's unclear why Nintendo does this every iteration of the loop... */
if (auto *error_packet = m_global_send_buffer.GetNextPacket(); error_packet != nullptr) {
std::memcpy(header, error_packet->GetHeader(), sizeof(*header));
*out_body_size = 0;
return true;
}
/* See if the channel has something for us to send. */
if (channel_impl.QuerySendPacket(header, body, out_body_size)) {
return this->IsSendable(header->packet_type);
}
}
return false;
}
void Mux::RemovePacket(const PacketHeader &header) {
/* Lock ourselves. */
std::scoped_lock lk(m_mutex);
/* Remove the packet from the appropriate source. */
if (header.packet_type == PacketType_Error) {
m_global_send_buffer.RemovePacket();
} else if (m_channel_impl_map.Exists(header.channel)) {
m_channel_impl_map[header.channel].RemovePacket(header);
}
/* Notify the task manager. */
m_task_manager.NotifySendReady();
}
void Mux::UpdateChannelState() { void Mux::UpdateChannelState() {
/* Lock ourselves. */ /* Lock ourselves. */
std::scoped_lock lk(m_mutex); std::scoped_lock lk(m_mutex);
@ -95,9 +140,9 @@ namespace ams::htclow::mux {
/* Update whether we're sleeping. */ /* Update whether we're sleeping. */
if (m_state_machine->IsSleeping()) { if (m_state_machine->IsSleeping()) {
m_is_sleeping = true; m_state = MuxState::Sleep;
} else { } else {
m_is_sleeping = false; m_state = MuxState::Normal;
m_event.Signal(); m_event.Signal();
} }
} }
@ -108,8 +153,23 @@ namespace ams::htclow::mux {
} }
Result Mux::SendErrorPacket(impl::ChannelInternalType channel) { Result Mux::SendErrorPacket(impl::ChannelInternalType channel) {
/* TODO */ /* Create and send the packet. */
AMS_ABORT("Mux::SendErrorPacket"); R_TRY(m_global_send_buffer.AddPacket(m_packet_factory->MakeErrorPacket(channel)));
/* Signal our event. */
m_event.Signal();
return ResultSuccess();
}
bool Mux::IsSendable(PacketType packet_type) const {
switch (m_state) {
case MuxState::Normal:
return true;
case MuxState::Sleep:
return false;
AMS_UNREACHABLE_DEFAULT_CASE();
}
} }
} }

View file

@ -21,7 +21,13 @@
namespace ams::htclow::mux { namespace ams::htclow::mux {
enum class MuxState {
Normal,
Sleep,
};
class Mux { class Mux {
private:
private: private:
PacketFactory *m_packet_factory; PacketFactory *m_packet_factory;
ctrl::HtcctrlStateMachine *m_state_machine; ctrl::HtcctrlStateMachine *m_state_machine;
@ -30,7 +36,7 @@ namespace ams::htclow::mux {
ChannelImplMap m_channel_impl_map; ChannelImplMap m_channel_impl_map;
GlobalSendBuffer m_global_send_buffer; GlobalSendBuffer m_global_send_buffer;
os::SdkMutex m_mutex; os::SdkMutex m_mutex;
bool m_is_sleeping; MuxState m_state;
s16 m_version; s16 m_version;
public: public:
Mux(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm); Mux(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm);
@ -51,6 +57,8 @@ namespace ams::htclow::mux {
Result CheckChannelExist(impl::ChannelInternalType channel); Result CheckChannelExist(impl::ChannelInternalType channel);
Result SendErrorPacket(impl::ChannelInternalType channel); Result SendErrorPacket(impl::ChannelInternalType channel);
bool IsSendable(PacketType packet_type) const;
}; };
} }

View file

@ -132,6 +132,30 @@ namespace ams::htclow::mux {
return ResultSuccess(); return ResultSuccess();
} }
bool ChannelImpl::QuerySendPacket(PacketHeader *header, PacketBody *body, int *out_body_size) {
/* Check our send buffer. */
if (m_send_buffer.QueryNextPacket(header, body, out_body_size, m_next_max_data, m_total_send_size, m_share.has_value(), m_share.value_or(0))) {
/* Update tracking variables. */
if (header->packet_type == PacketType_Data) {
m_cur_max_data = m_next_max_data;
}
return true;
} else {
return false;
}
}
void ChannelImpl::RemovePacket(const PacketHeader &header) {
/* Remove the packet. */
m_send_buffer.RemovePacket(header);
/* Check if the send buffer is now empty. */
if (m_send_buffer.Empty()) {
m_task_manager->NotifySendBufferEmpty(m_channel);
}
}
void ChannelImpl::UpdateState() { void ChannelImpl::UpdateState() {
/* Check if shutdown must be forced. */ /* Check if shutdown must be forced. */
if (m_state_machine->IsUnsupportedServiceChannelToShutdown(m_channel)) { if (m_state_machine->IsUnsupportedServiceChannelToShutdown(m_channel)) {

View file

@ -43,7 +43,9 @@ namespace ams::htclow::mux {
RingBuffer m_receive_buffer; RingBuffer m_receive_buffer;
s16 m_version; s16 m_version;
ChannelConfig m_config; ChannelConfig m_config;
/* TODO: tracking variables. */ u64 m_total_send_size;
u64 m_next_max_data;
u64 m_cur_max_data;
u64 m_offset; u64 m_offset;
std::optional<u64> m_share; std::optional<u64> m_share;
os::Event m_state_change_event; os::Event m_state_change_event;
@ -55,6 +57,10 @@ namespace ams::htclow::mux {
Result ProcessReceivePacket(const PacketHeader &header, const void *body, size_t body_size); Result ProcessReceivePacket(const PacketHeader &header, const void *body, size_t body_size);
bool QuerySendPacket(PacketHeader *header, PacketBody *body, int *out_body_size);
void RemovePacket(const PacketHeader &header);
void UpdateState(); void UpdateState();
private: private:
void ShutdownForce(); void ShutdownForce();

View file

@ -40,13 +40,13 @@ namespace ams::htclow::mux {
public: public:
ChannelImplMap(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm, TaskManager *tm, os::Event *ev); ChannelImplMap(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm, TaskManager *tm, os::Event *ev);
ChannelImpl &GetChannelImpl(int index);
ChannelImpl &GetChannelImpl(impl::ChannelInternalType channel); ChannelImpl &GetChannelImpl(impl::ChannelInternalType channel);
bool Exists(impl::ChannelInternalType channel) const { bool Exists(impl::ChannelInternalType channel) const {
return m_map.find(channel) != m_map.end(); return m_map.find(channel) != m_map.end();
} }
private: private:
ChannelImpl &GetChannelImpl(int index);
public: public:
MapType &GetMap() { MapType &GetMap() {
return m_map; return m_map;

View file

@ -0,0 +1,52 @@
/*
* Copyright (c) 2018-2020 Atmosphère-NX
*
* This program is free software; you can redistribute it and/or modify it
* under the terms and conditions of the GNU General Public License,
* version 2, as published by the Free Software Foundation.
*
* This program is distributed in the hope it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include <stratosphere.hpp>
#include "htclow_mux_global_send_buffer.hpp"
#include "../htclow_packet_factory.hpp"
namespace ams::htclow::mux {
Packet *GlobalSendBuffer::GetNextPacket() {
if (!m_packet_list.empty()) {
return std::addressof(m_packet_list.front());
} else {
return nullptr;
}
}
Result GlobalSendBuffer::AddPacket(std::unique_ptr<Packet, PacketDeleter> ptr) {
/* Global send buffer only supports adding error packets. */
R_UNLESS(ptr->GetHeader()->packet_type == PacketType_Error, htclow::ResultInvalidArgument());
/* Check if we already have an error packet for the channel. */
for (const auto &packet : m_packet_list) {
R_SUCCEED_IF(packet.GetHeader()->channel == ptr->GetHeader()->channel);
}
/* We don't, so push back a new one. */
m_packet_list.push_back(*(ptr.release()));
return ResultSuccess();
}
void GlobalSendBuffer::RemovePacket() {
auto *packet = std::addressof(m_packet_list.front());
m_packet_list.pop_front();
m_packet_factory->Delete(packet);
}
}

View file

@ -33,6 +33,11 @@ namespace ams::htclow::mux {
PacketList m_packet_list; PacketList m_packet_list;
public: public:
GlobalSendBuffer(PacketFactory *pf) : m_packet_factory(pf), m_packet_list() { /* ... */ } GlobalSendBuffer(PacketFactory *pf) : m_packet_factory(pf), m_packet_list() { /* ... */ }
Packet *GetNextPacket();
Result AddPacket(std::unique_ptr<Packet, PacketDeleter> ptr);
void RemovePacket();
}; };
} }

View file

@ -45,4 +45,50 @@ namespace ams::htclow::mux {
return ResultSuccess(); return ResultSuccess();
} }
Result RingBuffer::Copy(void *dst, size_t size) {
/* Select buffer to discard from. */
void *buffer = m_is_read_only ? m_read_only_buffer : m_buffer;
R_UNLESS(buffer != nullptr, htclow::ResultChannelBufferHasNotEnoughData());
/* Verify that we have enough data. */
R_UNLESS(m_data_size >= size, htclow::ResultChannelBufferHasNotEnoughData());
/* Determine position and copy sizes. */
const size_t pos = m_offset;
const size_t left = std::min(m_buffer_size - pos, size);
const size_t over = size - left;
/* Copy. */
if (left != 0) {
std::memcpy(dst, static_cast<const u8 *>(buffer) + pos, left);
}
if (over != 0) {
std::memcpy(static_cast<u8 *>(dst) + left, buffer, over);
}
/* Mark that we can discard. */
m_can_discard = true;
return ResultSuccess();
}
Result RingBuffer::Discard(size_t size) {
/* Select buffer to discard from. */
void *buffer = m_is_read_only ? m_read_only_buffer : m_buffer;
R_UNLESS(buffer != nullptr, htclow::ResultChannelBufferHasNotEnoughData());
/* Verify that the data we're discarding has been read. */
R_UNLESS(m_can_discard, htclow::ResultChannelCannotDiscard());
/* Verify that we have enough data. */
R_UNLESS(m_data_size >= size, htclow::ResultChannelBufferHasNotEnoughData());
/* Discard. */
m_offset = (m_offset + size) % m_buffer_size;
m_data_size -= size;
m_can_discard = false;
return ResultSuccess();
}
} }

View file

@ -26,13 +26,17 @@ namespace ams::htclow::mux {
size_t m_buffer_size; size_t m_buffer_size;
size_t m_data_size; size_t m_data_size;
size_t m_offset; size_t m_offset;
bool m_has_copied; bool m_can_discard;
public: public:
RingBuffer() : m_buffer(), m_read_only_buffer(), m_is_read_only(true), m_buffer_size(), m_data_size(), m_offset(), m_has_copied(false) { /* ... */ } RingBuffer() : m_buffer(), m_read_only_buffer(), m_is_read_only(true), m_buffer_size(), m_data_size(), m_offset(), m_can_discard(false) { /* ... */ }
size_t GetDataSize() { return m_data_size; } size_t GetDataSize() { return m_data_size; }
Result Write(const void *data, size_t size); Result Write(const void *data, size_t size);
Result Copy(void *dst, size_t size);
Result Discard(size_t size);
}; };
} }

View file

@ -19,11 +19,106 @@
namespace ams::htclow::mux { namespace ams::htclow::mux {
bool SendBuffer::IsPriorPacket(PacketType packet_type) const {
return packet_type == PacketType_MaxData;
}
void SendBuffer::SetVersion(s16 version) { void SendBuffer::SetVersion(s16 version) {
/* Set version. */ /* Set version. */
m_version = version; m_version = version;
} }
void SendBuffer::MakeDataPacketHeader(PacketHeader *header, int body_size, s16 version, u64 share, u32 offset) const {
/* Set all packet fields. */
header->signature = HtcGen2Signature;
header->offset = offset;
header->reserved = 0;
header->version = version;
header->body_size = body_size;
header->channel = m_channel;
header->packet_type = PacketType_Data;
header->share = share;
}
void SendBuffer::CopyPacket(PacketHeader *header, PacketBody *body, int *out_body_size, const Packet &packet) {
/* Get the body size. */
const int body_size = packet.GetBodySize();
AMS_ASSERT(0 <= body_size && body_size <= static_cast<int>(sizeof(*body)));
/* Copy the header. */
std::memcpy(header, packet.GetHeader(), sizeof(*header));
/* Copy the body. */
std::memcpy(body, packet.GetBody(), body_size);
/* Set the output body size. */
*out_body_size = body_size;
}
bool SendBuffer::QueryNextPacket(PacketHeader *header, PacketBody *body, int *out_body_size, u64 max_data, u64 total_send_size, bool has_share, u64 share) {
/* Check for a max data packet. */
if (!m_packet_list.empty()) {
this->CopyPacket(header, body, out_body_size, m_packet_list.front());
return true;
}
/* Check that we have data. */
const auto ring_buffer_data_size = m_ring_buffer.GetDataSize();
if (ring_buffer_data_size > 0) {
return false;
}
/* Check that we're valid for flow control. */
if (m_flow_control_enabled && !has_share) {
return false;
}
/* Determine the sendable size. */
const auto offset = total_send_size - ring_buffer_data_size;
const auto sendable_size = std::min(share - offset, ring_buffer_data_size);
if (sendable_size == 0) {
return false;
}
/* We're additionally bound by the actual packet size. */
const auto data_size = std::min(sendable_size, m_max_packet_size);
/* Make data packet header. */
this->MakeDataPacketHeader(header, data_size, m_version, max_data, share);
/* Copy the data. */
R_ABORT_UNLESS(m_ring_buffer.Copy(body, data_size));
/* Set output body size. */
*out_body_size = data_size;
return true;
}
void SendBuffer::RemovePacket(const PacketHeader &header) {
/* Get the packet type. */
const auto packet_type = header.packet_type;
if (this->IsPriorPacket(packet_type)) {
/* Packet will be using our list. */
auto *packet = std::addressof(m_packet_list.front());
m_packet_list.pop_front();
m_packet_factory->Delete(packet);
} else {
/* Packet managed by ring buffer. */
AMS_ABORT_UNLESS(packet_type == PacketType_Data);
/* Discard the packet's data. */
const Result result = m_ring_buffer.Discard(header.body_size);
if (!htclow::ResultChannelCannotDiscard::Includes(result)) {
R_ABORT_UNLESS(result);
}
}
}
bool SendBuffer::Empty() {
return m_packet_list.empty() && m_ring_buffer.GetDataSize() == 0;
}
void SendBuffer::Clear() { void SendBuffer::Clear() {
while (!m_packet_list.empty()) { while (!m_packet_list.empty()) {
auto *packet = std::addressof(m_packet_list.front()); auto *packet = std::addressof(m_packet_list.front());

View file

@ -37,11 +37,23 @@ namespace ams::htclow::mux {
s16 m_version; s16 m_version;
bool m_flow_control_enabled; bool m_flow_control_enabled;
size_t m_max_packet_size; size_t m_max_packet_size;
private:
bool IsPriorPacket(PacketType packet_type) const;
void MakeDataPacketHeader(PacketHeader *header, int body_size, s16 version, u64 share, u32 offset) const;
void CopyPacket(PacketHeader *header, PacketBody *body, int *out_body_size, const Packet &packet);
public: public:
SendBuffer(impl::ChannelInternalType channel, PacketFactory *pf); SendBuffer(impl::ChannelInternalType channel, PacketFactory *pf);
void SetVersion(s16 version); void SetVersion(s16 version);
bool QueryNextPacket(PacketHeader *header, PacketBody *body, int *out_body_size, u64 max_data, u64 total_send_size, bool has_share, u64 share);
void RemovePacket(const PacketHeader &header);
bool Empty();
void Clear(); void Clear();
}; };

View file

@ -34,6 +34,22 @@ namespace ams::htclow::mux {
} }
} }
void TaskManager::NotifySendReady() {
for (auto i = 0; i < MaxTaskCount; ++i) {
if (m_valid[i] && m_tasks[i].type == TaskType_Send) {
this->CompleteTask(i, EventTrigger_SendReady);
}
}
}
void TaskManager::NotifySendBufferEmpty(impl::ChannelInternalType channel) {
for (auto i = 0; i < MaxTaskCount; ++i) {
if (m_valid[i] && m_tasks[i].channel == channel && m_tasks[i].type == TaskType_Flush) {
this->CompleteTask(i, EventTrigger_SendBufferEmpty);
}
}
}
void TaskManager::NotifyConnectReady() { void TaskManager::NotifyConnectReady() {
for (auto i = 0; i < MaxTaskCount; ++i) { for (auto i = 0; i < MaxTaskCount; ++i) {
if (m_valid[i] && m_tasks[i].type == TaskType_Connect) { if (m_valid[i] && m_tasks[i].type == TaskType_Connect) {

View file

@ -23,6 +23,8 @@ namespace ams::htclow::mux {
enum EventTrigger : u8 { enum EventTrigger : u8 {
EventTrigger_Disconnect = 1, EventTrigger_Disconnect = 1,
EventTrigger_ReceiveData = 2, EventTrigger_ReceiveData = 2,
EventTrigger_SendReady = 5,
EventTrigger_SendBufferEmpty = 10,
EventTrigger_ConnectReady = 11, EventTrigger_ConnectReady = 11,
}; };
@ -51,6 +53,8 @@ namespace ams::htclow::mux {
void NotifyDisconnect(impl::ChannelInternalType channel); void NotifyDisconnect(impl::ChannelInternalType channel);
void NotifyReceiveData(impl::ChannelInternalType channel, size_t size); void NotifyReceiveData(impl::ChannelInternalType channel, size_t size);
void NotifySendReady();
void NotifySendBufferEmpty(impl::ChannelInternalType channel);
void NotifyConnectReady(); void NotifyConnectReady();
private: private:
void CompleteTask(int index, EventTrigger trigger); void CompleteTask(int index, EventTrigger trigger);

View file

@ -39,6 +39,7 @@ namespace ams::htclow {
R_DEFINE_ERROR_RESULT(ChannelStateTransitionError, 1104); R_DEFINE_ERROR_RESULT(ChannelStateTransitionError, 1104);
R_DEFINE_ERROR_RESULT(ChannelReceiveBufferEmpty, 1106); R_DEFINE_ERROR_RESULT(ChannelReceiveBufferEmpty, 1106);
R_DEFINE_ERROR_RESULT(ChannelSequenceIdNotMatched, 1107); R_DEFINE_ERROR_RESULT(ChannelSequenceIdNotMatched, 1107);
R_DEFINE_ERROR_RESULT(ChannelCannotDiscard, 1108);
R_DEFINE_ERROR_RANGE(DriverError, 1200, 1999); R_DEFINE_ERROR_RANGE(DriverError, 1200, 1999);
R_DEFINE_ERROR_RESULT(DriverOpened, 1201); R_DEFINE_ERROR_RESULT(DriverOpened, 1201);