From adc80eabb01fc9adb38b6ee99ea29b002077f778 Mon Sep 17 00:00:00 2001 From: tecnovert Date: Tue, 28 May 2024 23:17:54 +0200 Subject: [PATCH] Add simple protobuf encoder and decoder. --- basicswap/basicswap.py | 24 ++-- basicswap/basicswap_util.py | 2 +- basicswap/messages.proto | 2 +- basicswap/messages_npb.py | 244 ++++++++++++++++++++++++++++++++++ bin/basicswap_run.py | 3 + tests/basicswap/test_other.py | 28 ++++ 6 files changed, 293 insertions(+), 10 deletions(-) create mode 100644 basicswap/messages_npb.py diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 12af5c7..167d8fb 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -67,7 +67,6 @@ from .script import ( OpCodes, ) from .messages_pb2 import ( - OfferMessage, BidMessage, BidAcceptMessage, XmrBidMessage, @@ -80,6 +79,9 @@ from .messages_pb2 import ( ADSBidIntentMessage, ADSBidIntentAcceptMessage, ) +from .messages_npb import ( + OfferMessage, +) from .db import ( CURRENT_DB_VERSION, Concepts, @@ -1525,7 +1527,7 @@ class BasicSwap(BaseApp): coin_from_has_csv = self.coin_clients[coin_from]['use_csv'] coin_to_has_csv = self.coin_clients[coin_to]['use_csv'] - if lock_type == OfferMessage.SEQUENCE_LOCK_TIME: + if lock_type == TxLockTypes.SEQUENCE_LOCK_TIME: ensure(lock_value >= self.min_sequence_lock_seconds and lock_value <= self.max_sequence_lock_seconds, 'Invalid lock_value time') if swap_type == SwapTypes.XMR_SWAP: reverse_bid: bool = self.is_reverse_ads_bid(coin_from) @@ -1533,7 +1535,7 @@ class BasicSwap(BaseApp): ensure(itx_coin_has_csv, 'ITX coin needs CSV activated.') else: ensure(coin_from_has_csv and coin_to_has_csv, 'Both coins need CSV activated.') - elif lock_type == OfferMessage.SEQUENCE_LOCK_BLOCKS: + elif lock_type == TxLockTypes.SEQUENCE_LOCK_BLOCKS: ensure(lock_value >= 5 and lock_value <= 1000, 'Invalid lock_value blocks') if swap_type == SwapTypes.XMR_SWAP: reverse_bid: bool = self.is_reverse_ads_bid(coin_from) @@ -1680,7 +1682,7 @@ class BasicSwap(BaseApp): proof_addr, proof_sig, proof_utxos = self.getProofOfFunds(coin_from_t, int(amount), proof_of_funds_hash) # TODO: For now proof_of_funds is just a client side check, may need to be sent with offers in future however. - offer_bytes = msg_buf.SerializeToString() + offer_bytes = msg_buf.to_bytes() payload_hex = str.format('{:02x}', MessageTypes.OFFER) + offer_bytes.hex() msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) offer_id = self.sendSmsg(offer_addr, offer_addr_to, payload_hex, msg_valid) @@ -4808,11 +4810,17 @@ class BasicSwap(BaseApp): offer_bytes = bytes.fromhex(msg['hex'][2:-2]) offer_data = OfferMessage() - offer_data.ParseFromString(offer_bytes[:2]) - if offer_data.protocol_version < MINPROTO_VERSION or offer_data.protocol_version > MAXPROTO_VERSION: - self.log.warning(f'Incoming offer invalid protocol version: {offer_data.protocol_version} ') + try: + offer_data.from_bytes(offer_bytes[:2], init_all=False) + ensure(offer_data.protocol_version >= MINPROTO_VERSION and offer_data.protocol_version <= MAXPROTO_VERSION, 'protocol_version out of range') + except Exception as e: + self.log.warning('Incoming offer invalid protocol version: {}.'.format(getattr(offer_data, 'protocol_version', -1))) + return + try: + offer_data.from_bytes(offer_bytes) + except Exception as e: + self.log.warning('Failed to decode offer, protocol version: {}, {}.'.format(getattr(offer_data, 'protocol_version', -1), str(e))) return - offer_data.ParseFromString(offer_bytes) # Validate offer data now: int = self.getTime() diff --git a/basicswap/basicswap_util.py b/basicswap/basicswap_util.py index f4d8b3d..af7cecf 100644 --- a/basicswap/basicswap_util.py +++ b/basicswap/basicswap_util.py @@ -475,7 +475,7 @@ def getOfferProofOfFundsHash(offer_msg, offer_addr): # TODO: Hash must not include proof_of_funds sig if it exists in offer_msg h = hashlib.sha256() h.update(offer_addr.encode('utf-8')) - offer_bytes = offer_msg.SerializeToString() + offer_bytes = offer_msg.to_bytes() h.update(offer_bytes) return h.digest() diff --git a/basicswap/messages.proto b/basicswap/messages.proto index 8bd5b27..4fadc8e 100644 --- a/basicswap/messages.proto +++ b/basicswap/messages.proto @@ -51,7 +51,7 @@ message BidMessage { bytes proof_utxos = 9; /* 32 byte txid 2 byte vout, repeated */ /* optional */ - bytes pkhash_buyer_to = 13; /* When pubkey hash is different on the to-chain */ + bytes pkhash_buyer_to = 10; /* When pubkey hash is different on the to-chain */ } /* For tests */ diff --git a/basicswap/messages_npb.py b/basicswap/messages_npb.py new file mode 100644 index 0000000..ad17fe6 --- /dev/null +++ b/basicswap/messages_npb.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2024 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + + +''' +0 VARINT int32, int64, uint32, uint64, sint32, sint64, bool, enum +1 I64 fixed64, sfixed64, double +2 LEN string, bytes, embedded messages, packed repeated fields +5 I32 fixed32, sfixed32, float + +Don't encode fields of default values. +When decoding initialise all fields not set from data. +''' + +from basicswap.util.integer import encode_varint, decode_varint + + +class NonProtobufClass(): + def to_bytes(self) -> bytes: + rv = bytes() + + for field_num, v in self._map.items(): + field_name, wire_type, field_type = v + if not hasattr(self, field_name): + continue + field_value = getattr(self, field_name) + tag = (field_num << 3) | wire_type + if wire_type == 0: + if field_value == 0: + continue + rv += encode_varint(tag) + rv += encode_varint(field_value) + elif wire_type == 2: + if len(field_value) == 0: + continue + rv += encode_varint(tag) + if isinstance(field_value, str): + field_value = field_value.encode('utf-8') + rv += encode_varint(len(field_value)) + rv += field_value + else: + raise ValueError(f'Unknown wire_type {wire_type}') + return rv + + def from_bytes(self, b: bytes, init_all: bool = True) -> None: + max_len: int = len(b) + o: int = 0 + while o < max_len: + tag, lv = decode_varint(b, o) + o += lv + wire_type = tag & 7 + field_num = tag >> 3 + + field_name, wire_type_expect, field_type = self._map[field_num] + if wire_type != wire_type_expect: + raise ValueError(f'Unexpected wire_type {wire_type} for field {field_num}') + + if wire_type == 0: + field_value, lv = decode_varint(b, o) + o += lv + elif wire_type == 2: + field_len, lv = decode_varint(b, o) + o += lv + field_value = b[o: o + field_len] + o += field_len + if field_type == 1: + field_value = field_value.decode('utf-8') + else: + raise ValueError(f'Unknown wire_type {wire_type}') + + setattr(self, field_name, field_value) + + if not init_all: + return + # Set default values for missing fields + for field_num, v in self._map.items(): + field_name, wire_type, field_type = v + if hasattr(self, field_name): + continue + if wire_type == 0: + setattr(self, field_name, 0) + elif wire_type == 2: + if field_type == 1: + setattr(self, field_name, str()) + else: + setattr(self, field_name, bytes()) + else: + raise ValueError(f'Unknown wire_type {wire_type}') + + +class OfferMessage(NonProtobufClass): + _map = { + 1: ('protocol_version', 0, 0), + 2: ('coin_from', 0, 0), + 3: ('coin_to', 0, 0), + 4: ('amount_from', 0, 0), + 5: ('amount_to', 0, 0), + 6: ('min_bid_amount', 0, 0), + 7: ('time_valid', 0, 0), + 8: ('lock_type', 0, 0), + 9: ('lock_value', 0, 0), + 10: ('swap_type', 0, 0), + 11: ('proof_address', 2, 1), + 12: ('proof_signature', 2, 1), + 13: ('pkhash_seller', 2, 0), + 14: ('secret_hash', 2, 0), + 15: ('fee_rate_from', 0, 0), + 16: ('fee_rate_to', 0, 0), + 17: ('amount_negotiable', 0, 2), + 18: ('rate_negotiable', 0, 2), + 19: ('proof_utxos', 2, 0), + } + + +class BidMessage(NonProtobufClass): + _map = { + 1: ('protocol_version', 0, 0), + 2: ('offer_msg_id', 2, 0), + 3: ('time_valid', 0, 0), + 4: ('amount', 0, 0), + 5: ('amount_to', 0, 0), + 6: ('pkhash_buyer', 2, 0), + 7: ('proof_address', 2, 1), + 8: ('proof_signature', 2, 1), + 9: ('proof_utxos', 2, 0), + 10: ('pkhash_buyer_to', 2, 0), + } + + +class BidAcceptMessage(NonProtobufClass): + # Step 3, seller -> buyer + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('initiate_txid', 2, 0), + 3: ('contract_script', 2, 0), + 4: ('pkhash_seller', 2, 0), + } + + +class OfferRevokeMessage(NonProtobufClass): + _map = { + 1: ('offer_msg_id', 2, 0), + 2: ('signature', 2, 0), + } + + +class BidRejectMessage(NonProtobufClass): + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('reject_code', 0, 0), + } + + +class XmrBidMessage(NonProtobufClass): + # MSG1L, F -> L + _map = { + 1: ('protocol_version', 0, 0), + 2: ('offer_msg_id', 2, 0), + 3: ('time_valid', 0, 0), + 4: ('amount', 0, 0), + 5: ('amount_to', 0, 0), + 6: ('pkaf', 2, 0), + 7: ('kbvf', 2, 0), + 8: ('kbsf_dleag', 2, 0), + 9: ('dest_af', 2, 0), + } + + +class XmrSplitMessage(NonProtobufClass): + _map = { + 1: ('msg_id', 2, 0), + 2: ('msg_type', 0, 0), + 3: ('sequence', 0, 0), + 4: ('dleag', 2, 0), + } + + +class XmrBidAcceptMessage(NonProtobufClass): + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('pkal', 2, 0), + 3: ('kbvl', 2, 0), + 4: ('kbsl_dleag', 2, 0), + + # MSG2F + 5: ('a_lock_tx', 2, 0), + 6: ('a_lock_tx_script', 2, 0), + 7: ('a_lock_refund_tx', 2, 0), + 8: ('a_lock_refund_tx_script', 2, 0), + 9: ('a_lock_refund_spend_tx', 2, 0), + 10: ('al_lock_refund_tx_sig', 2, 0), + } + + +class XmrBidLockTxSigsMessage(NonProtobufClass): + # MSG3L + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('af_lock_refund_spend_tx_esig', 2, 0), + 3: ('af_lock_refund_tx_sig', 2, 0), + } + + +class XmrBidLockSpendTxMessage(NonProtobufClass): + # MSG4F + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('a_lock_spend_tx', 2, 0), + 3: ('kal_sig', 2, 0), + } + + +class XmrBidLockReleaseMessage(NonProtobufClass): + # MSG5F + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('al_lock_spend_tx_esig', 2, 0), + } + + +class ADSBidIntentMessage(NonProtobufClass): + # L -> F Sent from bidder, construct a reverse bid + _map = { + 1: ('protocol_version', 0, 0), + 2: ('offer_msg_id', 2, 0), + 3: ('time_valid', 0, 0), + 4: ('amount_from', 0, 0), + 5: ('amount_to', 0, 0), + } + + +class ADSBidIntentAcceptMessage(NonProtobufClass): + # F -> L Sent from offerer, construct a reverse bid + _map = { + 1: ('bid_msg_id', 2, 0), + 2: ('pkaf', 2, 0), + 3: ('kbvf', 2, 0), + 4: ('kbsf_dleag', 2, 0), + 5: ('dest_af', 2, 0), + } diff --git a/bin/basicswap_run.py b/bin/basicswap_run.py index d8c55ad..8d51f98 100755 --- a/bin/basicswap_run.py +++ b/bin/basicswap_run.py @@ -425,6 +425,9 @@ def main(): logger.warning('Unknown argument %s', v) + if os.name == 'nt': + logger.warning('Running on windows is discouraged and windows support may be discontinued in the future. Please consider using the WSL docker setup instead.') + if data_dir is None: data_dir = os.path.join(os.path.expanduser(cfg.BASICSWAP_DATADIR)) logger.info('Using datadir: %s', data_dir) diff --git a/tests/basicswap/test_other.py b/tests/basicswap/test_other.py index 84c9bad..5f4bacb 100644 --- a/tests/basicswap/test_other.py +++ b/tests/basicswap/test_other.py @@ -46,6 +46,10 @@ from basicswap.util import ( from basicswap.messages_pb2 import ( BidMessage, BidMessage_test, + OfferMessage, +) +from basicswap.messages_npb import ( + OfferMessage as OfferMessage_npb, ) from basicswap.contrib.test_framework.script import hash160 as hash160_btc @@ -433,6 +437,30 @@ class Test(unittest.TestCase): assert (msg_buf_v2.protocol_version == 2) assert (msg_buf_v2.time_valid == 0) + msg_buf = OfferMessage() + msg_buf.protocol_version = 2 + msg_buf.amount_from = 1024 + msg_buf.amount_to = 0 # test if it gets encoded + msg_buf.pkhash_seller = bytes((1,)) * 32 + msg_buf.proof_address = 'a string' + msg_buf.amount_negotiable = True + msg_buf.rate_negotiable = False + msg_buf.fee_rate_to = 2485 + pb_serialised_msg = msg_buf.SerializeToString() + + npb = OfferMessage_npb() + npb.from_bytes(pb_serialised_msg) + assert (npb.protocol_version == msg_buf.protocol_version) + assert (npb.amount_from == msg_buf.amount_from) + assert (npb.amount_to == msg_buf.amount_to) + assert (npb.pkhash_seller == msg_buf.pkhash_seller) + assert (npb.proof_address == msg_buf.proof_address) + assert (npb.amount_negotiable == msg_buf.amount_negotiable) + assert (npb.fee_rate_to == msg_buf.fee_rate_to) + + npb_serialised_msg = npb.to_bytes() + assert (npb_serialised_msg == pb_serialised_msg) + def test_is_private_ip_address(self): test_addresses = [ ('localhost', True),