mirror of
https://github.com/Atmosphere-NX/Atmosphere
synced 2024-12-23 04:41:12 +00:00
tma: impl helper services, cleanup hostside packets
This commit is contained in:
parent
46001263f8
commit
2572ae8378
10 changed files with 252 additions and 47 deletions
|
@ -18,8 +18,6 @@ def main(argc, argv):
|
||||||
print 'Waiting for connection...'
|
print 'Waiting for connection...'
|
||||||
c.wait_connected()
|
c.wait_connected()
|
||||||
print 'Connected!'
|
print 'Connected!'
|
||||||
while True:
|
|
||||||
c.send_packet('AAAAAAAA')
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
116
stratosphere/tma/client/Packet.py
Normal file
116
stratosphere/tma/client/Packet.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) 2018 Atmosphere-NX
|
||||||
|
#
|
||||||
|
# This program is free software; you can redistribute it and/or modify it
|
||||||
|
# under the terms and conditions of the GNU General Public License,
|
||||||
|
# version 2, as published by the Free Software Foundation.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
import zlib
|
||||||
|
import ServiceId
|
||||||
|
from struct import unpack as up, pack as pk
|
||||||
|
|
||||||
|
HEADER_SIZE = 0x28
|
||||||
|
|
||||||
|
def crc32(s):
|
||||||
|
return zlib.crc32(s) & 0xFFFFFFFF
|
||||||
|
|
||||||
|
class Packet():
|
||||||
|
def __init__(self):
|
||||||
|
self.service = 0
|
||||||
|
self.task = 0
|
||||||
|
self.cmd = 0
|
||||||
|
self.continuation = 0
|
||||||
|
self.version = 0
|
||||||
|
self.body_len = 0
|
||||||
|
self.body = ''
|
||||||
|
self.offset = 0
|
||||||
|
def load_header(self, header):
|
||||||
|
assert len(header) == HEADER_SIZE
|
||||||
|
self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, \
|
||||||
|
_, self.body_chk, self.hdr_chk = up('<IIHBBI16sII', header)
|
||||||
|
if crc32(header[:-4]) != self.hdr_chk:
|
||||||
|
raise ValueError('Invalid header checksum in received packet!')
|
||||||
|
def load_body(self, body):
|
||||||
|
assert len(body) == self.body_len
|
||||||
|
if crc32(body) != self.body_chk:
|
||||||
|
raise ValueError('Invalid body checksum in received packet!')
|
||||||
|
self.body = body
|
||||||
|
def get_data(self):
|
||||||
|
assert len(self.body) == self.body_len and self.body_len <= 0xE000
|
||||||
|
self.body_chk = crc32(self.body)
|
||||||
|
hdr = pk('<IIHBBIIIIII', self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, 0, 0, 0, 0, self.body_chk)
|
||||||
|
self.hdr_chk = crc32(hdr)
|
||||||
|
hdr += pk('<I', self.hdr_chk)
|
||||||
|
return hdr + self.body
|
||||||
|
def set_service(self, srv):
|
||||||
|
if type(srv) is str:
|
||||||
|
self.service = ServiceId.hash(srv)
|
||||||
|
else:
|
||||||
|
self.service = srv
|
||||||
|
return self
|
||||||
|
def set_task(self, t):
|
||||||
|
self.task = t
|
||||||
|
return self
|
||||||
|
def set_cmd(self, x):
|
||||||
|
self.cmd = x
|
||||||
|
return self
|
||||||
|
def set_continuation(self, c):
|
||||||
|
self.continuation = c
|
||||||
|
return self
|
||||||
|
def set_version(self, v):
|
||||||
|
self.version = v
|
||||||
|
return self
|
||||||
|
def reset_offset(self):
|
||||||
|
self.offset = 0
|
||||||
|
return self
|
||||||
|
def write_str(self, s):
|
||||||
|
self.body += s
|
||||||
|
self.body_len += len(s)
|
||||||
|
return self
|
||||||
|
def write_u8(self, x):
|
||||||
|
self.body += pk('<B', x & 0xFF)
|
||||||
|
self.body_len += 1
|
||||||
|
return self
|
||||||
|
def write_u16(self, x):
|
||||||
|
self.body += pk('<H', x & 0xFFFF)
|
||||||
|
self.body_len += 2
|
||||||
|
return self
|
||||||
|
def write_u32(self, x):
|
||||||
|
self.body += pk('<I', x & 0xFFFFFFFF)
|
||||||
|
self.body_len += 4
|
||||||
|
return self
|
||||||
|
def write_u64(self, x):
|
||||||
|
self.body += pk('<Q', x & 0xFFFFFFFFFFFFFFFF)
|
||||||
|
self.body_len += 8
|
||||||
|
return self
|
||||||
|
def read_str(self):
|
||||||
|
s = ''
|
||||||
|
while self.body[self.offset] != '\x00' and self.offset < self.body_len:
|
||||||
|
s += self.body[self.offset]
|
||||||
|
self.offset += 1
|
||||||
|
def read_u8(self):
|
||||||
|
x, = up('<B', self.body[self.offset:self.offset+1])
|
||||||
|
self.offset += 1
|
||||||
|
return x
|
||||||
|
def read_u16(self):
|
||||||
|
x, = up('<H', self.body[self.offset:self.offset+2])
|
||||||
|
self.offset += 2
|
||||||
|
return x
|
||||||
|
def read_u32(self):
|
||||||
|
x, = up('<I', self.body[self.offset:self.offset+4])
|
||||||
|
self.offset += 4
|
||||||
|
return x
|
||||||
|
def read_u64(self):
|
||||||
|
x, = up('<Q', self.body[self.offset:self.offset+8])
|
||||||
|
self.offset += 8
|
||||||
|
return x
|
||||||
|
def read_struct(self, format, sz):
|
||||||
|
x = up(format, self.body[self.offset:self.offset+sz])
|
||||||
|
self.offset += sz
|
||||||
|
return x
|
27
stratosphere/tma/client/ServiceId.py
Normal file
27
stratosphere/tma/client/ServiceId.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) 2018 Atmosphere-NX
|
||||||
|
#
|
||||||
|
# This program is free software; you can redistribute it and/or modify it
|
||||||
|
# under the terms and conditions of the GNU General Public License,
|
||||||
|
# version 2, as published by the Free Software Foundation.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
def hash(s):
|
||||||
|
h = ord(s[0]) & 0xFFFFFFFF
|
||||||
|
for c in s:
|
||||||
|
h = ((1000003 * h) ^ ord(c)) & 0xFFFFFFFF
|
||||||
|
h ^= len(s)
|
||||||
|
return h
|
||||||
|
|
||||||
|
USB_QUERY_TARGET = hash("USBQueryTarget")
|
||||||
|
USB_SEND_HOST_INFO = hash("USBSendHostInfo")
|
||||||
|
USB_CONNECT = hash("USBConnect")
|
||||||
|
USB_DISCONNECT = hash("USBDisconnect")
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@ from UsbInterface import UsbInterface
|
||||||
from threading import Thread, Condition
|
from threading import Thread, Condition
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import time
|
import time
|
||||||
|
import ServiceId
|
||||||
|
from Packet import Packet
|
||||||
|
|
||||||
class UsbConnection(UsbInterface):
|
class UsbConnection(UsbInterface):
|
||||||
# Auto connect thread func.
|
# Auto connect thread func.
|
||||||
|
@ -25,12 +27,6 @@ class UsbConnection(UsbInterface):
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
continue
|
continue
|
||||||
def recv_thread(connection):
|
def recv_thread(connection):
|
||||||
if connection.is_connected():
|
|
||||||
try:
|
|
||||||
# If we've previously been connected, PyUSB will read garbage...
|
|
||||||
connection.recv_packet()
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
while connection.is_connected():
|
while connection.is_connected():
|
||||||
try:
|
try:
|
||||||
connection.recv_packet()
|
connection.recv_packet()
|
||||||
|
@ -65,6 +61,7 @@ class UsbConnection(UsbInterface):
|
||||||
self.conn_thrd.start()
|
self.conn_thrd.start()
|
||||||
return self
|
return self
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.disconnect()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
print 'Closing!'
|
print 'Closing!'
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
@ -80,24 +77,43 @@ class UsbConnection(UsbInterface):
|
||||||
self.conn_lock.acquire()
|
self.conn_lock.acquire()
|
||||||
assert not self.connected
|
assert not self.connected
|
||||||
self.intf = intf
|
self.intf = intf
|
||||||
self.connected = True
|
|
||||||
self.conn_lock.notify()
|
try:
|
||||||
self.conn_lock.release()
|
# Perform Query + Connection handshake
|
||||||
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
|
self.intf.send_packet(Packet().set_service(ServiceId.USB_QUERY_TARGET))
|
||||||
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
|
query_resp = self.intf.read_packet()
|
||||||
self.recv_thrd.daemon = True
|
print 'Found Switch, Protocol version 0x%x' % query_resp.read_u32()
|
||||||
self.send_thrd.daemon = True
|
|
||||||
self.recv_thrd.start()
|
self.intf.send_packet(Packet().set_service(ServiceId.USB_SEND_HOST_INFO).write_u32(0).write_u32(0))
|
||||||
self.send_thrd.start()
|
|
||||||
|
self.intf.send_packet(Packet().set_service(ServiceId.USB_CONNECT))
|
||||||
|
resp = self.intf.read_packet()
|
||||||
|
|
||||||
|
# Spawn threads
|
||||||
|
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
|
||||||
|
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
|
||||||
|
self.recv_thrd.daemon = True
|
||||||
|
self.send_thrd.daemon = True
|
||||||
|
self.recv_thrd.start()
|
||||||
|
self.send_thrd.start()
|
||||||
|
self.connected = True
|
||||||
|
finally:
|
||||||
|
# Finish connection.
|
||||||
|
self.conn_lock.notify()
|
||||||
|
self.conn_lock.release()
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
self.conn_lock.acquire()
|
self.conn_lock.acquire()
|
||||||
if self.connected:
|
if self.connected:
|
||||||
self.connected = False
|
self.connected = False
|
||||||
|
self.intf.send_packet(Packet().set_service(ServiceId.USB_DISCONNECT))
|
||||||
self.conn_lock.release()
|
self.conn_lock.release()
|
||||||
def recv_packet(self):
|
def recv_packet(self):
|
||||||
hdr, body = self.intf.read_packet()
|
packet = self.intf.read_packet()
|
||||||
print('Got Packet: %s' % body.encode('hex'))
|
assert type(packet) is Packet
|
||||||
|
dat = packet.read_u64()
|
||||||
|
print('Got Packet: %08x' % dat)
|
||||||
def send_packet(self, packet):
|
def send_packet(self, packet):
|
||||||
|
assert type(packet) is Packet
|
||||||
self.send_lock.acquire()
|
self.send_lock.acquire()
|
||||||
if len(self.send_queue) == 0x40:
|
if len(self.send_queue) == 0x40:
|
||||||
self.send_lock.wait()
|
self.send_lock.wait()
|
||||||
|
|
|
@ -11,11 +11,8 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU General Public License
|
# You should have received a copy of the GNU General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
import usb, zlib
|
import usb
|
||||||
from struct import unpack as up, pack as pk
|
import Packet
|
||||||
|
|
||||||
def crc32(s):
|
|
||||||
return zlib.crc32(s) & 0xFFFFFFFF
|
|
||||||
|
|
||||||
class UsbInterface():
|
class UsbInterface():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -50,20 +47,16 @@ class UsbInterface():
|
||||||
def blocking_write(self, data):
|
def blocking_write(self, data):
|
||||||
self.ep_out.write(data, 0xFFFFFFFFFFFFFFFF)
|
self.ep_out.write(data, 0xFFFFFFFFFFFFFFFF)
|
||||||
def read_packet(self):
|
def read_packet(self):
|
||||||
hdr = self.blocking_read(0x28)
|
packet = Packet.Packet()
|
||||||
_, _, _, body_size, _, _, _, _, body_chk, hdr_chk = up('<IIIIIIIIII', hdr)
|
hdr = self.blocking_read(Packet.HEADER_SIZE)
|
||||||
if crc32(hdr[:-4]) != hdr_chk:
|
packet.load_header(hdr)
|
||||||
raise ValueError('Invalid header checksum in received packet!')
|
if packet.body_len:
|
||||||
body = self.blocking_read(body_size)
|
packet.load_body(self.blocking_read(packet.body_len))
|
||||||
if len(body) != body_size:
|
return packet
|
||||||
raise ValueError('Failed to receive packet body!')
|
def send_packet(self, packet):
|
||||||
elif crc32(body) != body_chk:
|
data = packet.get_data()
|
||||||
raise ValueError('Invalid body checksum in received packet!')
|
self.blocking_write(data[:Packet.HEADER_SIZE])
|
||||||
return (hdr, body)
|
if (len(data) > Packet.HEADER_SIZE):
|
||||||
def send_packet(self, body):
|
self.blocking_write(data[Packet.HEADER_SIZE:])
|
||||||
hdr = pk('<IIIIIIIII', 0, 0, 0, len(body), 0, 0, 0, 0, crc32(body))
|
|
||||||
hdr += pk('<I', crc32(hdr))
|
|
||||||
self.blocking_write(hdr)
|
|
||||||
self.blocking_write(body)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -190,7 +190,7 @@ class TmaPacket {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
TmaConnResult Read(const T &t) {
|
TmaConnResult Read(T &t) {
|
||||||
return Read(&t, sizeof(T));
|
return Read(&t, sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,13 @@ static constexpr u32 HashServiceName(const char *name) {
|
||||||
|
|
||||||
enum class TmaService : u32 {
|
enum class TmaService : u32 {
|
||||||
Invalid = 0,
|
Invalid = 0,
|
||||||
|
|
||||||
|
/* Special nodes, for facilitating connection over USB. */
|
||||||
|
UsbQueryTarget = HashServiceName("USBQueryTarget"),
|
||||||
|
UsbSendHostInfo = HashServiceName("USBSendHostInfo"),
|
||||||
|
UsbConnect = HashServiceName("USBConnect"),
|
||||||
|
UsbDisconnect = HashServiceName("USBDisconnect"),
|
||||||
|
|
||||||
|
|
||||||
TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */
|
TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */
|
||||||
};
|
};
|
||||||
|
|
|
@ -76,18 +76,54 @@ void TmaUsbConnection::RecvThreadFunc(void *arg) {
|
||||||
this_ptr->SetConnected(true);
|
this_ptr->SetConnected(true);
|
||||||
|
|
||||||
while (res == TmaConnResult::Success) {
|
while (res == TmaConnResult::Success) {
|
||||||
if (!this_ptr->IsConnected()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
TmaPacket *packet = this_ptr->AllocateRecvPacket();
|
TmaPacket *packet = this_ptr->AllocateRecvPacket();
|
||||||
if (packet == nullptr) { std::abort(); }
|
if (packet == nullptr) { std::abort(); }
|
||||||
|
|
||||||
res = TmaUsbComms::ReceivePacket(packet);
|
res = TmaUsbComms::ReceivePacket(packet);
|
||||||
|
|
||||||
if (res == TmaConnResult::Success) {
|
if (res == TmaConnResult::Success) {
|
||||||
TmaPacket *send_packet = this_ptr->AllocateSendPacket();
|
switch (packet->GetServiceId()) {
|
||||||
send_packet->Write<u64>(i++);
|
case TmaService::UsbQueryTarget: {
|
||||||
this_ptr->send_queue.Send(reinterpret_cast<uintptr_t>(send_packet));
|
this_ptr->SetConnected(false);
|
||||||
|
|
||||||
|
res = this_ptr->SendQueryReply(packet);
|
||||||
|
|
||||||
|
if (!this_ptr->has_woken_up) {
|
||||||
|
/* TODO: Cancel background work. */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TmaService::UsbSendHostInfo: {
|
||||||
|
struct {
|
||||||
|
u32 version;
|
||||||
|
u32 sleeping;
|
||||||
|
} host_info;
|
||||||
|
packet->Read<decltype(host_info)>(host_info);
|
||||||
|
|
||||||
|
if (!this_ptr->has_woken_up || !host_info.sleeping) {
|
||||||
|
/* TODO: Cancel background work. */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TmaService::UsbConnect: {
|
||||||
|
res = this_ptr->SendQueryReply(packet);
|
||||||
|
|
||||||
|
if (res == TmaConnResult::Success) {
|
||||||
|
this_ptr->SetConnected(true);
|
||||||
|
this_ptr->OnConnectionEvent(ConnectionEvent::Connected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TmaService::UsbDisconnect: {
|
||||||
|
this_ptr->SetConnected(false);
|
||||||
|
this_ptr->OnDisconnected();
|
||||||
|
|
||||||
|
/* TODO: Cancel background work. */
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
this_ptr->FreePacket(packet);
|
this_ptr->FreePacket(packet);
|
||||||
} else {
|
} else {
|
||||||
this_ptr->FreePacket(packet);
|
this_ptr->FreePacket(packet);
|
||||||
|
@ -153,3 +189,13 @@ TmaConnResult TmaUsbConnection::SendPacket(TmaPacket *packet) {
|
||||||
return TmaConnResult::Disconnected;
|
return TmaConnResult::Disconnected;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TmaConnResult TmaUsbConnection::SendQueryReply(TmaPacket *packet) {
|
||||||
|
packet->ClearOffset();
|
||||||
|
struct {
|
||||||
|
u32 version;
|
||||||
|
} target_info;
|
||||||
|
target_info.version = 0;
|
||||||
|
packet->Write<decltype(target_info)>(target_info);
|
||||||
|
return TmaUsbComms::SendPacket(packet);
|
||||||
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ class TmaUsbConnection : public TmaConnection {
|
||||||
static void SendThreadFunc(void *arg);
|
static void SendThreadFunc(void *arg);
|
||||||
static void RecvThreadFunc(void *arg);
|
static void RecvThreadFunc(void *arg);
|
||||||
static void OnUsbStateChange(void *this_ptr, u32 state);
|
static void OnUsbStateChange(void *this_ptr, u32 state);
|
||||||
|
TmaConnResult SendQueryReply(TmaPacket *packet);
|
||||||
void ClearSendQueue();
|
void ClearSendQueue();
|
||||||
void StartThreads();
|
void StartThreads();
|
||||||
void StopThreads();
|
void StopThreads();
|
||||||
|
|
|
@ -436,7 +436,7 @@ TmaConnResult TmaUsbComms::SendPacket(TmaPacket *packet) {
|
||||||
res = TmaConnResult::GeneralFailure;
|
res = TmaConnResult::GeneralFailure;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (res == TmaConnResult::Success) {
|
if (res == TmaConnResult::Success && 0 < body_len) {
|
||||||
/* Copy body to send buffer. */
|
/* Copy body to send buffer. */
|
||||||
packet->CopyBodyTo(g_send_data_buf);
|
packet->CopyBodyTo(g_send_data_buf);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue