tma: impl helper services, cleanup hostside packets

This commit is contained in:
Michael Scire 2018-11-06 20:20:07 -08:00
parent 46001263f8
commit 2572ae8378
10 changed files with 252 additions and 47 deletions

View file

@ -18,8 +18,6 @@ def main(argc, argv):
print 'Waiting for connection...' print 'Waiting for connection...'
c.wait_connected() c.wait_connected()
print 'Connected!' print 'Connected!'
while True:
c.send_packet('AAAAAAAA')
return 0 return 0
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -0,0 +1,116 @@
# Copyright (c) 2018 Atmosphere-NX
#
# This program is free software; you can redistribute it and/or modify it
# under the terms and conditions of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import zlib
import ServiceId
from struct import unpack as up, pack as pk
HEADER_SIZE = 0x28
def crc32(s):
return zlib.crc32(s) & 0xFFFFFFFF
class Packet():
def __init__(self):
self.service = 0
self.task = 0
self.cmd = 0
self.continuation = 0
self.version = 0
self.body_len = 0
self.body = ''
self.offset = 0
def load_header(self, header):
assert len(header) == HEADER_SIZE
self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, \
_, self.body_chk, self.hdr_chk = up('<IIHBBI16sII', header)
if crc32(header[:-4]) != self.hdr_chk:
raise ValueError('Invalid header checksum in received packet!')
def load_body(self, body):
assert len(body) == self.body_len
if crc32(body) != self.body_chk:
raise ValueError('Invalid body checksum in received packet!')
self.body = body
def get_data(self):
assert len(self.body) == self.body_len and self.body_len <= 0xE000
self.body_chk = crc32(self.body)
hdr = pk('<IIHBBIIIIII', self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, 0, 0, 0, 0, self.body_chk)
self.hdr_chk = crc32(hdr)
hdr += pk('<I', self.hdr_chk)
return hdr + self.body
def set_service(self, srv):
if type(srv) is str:
self.service = ServiceId.hash(srv)
else:
self.service = srv
return self
def set_task(self, t):
self.task = t
return self
def set_cmd(self, x):
self.cmd = x
return self
def set_continuation(self, c):
self.continuation = c
return self
def set_version(self, v):
self.version = v
return self
def reset_offset(self):
self.offset = 0
return self
def write_str(self, s):
self.body += s
self.body_len += len(s)
return self
def write_u8(self, x):
self.body += pk('<B', x & 0xFF)
self.body_len += 1
return self
def write_u16(self, x):
self.body += pk('<H', x & 0xFFFF)
self.body_len += 2
return self
def write_u32(self, x):
self.body += pk('<I', x & 0xFFFFFFFF)
self.body_len += 4
return self
def write_u64(self, x):
self.body += pk('<Q', x & 0xFFFFFFFFFFFFFFFF)
self.body_len += 8
return self
def read_str(self):
s = ''
while self.body[self.offset] != '\x00' and self.offset < self.body_len:
s += self.body[self.offset]
self.offset += 1
def read_u8(self):
x, = up('<B', self.body[self.offset:self.offset+1])
self.offset += 1
return x
def read_u16(self):
x, = up('<H', self.body[self.offset:self.offset+2])
self.offset += 2
return x
def read_u32(self):
x, = up('<I', self.body[self.offset:self.offset+4])
self.offset += 4
return x
def read_u64(self):
x, = up('<Q', self.body[self.offset:self.offset+8])
self.offset += 8
return x
def read_struct(self, format, sz):
x = up(format, self.body[self.offset:self.offset+sz])
self.offset += sz
return x

View file

@ -0,0 +1,27 @@
# Copyright (c) 2018 Atmosphere-NX
#
# This program is free software; you can redistribute it and/or modify it
# under the terms and conditions of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def hash(s):
h = ord(s[0]) & 0xFFFFFFFF
for c in s:
h = ((1000003 * h) ^ ord(c)) & 0xFFFFFFFF
h ^= len(s)
return h
USB_QUERY_TARGET = hash("USBQueryTarget")
USB_SEND_HOST_INFO = hash("USBSendHostInfo")
USB_CONNECT = hash("USBConnect")
USB_DISCONNECT = hash("USBDisconnect")

View file

@ -15,6 +15,8 @@ from UsbInterface import UsbInterface
from threading import Thread, Condition from threading import Thread, Condition
from collections import deque from collections import deque
import time import time
import ServiceId
from Packet import Packet
class UsbConnection(UsbInterface): class UsbConnection(UsbInterface):
# Auto connect thread func. # Auto connect thread func.
@ -25,12 +27,6 @@ class UsbConnection(UsbInterface):
except ValueError as e: except ValueError as e:
continue continue
def recv_thread(connection): def recv_thread(connection):
if connection.is_connected():
try:
# If we've previously been connected, PyUSB will read garbage...
connection.recv_packet()
except ValueError:
pass
while connection.is_connected(): while connection.is_connected():
try: try:
connection.recv_packet() connection.recv_packet()
@ -65,6 +61,7 @@ class UsbConnection(UsbInterface):
self.conn_thrd.start() self.conn_thrd.start()
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.disconnect()
time.sleep(1) time.sleep(1)
print 'Closing!' print 'Closing!'
time.sleep(1) time.sleep(1)
@ -80,24 +77,43 @@ class UsbConnection(UsbInterface):
self.conn_lock.acquire() self.conn_lock.acquire()
assert not self.connected assert not self.connected
self.intf = intf self.intf = intf
self.connected = True
self.conn_lock.notify() try:
self.conn_lock.release() # Perform Query + Connection handshake
self.intf.send_packet(Packet().set_service(ServiceId.USB_QUERY_TARGET))
query_resp = self.intf.read_packet()
print 'Found Switch, Protocol version 0x%x' % query_resp.read_u32()
self.intf.send_packet(Packet().set_service(ServiceId.USB_SEND_HOST_INFO).write_u32(0).write_u32(0))
self.intf.send_packet(Packet().set_service(ServiceId.USB_CONNECT))
resp = self.intf.read_packet()
# Spawn threads
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,)) self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,)) self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
self.recv_thrd.daemon = True self.recv_thrd.daemon = True
self.send_thrd.daemon = True self.send_thrd.daemon = True
self.recv_thrd.start() self.recv_thrd.start()
self.send_thrd.start() self.send_thrd.start()
self.connected = True
finally:
# Finish connection.
self.conn_lock.notify()
self.conn_lock.release()
def disconnect(self): def disconnect(self):
self.conn_lock.acquire() self.conn_lock.acquire()
if self.connected: if self.connected:
self.connected = False self.connected = False
self.intf.send_packet(Packet().set_service(ServiceId.USB_DISCONNECT))
self.conn_lock.release() self.conn_lock.release()
def recv_packet(self): def recv_packet(self):
hdr, body = self.intf.read_packet() packet = self.intf.read_packet()
print('Got Packet: %s' % body.encode('hex')) assert type(packet) is Packet
dat = packet.read_u64()
print('Got Packet: %08x' % dat)
def send_packet(self, packet): def send_packet(self, packet):
assert type(packet) is Packet
self.send_lock.acquire() self.send_lock.acquire()
if len(self.send_queue) == 0x40: if len(self.send_queue) == 0x40:
self.send_lock.wait() self.send_lock.wait()

View file

@ -11,11 +11,8 @@
# #
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import usb, zlib import usb
from struct import unpack as up, pack as pk import Packet
def crc32(s):
return zlib.crc32(s) & 0xFFFFFFFF
class UsbInterface(): class UsbInterface():
def __init__(self): def __init__(self):
@ -50,20 +47,16 @@ class UsbInterface():
def blocking_write(self, data): def blocking_write(self, data):
self.ep_out.write(data, 0xFFFFFFFFFFFFFFFF) self.ep_out.write(data, 0xFFFFFFFFFFFFFFFF)
def read_packet(self): def read_packet(self):
hdr = self.blocking_read(0x28) packet = Packet.Packet()
_, _, _, body_size, _, _, _, _, body_chk, hdr_chk = up('<IIIIIIIIII', hdr) hdr = self.blocking_read(Packet.HEADER_SIZE)
if crc32(hdr[:-4]) != hdr_chk: packet.load_header(hdr)
raise ValueError('Invalid header checksum in received packet!') if packet.body_len:
body = self.blocking_read(body_size) packet.load_body(self.blocking_read(packet.body_len))
if len(body) != body_size: return packet
raise ValueError('Failed to receive packet body!') def send_packet(self, packet):
elif crc32(body) != body_chk: data = packet.get_data()
raise ValueError('Invalid body checksum in received packet!') self.blocking_write(data[:Packet.HEADER_SIZE])
return (hdr, body) if (len(data) > Packet.HEADER_SIZE):
def send_packet(self, body): self.blocking_write(data[Packet.HEADER_SIZE:])
hdr = pk('<IIIIIIIII', 0, 0, 0, len(body), 0, 0, 0, 0, crc32(body))
hdr += pk('<I', crc32(hdr))
self.blocking_write(hdr)
self.blocking_write(body)

View file

@ -190,7 +190,7 @@ class TmaPacket {
} }
template<typename T> template<typename T>
TmaConnResult Read(const T &t) { TmaConnResult Read(T &t) {
return Read(&t, sizeof(T)); return Read(&t, sizeof(T));
} }

View file

@ -34,5 +34,13 @@ static constexpr u32 HashServiceName(const char *name) {
enum class TmaService : u32 { enum class TmaService : u32 {
Invalid = 0, Invalid = 0,
/* Special nodes, for facilitating connection over USB. */
UsbQueryTarget = HashServiceName("USBQueryTarget"),
UsbSendHostInfo = HashServiceName("USBSendHostInfo"),
UsbConnect = HashServiceName("USBConnect"),
UsbDisconnect = HashServiceName("USBDisconnect"),
TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */ TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */
}; };

View file

@ -76,18 +76,54 @@ void TmaUsbConnection::RecvThreadFunc(void *arg) {
this_ptr->SetConnected(true); this_ptr->SetConnected(true);
while (res == TmaConnResult::Success) { while (res == TmaConnResult::Success) {
if (!this_ptr->IsConnected()) {
break;
}
TmaPacket *packet = this_ptr->AllocateRecvPacket(); TmaPacket *packet = this_ptr->AllocateRecvPacket();
if (packet == nullptr) { std::abort(); } if (packet == nullptr) { std::abort(); }
res = TmaUsbComms::ReceivePacket(packet); res = TmaUsbComms::ReceivePacket(packet);
if (res == TmaConnResult::Success) { if (res == TmaConnResult::Success) {
TmaPacket *send_packet = this_ptr->AllocateSendPacket(); switch (packet->GetServiceId()) {
send_packet->Write<u64>(i++); case TmaService::UsbQueryTarget: {
this_ptr->send_queue.Send(reinterpret_cast<uintptr_t>(send_packet)); this_ptr->SetConnected(false);
res = this_ptr->SendQueryReply(packet);
if (!this_ptr->has_woken_up) {
/* TODO: Cancel background work. */
}
}
break;
case TmaService::UsbSendHostInfo: {
struct {
u32 version;
u32 sleeping;
} host_info;
packet->Read<decltype(host_info)>(host_info);
if (!this_ptr->has_woken_up || !host_info.sleeping) {
/* TODO: Cancel background work. */
}
}
break;
case TmaService::UsbConnect: {
res = this_ptr->SendQueryReply(packet);
if (res == TmaConnResult::Success) {
this_ptr->SetConnected(true);
this_ptr->OnConnectionEvent(ConnectionEvent::Connected);
}
}
break;
case TmaService::UsbDisconnect: {
this_ptr->SetConnected(false);
this_ptr->OnDisconnected();
/* TODO: Cancel background work. */
}
break;
default:
break;
}
this_ptr->FreePacket(packet); this_ptr->FreePacket(packet);
} else { } else {
this_ptr->FreePacket(packet); this_ptr->FreePacket(packet);
@ -153,3 +189,13 @@ TmaConnResult TmaUsbConnection::SendPacket(TmaPacket *packet) {
return TmaConnResult::Disconnected; return TmaConnResult::Disconnected;
} }
} }
TmaConnResult TmaUsbConnection::SendQueryReply(TmaPacket *packet) {
packet->ClearOffset();
struct {
u32 version;
} target_info;
target_info.version = 0;
packet->Write<decltype(target_info)>(target_info);
return TmaUsbComms::SendPacket(packet);
}

View file

@ -29,6 +29,7 @@ class TmaUsbConnection : public TmaConnection {
static void SendThreadFunc(void *arg); static void SendThreadFunc(void *arg);
static void RecvThreadFunc(void *arg); static void RecvThreadFunc(void *arg);
static void OnUsbStateChange(void *this_ptr, u32 state); static void OnUsbStateChange(void *this_ptr, u32 state);
TmaConnResult SendQueryReply(TmaPacket *packet);
void ClearSendQueue(); void ClearSendQueue();
void StartThreads(); void StartThreads();
void StopThreads(); void StopThreads();

View file

@ -436,7 +436,7 @@ TmaConnResult TmaUsbComms::SendPacket(TmaPacket *packet) {
res = TmaConnResult::GeneralFailure; res = TmaConnResult::GeneralFailure;
} }
if (res == TmaConnResult::Success) { if (res == TmaConnResult::Success && 0 < body_len) {
/* Copy body to send buffer. */ /* Copy body to send buffer. */
packet->CopyBodyTo(g_send_data_buf); packet->CopyBodyTo(g_send_data_buf);