mirror of
https://github.com/Atmosphere-NX/Atmosphere
synced 2025-01-26 17:12:54 +00:00
tma: impl helper services, cleanup hostside packets
This commit is contained in:
parent
46001263f8
commit
2572ae8378
10 changed files with 252 additions and 47 deletions
|
@ -18,8 +18,6 @@ def main(argc, argv):
|
|||
print 'Waiting for connection...'
|
||||
c.wait_connected()
|
||||
print 'Connected!'
|
||||
while True:
|
||||
c.send_packet('AAAAAAAA')
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
116
stratosphere/tma/client/Packet.py
Normal file
116
stratosphere/tma/client/Packet.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) 2018 Atmosphere-NX
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify it
|
||||
# under the terms and conditions of the GNU General Public License,
|
||||
# version 2, as published by the Free Software Foundation.
|
||||
#
|
||||
# This program is distributed in the hope it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
# more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
import zlib
|
||||
import ServiceId
|
||||
from struct import unpack as up, pack as pk
|
||||
|
||||
HEADER_SIZE = 0x28
|
||||
|
||||
def crc32(s):
|
||||
return zlib.crc32(s) & 0xFFFFFFFF
|
||||
|
||||
class Packet():
|
||||
def __init__(self):
|
||||
self.service = 0
|
||||
self.task = 0
|
||||
self.cmd = 0
|
||||
self.continuation = 0
|
||||
self.version = 0
|
||||
self.body_len = 0
|
||||
self.body = ''
|
||||
self.offset = 0
|
||||
def load_header(self, header):
|
||||
assert len(header) == HEADER_SIZE
|
||||
self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, \
|
||||
_, self.body_chk, self.hdr_chk = up('<IIHBBI16sII', header)
|
||||
if crc32(header[:-4]) != self.hdr_chk:
|
||||
raise ValueError('Invalid header checksum in received packet!')
|
||||
def load_body(self, body):
|
||||
assert len(body) == self.body_len
|
||||
if crc32(body) != self.body_chk:
|
||||
raise ValueError('Invalid body checksum in received packet!')
|
||||
self.body = body
|
||||
def get_data(self):
|
||||
assert len(self.body) == self.body_len and self.body_len <= 0xE000
|
||||
self.body_chk = crc32(self.body)
|
||||
hdr = pk('<IIHBBIIIIII', self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, 0, 0, 0, 0, self.body_chk)
|
||||
self.hdr_chk = crc32(hdr)
|
||||
hdr += pk('<I', self.hdr_chk)
|
||||
return hdr + self.body
|
||||
def set_service(self, srv):
|
||||
if type(srv) is str:
|
||||
self.service = ServiceId.hash(srv)
|
||||
else:
|
||||
self.service = srv
|
||||
return self
|
||||
def set_task(self, t):
|
||||
self.task = t
|
||||
return self
|
||||
def set_cmd(self, x):
|
||||
self.cmd = x
|
||||
return self
|
||||
def set_continuation(self, c):
|
||||
self.continuation = c
|
||||
return self
|
||||
def set_version(self, v):
|
||||
self.version = v
|
||||
return self
|
||||
def reset_offset(self):
|
||||
self.offset = 0
|
||||
return self
|
||||
def write_str(self, s):
|
||||
self.body += s
|
||||
self.body_len += len(s)
|
||||
return self
|
||||
def write_u8(self, x):
|
||||
self.body += pk('<B', x & 0xFF)
|
||||
self.body_len += 1
|
||||
return self
|
||||
def write_u16(self, x):
|
||||
self.body += pk('<H', x & 0xFFFF)
|
||||
self.body_len += 2
|
||||
return self
|
||||
def write_u32(self, x):
|
||||
self.body += pk('<I', x & 0xFFFFFFFF)
|
||||
self.body_len += 4
|
||||
return self
|
||||
def write_u64(self, x):
|
||||
self.body += pk('<Q', x & 0xFFFFFFFFFFFFFFFF)
|
||||
self.body_len += 8
|
||||
return self
|
||||
def read_str(self):
|
||||
s = ''
|
||||
while self.body[self.offset] != '\x00' and self.offset < self.body_len:
|
||||
s += self.body[self.offset]
|
||||
self.offset += 1
|
||||
def read_u8(self):
|
||||
x, = up('<B', self.body[self.offset:self.offset+1])
|
||||
self.offset += 1
|
||||
return x
|
||||
def read_u16(self):
|
||||
x, = up('<H', self.body[self.offset:self.offset+2])
|
||||
self.offset += 2
|
||||
return x
|
||||
def read_u32(self):
|
||||
x, = up('<I', self.body[self.offset:self.offset+4])
|
||||
self.offset += 4
|
||||
return x
|
||||
def read_u64(self):
|
||||
x, = up('<Q', self.body[self.offset:self.offset+8])
|
||||
self.offset += 8
|
||||
return x
|
||||
def read_struct(self, format, sz):
|
||||
x = up(format, self.body[self.offset:self.offset+sz])
|
||||
self.offset += sz
|
||||
return x
|
27
stratosphere/tma/client/ServiceId.py
Normal file
27
stratosphere/tma/client/ServiceId.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) 2018 Atmosphere-NX
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify it
|
||||
# under the terms and conditions of the GNU General Public License,
|
||||
# version 2, as published by the Free Software Foundation.
|
||||
#
|
||||
# This program is distributed in the hope it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
# more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
def hash(s):
|
||||
h = ord(s[0]) & 0xFFFFFFFF
|
||||
for c in s:
|
||||
h = ((1000003 * h) ^ ord(c)) & 0xFFFFFFFF
|
||||
h ^= len(s)
|
||||
return h
|
||||
|
||||
USB_QUERY_TARGET = hash("USBQueryTarget")
|
||||
USB_SEND_HOST_INFO = hash("USBSendHostInfo")
|
||||
USB_CONNECT = hash("USBConnect")
|
||||
USB_DISCONNECT = hash("USBDisconnect")
|
||||
|
||||
|
|
@ -15,6 +15,8 @@ from UsbInterface import UsbInterface
|
|||
from threading import Thread, Condition
|
||||
from collections import deque
|
||||
import time
|
||||
import ServiceId
|
||||
from Packet import Packet
|
||||
|
||||
class UsbConnection(UsbInterface):
|
||||
# Auto connect thread func.
|
||||
|
@ -25,12 +27,6 @@ class UsbConnection(UsbInterface):
|
|||
except ValueError as e:
|
||||
continue
|
||||
def recv_thread(connection):
|
||||
if connection.is_connected():
|
||||
try:
|
||||
# If we've previously been connected, PyUSB will read garbage...
|
||||
connection.recv_packet()
|
||||
except ValueError:
|
||||
pass
|
||||
while connection.is_connected():
|
||||
try:
|
||||
connection.recv_packet()
|
||||
|
@ -65,6 +61,7 @@ class UsbConnection(UsbInterface):
|
|||
self.conn_thrd.start()
|
||||
return self
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.disconnect()
|
||||
time.sleep(1)
|
||||
print 'Closing!'
|
||||
time.sleep(1)
|
||||
|
@ -80,24 +77,43 @@ class UsbConnection(UsbInterface):
|
|||
self.conn_lock.acquire()
|
||||
assert not self.connected
|
||||
self.intf = intf
|
||||
self.connected = True
|
||||
self.conn_lock.notify()
|
||||
self.conn_lock.release()
|
||||
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
|
||||
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
|
||||
self.recv_thrd.daemon = True
|
||||
self.send_thrd.daemon = True
|
||||
self.recv_thrd.start()
|
||||
self.send_thrd.start()
|
||||
|
||||
try:
|
||||
# Perform Query + Connection handshake
|
||||
self.intf.send_packet(Packet().set_service(ServiceId.USB_QUERY_TARGET))
|
||||
query_resp = self.intf.read_packet()
|
||||
print 'Found Switch, Protocol version 0x%x' % query_resp.read_u32()
|
||||
|
||||
self.intf.send_packet(Packet().set_service(ServiceId.USB_SEND_HOST_INFO).write_u32(0).write_u32(0))
|
||||
|
||||
self.intf.send_packet(Packet().set_service(ServiceId.USB_CONNECT))
|
||||
resp = self.intf.read_packet()
|
||||
|
||||
# Spawn threads
|
||||
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
|
||||
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
|
||||
self.recv_thrd.daemon = True
|
||||
self.send_thrd.daemon = True
|
||||
self.recv_thrd.start()
|
||||
self.send_thrd.start()
|
||||
self.connected = True
|
||||
finally:
|
||||
# Finish connection.
|
||||
self.conn_lock.notify()
|
||||
self.conn_lock.release()
|
||||
def disconnect(self):
|
||||
self.conn_lock.acquire()
|
||||
if self.connected:
|
||||
self.connected = False
|
||||
self.intf.send_packet(Packet().set_service(ServiceId.USB_DISCONNECT))
|
||||
self.conn_lock.release()
|
||||
def recv_packet(self):
|
||||
hdr, body = self.intf.read_packet()
|
||||
print('Got Packet: %s' % body.encode('hex'))
|
||||
packet = self.intf.read_packet()
|
||||
assert type(packet) is Packet
|
||||
dat = packet.read_u64()
|
||||
print('Got Packet: %08x' % dat)
|
||||
def send_packet(self, packet):
|
||||
assert type(packet) is Packet
|
||||
self.send_lock.acquire()
|
||||
if len(self.send_queue) == 0x40:
|
||||
self.send_lock.wait()
|
||||
|
|
|
@ -11,11 +11,8 @@
|
|||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
import usb, zlib
|
||||
from struct import unpack as up, pack as pk
|
||||
|
||||
def crc32(s):
|
||||
return zlib.crc32(s) & 0xFFFFFFFF
|
||||
import usb
|
||||
import Packet
|
||||
|
||||
class UsbInterface():
|
||||
def __init__(self):
|
||||
|
@ -50,20 +47,16 @@ class UsbInterface():
|
|||
def blocking_write(self, data):
|
||||
self.ep_out.write(data, 0xFFFFFFFFFFFFFFFF)
|
||||
def read_packet(self):
|
||||
hdr = self.blocking_read(0x28)
|
||||
_, _, _, body_size, _, _, _, _, body_chk, hdr_chk = up('<IIIIIIIIII', hdr)
|
||||
if crc32(hdr[:-4]) != hdr_chk:
|
||||
raise ValueError('Invalid header checksum in received packet!')
|
||||
body = self.blocking_read(body_size)
|
||||
if len(body) != body_size:
|
||||
raise ValueError('Failed to receive packet body!')
|
||||
elif crc32(body) != body_chk:
|
||||
raise ValueError('Invalid body checksum in received packet!')
|
||||
return (hdr, body)
|
||||
def send_packet(self, body):
|
||||
hdr = pk('<IIIIIIIII', 0, 0, 0, len(body), 0, 0, 0, 0, crc32(body))
|
||||
hdr += pk('<I', crc32(hdr))
|
||||
self.blocking_write(hdr)
|
||||
self.blocking_write(body)
|
||||
packet = Packet.Packet()
|
||||
hdr = self.blocking_read(Packet.HEADER_SIZE)
|
||||
packet.load_header(hdr)
|
||||
if packet.body_len:
|
||||
packet.load_body(self.blocking_read(packet.body_len))
|
||||
return packet
|
||||
def send_packet(self, packet):
|
||||
data = packet.get_data()
|
||||
self.blocking_write(data[:Packet.HEADER_SIZE])
|
||||
if (len(data) > Packet.HEADER_SIZE):
|
||||
self.blocking_write(data[Packet.HEADER_SIZE:])
|
||||
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ class TmaPacket {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
TmaConnResult Read(const T &t) {
|
||||
TmaConnResult Read(T &t) {
|
||||
return Read(&t, sizeof(T));
|
||||
}
|
||||
|
||||
|
|
|
@ -34,5 +34,13 @@ static constexpr u32 HashServiceName(const char *name) {
|
|||
|
||||
enum class TmaService : u32 {
|
||||
Invalid = 0,
|
||||
|
||||
/* Special nodes, for facilitating connection over USB. */
|
||||
UsbQueryTarget = HashServiceName("USBQueryTarget"),
|
||||
UsbSendHostInfo = HashServiceName("USBSendHostInfo"),
|
||||
UsbConnect = HashServiceName("USBConnect"),
|
||||
UsbDisconnect = HashServiceName("USBDisconnect"),
|
||||
|
||||
|
||||
TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */
|
||||
};
|
||||
|
|
|
@ -76,18 +76,54 @@ void TmaUsbConnection::RecvThreadFunc(void *arg) {
|
|||
this_ptr->SetConnected(true);
|
||||
|
||||
while (res == TmaConnResult::Success) {
|
||||
if (!this_ptr->IsConnected()) {
|
||||
break;
|
||||
}
|
||||
TmaPacket *packet = this_ptr->AllocateRecvPacket();
|
||||
if (packet == nullptr) { std::abort(); }
|
||||
|
||||
res = TmaUsbComms::ReceivePacket(packet);
|
||||
|
||||
if (res == TmaConnResult::Success) {
|
||||
TmaPacket *send_packet = this_ptr->AllocateSendPacket();
|
||||
send_packet->Write<u64>(i++);
|
||||
this_ptr->send_queue.Send(reinterpret_cast<uintptr_t>(send_packet));
|
||||
switch (packet->GetServiceId()) {
|
||||
case TmaService::UsbQueryTarget: {
|
||||
this_ptr->SetConnected(false);
|
||||
|
||||
res = this_ptr->SendQueryReply(packet);
|
||||
|
||||
if (!this_ptr->has_woken_up) {
|
||||
/* TODO: Cancel background work. */
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TmaService::UsbSendHostInfo: {
|
||||
struct {
|
||||
u32 version;
|
||||
u32 sleeping;
|
||||
} host_info;
|
||||
packet->Read<decltype(host_info)>(host_info);
|
||||
|
||||
if (!this_ptr->has_woken_up || !host_info.sleeping) {
|
||||
/* TODO: Cancel background work. */
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TmaService::UsbConnect: {
|
||||
res = this_ptr->SendQueryReply(packet);
|
||||
|
||||
if (res == TmaConnResult::Success) {
|
||||
this_ptr->SetConnected(true);
|
||||
this_ptr->OnConnectionEvent(ConnectionEvent::Connected);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TmaService::UsbDisconnect: {
|
||||
this_ptr->SetConnected(false);
|
||||
this_ptr->OnDisconnected();
|
||||
|
||||
/* TODO: Cancel background work. */
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
this_ptr->FreePacket(packet);
|
||||
} else {
|
||||
this_ptr->FreePacket(packet);
|
||||
|
@ -153,3 +189,13 @@ TmaConnResult TmaUsbConnection::SendPacket(TmaPacket *packet) {
|
|||
return TmaConnResult::Disconnected;
|
||||
}
|
||||
}
|
||||
|
||||
TmaConnResult TmaUsbConnection::SendQueryReply(TmaPacket *packet) {
|
||||
packet->ClearOffset();
|
||||
struct {
|
||||
u32 version;
|
||||
} target_info;
|
||||
target_info.version = 0;
|
||||
packet->Write<decltype(target_info)>(target_info);
|
||||
return TmaUsbComms::SendPacket(packet);
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ class TmaUsbConnection : public TmaConnection {
|
|||
static void SendThreadFunc(void *arg);
|
||||
static void RecvThreadFunc(void *arg);
|
||||
static void OnUsbStateChange(void *this_ptr, u32 state);
|
||||
TmaConnResult SendQueryReply(TmaPacket *packet);
|
||||
void ClearSendQueue();
|
||||
void StartThreads();
|
||||
void StopThreads();
|
||||
|
|
|
@ -436,7 +436,7 @@ TmaConnResult TmaUsbComms::SendPacket(TmaPacket *packet) {
|
|||
res = TmaConnResult::GeneralFailure;
|
||||
}
|
||||
|
||||
if (res == TmaConnResult::Success) {
|
||||
if (res == TmaConnResult::Success && 0 < body_len) {
|
||||
/* Copy body to send buffer. */
|
||||
packet->CopyBodyTo(g_send_data_buf);
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue