refactor: Replace struct.pack/unpack.

This commit is contained in:
tecnovert 2024-02-02 18:53:44 +02:00
parent 7c9504e0cd
commit 8318961f0b
No known key found for this signature in database
GPG key ID: 8ED6D8750C4E3F93

View file

@ -24,7 +24,6 @@ import queue
import random import random
import select import select
import socket import socket
import struct
import hashlib import hashlib
import logging import logging
import secrets import secrets
@ -41,7 +40,7 @@ from basicswap.contrib.rfc6979 import (
START_TOKEN = 0xabcd START_TOKEN = 0xabcd
MSG_START_TOKEN = struct.pack('>H', START_TOKEN) MSG_START_TOKEN = START_TOKEN.to_bytes(2, 'big')
MSG_MAX_SIZE = 0x200000 # 2MB MSG_MAX_SIZE = 0x200000 # 2MB
@ -83,8 +82,8 @@ class MsgHandshake:
pass pass
def encode_aad(self): # Additional Authenticated Data def encode_aad(self): # Additional Authenticated Data
return struct.pack('>H', NetMessageTypes.HANDSHAKE) + \ return int(NetMessageTypes.HANDSHAKE).to_bytes(2, 'big') + \
struct.pack('>Q', self._timestamp) + \ self._timestamp.to_bytes(8, 'big') + \
self._ephem_pk self._ephem_pk
def encode(self): def encode(self):
@ -92,7 +91,7 @@ class MsgHandshake:
def decode(self, msg_mv): def decode(self, msg_mv):
o = 2 o = 2
self._timestamp = struct.unpack('>Q', msg_mv[o: o + 8])[0] self._timestamp = int.from_bytes(msg_mv[o: o + 8], 'big')
o += 8 o += 8
self._ephem_pk = bytes(msg_mv[o: o + 33]) self._ephem_pk = bytes(msg_mv[o: o + 33])
o += 33 o += 33
@ -333,7 +332,7 @@ class Network:
ss = k.ecdh(peer._pubkey) ss = k.ecdh(peer._pubkey)
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest() hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
peer._ke = hashed[:32] peer._ke = hashed[:32]
peer._km = hashed[32:] peer._km = hashed[32:]
@ -386,7 +385,7 @@ class Network:
nk = PrivateKey(self._network_key) nk = PrivateKey(self._network_key)
ss = nk.ecdh(msg._ephem_pk) ss = nk.ecdh(msg._ephem_pk)
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest() hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
peer._ke = hashed[:32] peer._ke = hashed[:32]
peer._km = hashed[32:] peer._km = hashed[32:]
@ -427,7 +426,7 @@ class Network:
mac = msg_mv[-16:] mac = msg_mv[-16:]
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac) plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
ping_nonce = struct.unpack('>I', plaintext[:4])[0] ping_nonce = int.from_bytes(plaintext[:4], 'big')
# Version is added to a ping following a handshake message # Version is added to a ping following a handshake message
if len(plaintext) >= 10: if len(plaintext) >= 10:
peer._ready = True peer._ready = True
@ -450,7 +449,7 @@ class Network:
mac = msg_mv[-16:] mac = msg_mv[-16:]
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac) plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
pong_nonce = struct.unpack('>I', plaintext[:4])[0] pong_nonce = int.from_bytes(plaintext[:4], 'big')
if pong_nonce == peer._ping_nonce: if pong_nonce == peer._ping_nonce:
peer._last_ping_rtt = (time.time_ns() // 1000) - peer._last_ping_at peer._last_ping_rtt = (time.time_ns() // 1000) - peer._last_ping_at
@ -462,14 +461,14 @@ class Network:
def send_ping(self, peer): def send_ping(self, peer):
ping_nonce = random.getrandbits(32) ping_nonce = random.getrandbits(32)
msg_bytes = struct.pack('>H', NetMessageTypes.PING) msg_bytes = int(NetMessageTypes.PING).to_bytes(2, 'big')
nonce = peer._sent_nonce[:24] nonce = peer._sent_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce) cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_bytes) cipher.update(msg_bytes)
cipher.update(nonce) cipher.update(nonce)
payload = struct.pack('>I', ping_nonce) payload = ping_nonce.to_bytes(4, 'big')
if peer._last_ping_at == 0: if peer._last_ping_at == 0:
payload += self._sc._version payload += self._sc._version
ct, mac = cipher.encrypt_and_digest(payload) ct, mac = cipher.encrypt_and_digest(payload)
@ -484,14 +483,14 @@ class Network:
self.send_msg(peer, msg_bytes) self.send_msg(peer, msg_bytes)
def send_pong(self, peer, ping_nonce): def send_pong(self, peer, ping_nonce):
msg_bytes = struct.pack('>H', NetMessageTypes.PONG) msg_bytes = int(NetMessageTypes.PONG).to_bytes(2, 'big')
nonce = peer._sent_nonce[:24] nonce = peer._sent_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce) cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_bytes) cipher.update(msg_bytes)
cipher.update(nonce) cipher.update(nonce)
payload = struct.pack('>I', ping_nonce) payload = ping_nonce.to_bytes(4, 'big')
ct, mac = cipher.encrypt_and_digest(payload) ct, mac = cipher.encrypt_and_digest(payload)
msg_bytes += ct + mac msg_bytes += ct + mac
@ -503,7 +502,7 @@ class Network:
msg_encoded = msg if isinstance(msg, bytes) else msg.encode() msg_encoded = msg if isinstance(msg, bytes) else msg.encode()
len_encoded = len(msg_encoded) len_encoded = len(msg_encoded)
msg_packed = bytearray(MSG_START_TOKEN) + struct.pack('>I', len_encoded) + msg_encoded msg_packed = bytearray(MSG_START_TOKEN) + len_encoded.to_bytes(4, 'big') + msg_encoded
peer._socket.sendall(msg_packed) peer._socket.sendall(msg_packed)
peer._bytes_sent += len_encoded peer._bytes_sent += len_encoded
@ -515,7 +514,7 @@ class Network:
try: try:
mv = memoryview(msg_bytes) mv = memoryview(msg_bytes)
o = 0 o = 0
msg_type = struct.unpack('>H', mv[o: o + 2])[0] msg_type = int.from_bytes(mv[o: o + 2], 'big')
if msg_type == NetMessageTypes.HANDSHAKE: if msg_type == NetMessageTypes.HANDSHAKE:
self.process_handshake(peer, mv) self.process_handshake(peer, mv)
elif msg_type == NetMessageTypes.PING: elif msg_type == NetMessageTypes.PING:
@ -548,13 +547,13 @@ class Network:
raise ValueError('Invalid start token') raise ValueError('Invalid start token')
o += 2 o += 2
msg_len = struct.unpack('>I', mv[o: o + 4])[0] msg_len = int.from_bytes(mv[o: o + 4], 'big')
o += 4 o += 4
if msg_len < 2 or msg_len > MSG_MAX_SIZE: if msg_len < 2 or msg_len > MSG_MAX_SIZE:
raise ValueError('Invalid data length') raise ValueError('Invalid data length')
# Precheck msg_type # Precheck msg_type
msg_type = struct.unpack('>H', mv[o: o + 2])[0] msg_type = int.from_bytes(mv[o: o + 2], 'big')
# o += 2 # Don't inc offset, msg includes type # o += 2 # Don't inc offset, msg includes type
if not NetMessageTypes.has_value(msg_type): if not NetMessageTypes.has_value(msg_type):
raise ValueError('Invalid msg type') raise ValueError('Invalid msg type')