Merge pull request from tecnovert/multinet

Add simplex chat test.
This commit is contained in:
tecnovert 2025-04-10 23:02:39 +00:00 committed by GitHub
commit 6777aff0b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 3168 additions and 591 deletions

View file

@ -122,8 +122,16 @@ from .explorers import (
ExplorerBitAps,
ExplorerChainz,
)
from .network.simplex import (
initialiseSimplexNetwork,
sendSimplexMsg,
readSimplexMsgs,
)
from .network.util import (
getMsgPubkey,
)
import basicswap.config as cfg
import basicswap.network as bsn
import basicswap.network.network as bsn
import basicswap.protocols.atomic_swap_1 as atomic_swap_1
import basicswap.protocols.xmr_swap_1 as xmr_swap_1
from .basicswap_util import (
@ -428,6 +436,9 @@ class BasicSwap(BaseApp):
self.swaps_in_progress = dict()
self.dleag_split_size_init = 16000
self.dleag_split_size = 17000
self.SMSG_SECONDS_IN_HOUR = (
60 * 60
) # Note: Set smsgsregtestadjust=0 for regtest
@ -526,6 +537,8 @@ class BasicSwap(BaseApp):
self._network = None
for t in self.threads:
if hasattr(t, "stop") and callable(t.stop):
t.stop()
t.join()
if sys.version_info[1] >= 9:
@ -1078,6 +1091,17 @@ class BasicSwap(BaseApp):
f"network_key {self.network_key}\nnetwork_pubkey {self.network_pubkey}\nnetwork_addr {self.network_addr}"
)
self.active_networks = []
network_config_list = self.settings.get("networks", [])
if len(network_config_list) < 1:
network_config_list = [{"type": "smsg", "enabled": True}]
for network in network_config_list:
if network["type"] == "smsg":
self.active_networks.append({"type": "smsg"})
elif network["type"] == "simplex":
initialiseSimplexNetwork(self, network)
ro = self.callrpc("smsglocalkeys")
found = False
for k in ro["smsg_keys"]:
@ -1655,6 +1679,33 @@ class BasicSwap(BaseApp):
bid_valid = (bid.expire_at - now) + 10 * 60 # Add 10 minute buffer
return max(smsg_min_valid, min(smsg_max_valid, bid_valid))
def sendMessage(
self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int, cursor
) -> bytes:
message_id: bytes = None
# First network in list will set message_id
for network in self.active_networks:
net_message_id = None
if network["type"] == "smsg":
net_message_id = self.sendSmsg(
addr_from, addr_to, payload_hex, msg_valid
)
elif network["type"] == "simplex":
net_message_id = sendSimplexMsg(
self,
network,
addr_from,
addr_to,
bytes.fromhex(payload_hex),
msg_valid,
cursor,
)
else:
raise ValueError("Unknown network: {}".format(network["type"]))
if not message_id:
message_id = net_message_id
return message_id
def sendSmsg(
self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int
) -> bytes:
@ -2200,7 +2251,9 @@ class BasicSwap(BaseApp):
offer_bytes = msg_buf.to_bytes()
payload_hex = str.format("{:02x}", MessageTypes.OFFER) + offer_bytes.hex()
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
offer_id = self.sendSmsg(offer_addr, offer_addr_to, payload_hex, msg_valid)
offer_id = self.sendMessage(
offer_addr, offer_addr_to, payload_hex, msg_valid, cursor
)
security_token = extra_options.get("security_token", None)
if security_token is not None and len(security_token) != 20:
@ -2305,8 +2358,8 @@ class BasicSwap(BaseApp):
)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, offer.time_valid)
msg_id = self.sendSmsg(
offer.addr_from, self.network_addr, payload_hex, msg_valid
msg_id = self.sendMessage(
offer.addr_from, self.network_addr, payload_hex, msg_valid, cursor
)
self.log.debug(
f"Revoked offer {self.log.id(offer_id)} in msg {self.log.id(msg_id)}"
@ -3152,7 +3205,9 @@ class BasicSwap(BaseApp):
bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid)
bid_id = self.sendMessage(
bid_addr, offer.addr_from, payload_hex, msg_valid, cursor
)
bid = Bid(
protocol_version=msg_buf.protocol_version,
@ -3488,8 +3543,8 @@ class BasicSwap(BaseApp):
)
msg_valid: int = self.getAcceptBidMsgValidTime(bid)
accept_msg_id = self.sendSmsg(
offer.addr_from, bid.bid_addr, payload_hex, msg_valid
accept_msg_id = self.sendMessage(
offer.addr_from, bid.bid_addr, payload_hex, msg_valid, cursor
)
self.addMessageLink(
@ -3519,20 +3574,29 @@ class BasicSwap(BaseApp):
dleag: bytes,
msg_valid: int,
bid_msg_ids,
cursor,
) -> None:
msg_buf2 = XmrSplitMessage(
msg_id=bid_id, msg_type=msg_type, sequence=1, dleag=dleag[16000:32000]
)
msg_bytes = msg_buf2.to_bytes()
payload_hex = str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex()
bid_msg_ids[1] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
msg_buf3 = XmrSplitMessage(
msg_id=bid_id, msg_type=msg_type, sequence=2, dleag=dleag[32000:]
)
msg_bytes = msg_buf3.to_bytes()
payload_hex = str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex()
bid_msg_ids[2] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
sent_bytes = self.dleag_split_size_init
num_sent = 1
while sent_bytes < len(dleag):
size_to_send: int = min(self.dleag_split_size, len(dleag) - sent_bytes)
msg_buf = XmrSplitMessage(
msg_id=bid_id,
msg_type=msg_type,
sequence=num_sent,
dleag=dleag[sent_bytes : sent_bytes + size_to_send],
)
msg_bytes = msg_buf.to_bytes()
payload_hex = (
str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex()
)
bid_msg_ids[num_sent] = self.sendMessage(
addr_from, addr_to, payload_hex, msg_valid, cursor
)
num_sent += 1
sent_bytes += size_to_send
def postXmrBid(
self, offer_id: bytes, amount: int, addr_send_from: str = None, extra_options={}
@ -3608,8 +3672,8 @@ class BasicSwap(BaseApp):
)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
xmr_swap.bid_id = self.sendSmsg(
bid_addr, offer.addr_from, payload_hex, msg_valid
xmr_swap.bid_id = self.sendMessage(
bid_addr, offer.addr_from, payload_hex, msg_valid, cursor
)
bid = Bid(
@ -3691,7 +3755,7 @@ class BasicSwap(BaseApp):
if ci_to.curve_type() == Curves.ed25519:
xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33]
msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[:16000]
msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[: self.dleag_split_size_init]
elif ci_to.curve_type() == Curves.secp256k1:
for i in range(10):
xmr_swap.kbsf_dleag = ci_to.signRecoverable(
@ -3721,8 +3785,8 @@ class BasicSwap(BaseApp):
bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
xmr_swap.bid_id = self.sendSmsg(
bid_addr, offer.addr_from, payload_hex, msg_valid
xmr_swap.bid_id = self.sendMessage(
bid_addr, offer.addr_from, payload_hex, msg_valid, cursor
)
bid_msg_ids = {}
@ -3735,6 +3799,7 @@ class BasicSwap(BaseApp):
xmr_swap.kbsf_dleag,
msg_valid,
bid_msg_ids,
cursor,
)
bid = Bid(
@ -4013,7 +4078,7 @@ class BasicSwap(BaseApp):
if ci_to.curve_type() == Curves.ed25519:
xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl)
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:16000]
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[: self.dleag_split_size_init]
elif ci_to.curve_type() == Curves.secp256k1:
for i in range(10):
xmr_swap.kbsl_dleag = ci_to.signRecoverable(
@ -4048,7 +4113,9 @@ class BasicSwap(BaseApp):
msg_valid: int = self.getAcceptBidMsgValidTime(bid)
bid_msg_ids = {}
bid_msg_ids[0] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
bid_msg_ids[0] = self.sendMessage(
addr_from, addr_to, payload_hex, msg_valid, use_cursor
)
if ci_to.curve_type() == Curves.ed25519:
self.sendXmrSplitMessages(
@ -4059,6 +4126,7 @@ class BasicSwap(BaseApp):
xmr_swap.kbsl_dleag,
msg_valid,
bid_msg_ids,
use_cursor,
)
bid.setState(BidStates.BID_ACCEPTED) # ADS
@ -4180,8 +4248,8 @@ class BasicSwap(BaseApp):
msg_buf.kbvf = kbvf
msg_buf.kbsf_dleag = (
xmr_swap.kbsf_dleag
if len(xmr_swap.kbsf_dleag) < 16000
else xmr_swap.kbsf_dleag[:16000]
if len(xmr_swap.kbsf_dleag) < self.dleag_split_size_init
else xmr_swap.kbsf_dleag[: self.dleag_split_size_init]
)
bid_bytes = msg_buf.to_bytes()
@ -4193,7 +4261,9 @@ class BasicSwap(BaseApp):
addr_to: str = bid.bid_addr
msg_valid: int = self.getAcceptBidMsgValidTime(bid)
bid_msg_ids = {}
bid_msg_ids[0] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
bid_msg_ids[0] = self.sendMessage(
addr_from, addr_to, payload_hex, msg_valid, use_cursor
)
if ci_to.curve_type() == Curves.ed25519:
self.sendXmrSplitMessages(
@ -4204,6 +4274,7 @@ class BasicSwap(BaseApp):
xmr_swap.kbsf_dleag,
msg_valid,
bid_msg_ids,
use_cursor,
)
bid.setState(BidStates.BID_REQUEST_ACCEPTED)
@ -6808,75 +6879,61 @@ class BasicSwap(BaseApp):
now: int = self.getTime()
ttl_xmr_split_messages = 60 * 60
bid_cursor = None
dleag_proof_len: int = 48893 # coincurve.dleag.dleag_proof_len()
try:
cursor = self.openDB()
bid_cursor = self.getNewDBCursor()
q_bids = self.query(
Bid, bid_cursor, {"state": int(BidStates.BID_RECEIVING)}
Bid,
bid_cursor,
{
"state": (
int(BidStates.BID_RECEIVING),
int(BidStates.BID_RECEIVING_ACC),
)
},
)
for bid in q_bids:
q = cursor.execute(
"SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type",
{"bid_id": bid.bid_id, "msg_type": int(XmrSplitMsgTypes.BID)},
).fetchone()
num_segments = q[0]
if num_segments > 1:
try:
self.receiveXmrBid(bid, cursor)
except Exception as ex:
self.log.info(
f"Verify adaptor-sig bid {self.log.id(bid.bid_id)} failed: {ex}"
)
if self.debug:
self.log.error(traceback.format_exc())
bid.setState(
BidStates.BID_ERROR, "Failed validation: " + str(ex)
)
self.updateDB(
bid,
cursor,
[
"bid_id",
],
)
self.updateBidInProgress(bid)
continue
if bid.created_at + ttl_xmr_split_messages < now:
self.log.debug(
f"Expiring partially received bid: {self.log.id(bid.bid_id)}."
)
bid.setState(BidStates.BID_ERROR, "Timed out")
self.updateDB(
bid,
cursor,
[
"bid_id",
],
)
q_bids = self.query(
Bid, bid_cursor, {"state": int(BidStates.BID_RECEIVING_ACC)}
)
for bid in q_bids:
q = cursor.execute(
"SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type",
"SELECT LENGTH(kbsl_dleag), LENGTH(kbsf_dleag) FROM xmr_swaps WHERE bid_id = :bid_id",
{
"bid_id": bid.bid_id,
"msg_type": int(XmrSplitMsgTypes.BID_ACCEPT),
},
).fetchone()
num_segments = q[0]
if num_segments > 1:
kbsl_dleag_len: int = q[0]
kbsf_dleag_len: int = q[1]
if bid.state == int(BidStates.BID_RECEIVING_ACC):
bid_type: str = "bid accept"
msg_type: int = int(XmrSplitMsgTypes.BID_ACCEPT)
total_dleag_size: int = kbsl_dleag_len
else:
bid_type: str = "bid"
msg_type: int = int(XmrSplitMsgTypes.BID)
total_dleag_size: int = kbsf_dleag_len
q = cursor.execute(
"SELECT COUNT(*), SUM(LENGTH(dleag)) AS total_dleag_size FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type",
{"bid_id": bid.bid_id, "msg_type": msg_type},
).fetchone()
total_dleag_size += 0 if q[1] is None else q[1]
if total_dleag_size >= dleag_proof_len:
try:
self.receiveXmrBidAccept(bid, cursor)
if bid.state == int(BidStates.BID_RECEIVING):
self.receiveXmrBid(bid, cursor)
elif bid.state == int(BidStates.BID_RECEIVING_ACC):
self.receiveXmrBidAccept(bid, cursor)
else:
raise ValueError("Unexpected bid state")
except Exception as ex:
self.log.info(
f"Verify adaptor-sig {bid_type} {self.log.id(bid.bid_id)} failed: {ex}"
)
if self.debug:
self.log.error(traceback.format_exc())
self.log.info(
f"Verify adaptor-sig bid accept {self.log.id(bid.bid_id)} failed: {ex}."
)
bid.setState(
BidStates.BID_ERROR, "Failed accept validation: " + str(ex)
BidStates.BID_ERROR, f"Failed {bid_type} validation: {ex}"
)
self.updateDB(
bid,
@ -6889,7 +6946,7 @@ class BasicSwap(BaseApp):
continue
if bid.created_at + ttl_xmr_split_messages < now:
self.log.debug(
f"Expiring partially received bid accept: {self.log.id(bid.bid_id)}."
f"Expiring partially received {bid_type}: {self.log.id(bid.bid_id)}."
)
bid.setState(BidStates.BID_ERROR, "Timed out")
self.updateDB(
@ -6899,7 +6956,6 @@ class BasicSwap(BaseApp):
"bid_id",
],
)
# Expire old records
cursor.execute(
"DELETE FROM xmr_split_data WHERE created_at + :ttl < :now",
@ -7029,6 +7085,7 @@ class BasicSwap(BaseApp):
if self.isOfferRevoked(offer_id, msg["from"]):
raise ValueError("Offer has been revoked {}.".format(offer_id.hex()))
pk_from: bytes = getMsgPubkey(self, msg)
try:
cursor = self.openDB()
# Offers must be received on the public network_addr or manually created addresses
@ -7069,6 +7126,7 @@ class BasicSwap(BaseApp):
rate_negotiable=offer_data.rate_negotiable,
addr_to=msg["to"],
addr_from=msg["from"],
pk_from=pk_from,
created_at=msg["sent"],
expire_at=msg["sent"] + offer_data.time_valid,
was_sent=False,
@ -7417,6 +7475,7 @@ class BasicSwap(BaseApp):
bid = self.getBid(bid_id)
if bid is None:
pk_from: bytes = getMsgPubkey(self, msg)
bid = Bid(
active_ind=1,
bid_id=bid_id,
@ -7431,6 +7490,7 @@ class BasicSwap(BaseApp):
created_at=msg["sent"],
expire_at=msg["sent"] + bid_data.time_valid,
bid_addr=msg["from"],
pk_bid_addr=pk_from,
was_received=True,
chain_a_height_start=ci_from.getChainHeight(),
chain_b_height_start=ci_to.getChainHeight(),
@ -7829,12 +7889,13 @@ class BasicSwap(BaseApp):
)
if ci_to.curve_type() == Curves.ed25519:
ensure(len(bid_data.kbsf_dleag) == 16000, "Invalid kbsf_dleag size")
ensure(len(bid_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size")
bid_id = bytes.fromhex(msg["msgid"])
bid, xmr_swap = self.getXmrBid(bid_id)
if bid is None:
pk_from: bytes = getMsgPubkey(self, msg)
bid = Bid(
active_ind=1,
bid_id=bid_id,
@ -7846,6 +7907,7 @@ class BasicSwap(BaseApp):
created_at=msg["sent"],
expire_at=msg["sent"] + bid_data.time_valid,
bid_addr=msg["from"],
pk_bid_addr=pk_from,
was_received=True,
chain_a_height_start=ci_from.getChainHeight(),
chain_b_height_start=ci_to.getChainHeight(),
@ -8175,8 +8237,8 @@ class BasicSwap(BaseApp):
msg_valid: int = self.getActiveBidMsgValidTime()
addr_send_from: str = offer.addr_from if reverse_bid else bid.bid_addr
addr_send_to: str = bid.bid_addr if reverse_bid else offer.addr_from
coin_a_lock_tx_sigs_l_msg_id = self.sendSmsg(
addr_send_from, addr_send_to, payload_hex, msg_valid
coin_a_lock_tx_sigs_l_msg_id = self.sendMessage(
addr_send_from, addr_send_to, payload_hex, msg_valid, cursor
)
self.addMessageLink(
Concepts.BID,
@ -8544,8 +8606,8 @@ class BasicSwap(BaseApp):
addr_send_from: str = bid.bid_addr if reverse_bid else offer.addr_from
addr_send_to: str = offer.addr_from if reverse_bid else bid.bid_addr
msg_valid: int = self.getActiveBidMsgValidTime()
coin_a_lock_release_msg_id = self.sendSmsg(
addr_send_from, addr_send_to, payload_hex, msg_valid
coin_a_lock_release_msg_id = self.sendMessage(
addr_send_from, addr_send_to, payload_hex, msg_valid, cursor
)
self.addMessageLink(
Concepts.BID,
@ -8964,8 +9026,8 @@ class BasicSwap(BaseApp):
)
msg_valid: int = self.getActiveBidMsgValidTime()
xmr_swap.coin_a_lock_refund_spend_tx_msg_id = self.sendSmsg(
addr_send_from, addr_send_to, payload_hex, msg_valid
xmr_swap.coin_a_lock_refund_spend_tx_msg_id = self.sendMessage(
addr_send_from, addr_send_to, payload_hex, msg_valid, cursor
)
bid.setState(BidStates.XMR_SWAP_MSG_SCRIPT_LOCK_SPEND_TX)
@ -9347,6 +9409,7 @@ class BasicSwap(BaseApp):
bid, xmr_swap = self.getXmrBid(bid_id)
if bid is None:
pk_from: bytes = getMsgPubkey(self, msg)
bid = Bid(
active_ind=1,
bid_id=bid_id,
@ -9358,6 +9421,7 @@ class BasicSwap(BaseApp):
created_at=msg["sent"],
expire_at=msg["sent"] + bid_data.time_valid,
bid_addr=msg["from"],
pk_bid_addr=pk_from,
was_sent=False,
was_received=True,
chain_a_height_start=ci_from.getChainHeight(),
@ -9460,7 +9524,7 @@ class BasicSwap(BaseApp):
"Invalid destination address",
)
if ci_to.curve_type() == Curves.ed25519:
ensure(len(msg_data.kbsf_dleag) == 16000, "Invalid kbsf_dleag size")
ensure(len(msg_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size")
xmr_swap.dest_af = msg_data.dest_af
xmr_swap.pkaf = msg_data.pkaf
@ -9495,6 +9559,14 @@ class BasicSwap(BaseApp):
def processMsg(self, msg) -> None:
try:
if "hex" not in msg:
if self.debug:
if "error" in msg:
self.log.debug(
"Message error {}: {}.".format(msg["msgid"], msg["error"])
)
raise ValueError("Invalid msg received {}.".format(msg["msgid"]))
return
msg_type = int(msg["hex"][:2], 16)
if msg_type == MessageTypes.OFFER:
@ -9708,6 +9780,10 @@ class BasicSwap(BaseApp):
self.processMsg(msg)
try:
for network in self.active_networks:
if network["type"] == "simplex":
readSimplexMsgs(self, network)
# TODO: Wait for blocks / txns, would need to check multiple coins
now: int = self.getTime()
self.expireBidsAndOffers(now)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -5,20 +5,22 @@
"""Helpful routines for regression testing."""
from base64 import b64encode
from binascii import unhexlify
from decimal import Decimal, ROUND_DOWN
from subprocess import CalledProcessError
import hashlib
import inspect
import json
import logging
import os
import random
import pathlib
import platform
import re
import time
from . import coverage
from .authproxy import AuthServiceProxy, JSONRPCException
from io import BytesIO
from collections.abc import Callable
from typing import Optional
logger = logging.getLogger("TestFramework.utils")
@ -28,23 +30,46 @@ logger = logging.getLogger("TestFramework.utils")
def assert_approx(v, vexp, vspan=0.00001):
"""Assert that `v` is within `vspan` of `vexp`"""
if isinstance(v, Decimal) or isinstance(vexp, Decimal):
v=Decimal(v)
vexp=Decimal(vexp)
vspan=Decimal(vspan)
if v < vexp - vspan:
raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan)))
if v > vexp + vspan:
raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan)))
def assert_fee_amount(fee, tx_size, fee_per_kB):
"""Assert the fee was in range"""
target_fee = round(tx_size * fee_per_kB / 1000, 8)
def assert_fee_amount(fee, tx_size, feerate_BTC_kvB):
"""Assert the fee is in range."""
assert isinstance(tx_size, int)
target_fee = get_fee(tx_size, feerate_BTC_kvB)
if fee < target_fee:
raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee)))
# allow the wallet's estimation to be at most 2 bytes off
if fee > (tx_size + 2) * fee_per_kB / 1000:
high_fee = get_fee(tx_size + 2, feerate_BTC_kvB)
if fee > high_fee:
raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee)))
def summarise_dict_differences(thing1, thing2):
if not isinstance(thing1, dict) or not isinstance(thing2, dict):
return thing1, thing2
d1, d2 = {}, {}
for k in sorted(thing1.keys()):
if k not in thing2:
d1[k] = thing1[k]
elif thing1[k] != thing2[k]:
d1[k], d2[k] = summarise_dict_differences(thing1[k], thing2[k])
for k in sorted(thing2.keys()):
if k not in thing1:
d2[k] = thing2[k]
return d1, d2
def assert_equal(thing1, thing2, *args):
if thing1 != thing2 and not args and isinstance(thing1, dict) and isinstance(thing2, dict):
d1,d2 = summarise_dict_differences(thing1, thing2)
raise AssertionError("not(%s == %s)\n in particular not(%s == %s)" % (thing1, thing2, d1, d2))
if thing1 != thing2 or any(thing1 != arg for arg in args):
raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args))
@ -79,7 +104,7 @@ def assert_raises_message(exc, message, fun, *args, **kwds):
raise AssertionError("No exception raised")
def assert_raises_process_error(returncode, output, fun, *args, **kwds):
def assert_raises_process_error(returncode: int, output: str, fun: Callable, *args, **kwds):
"""Execute a process and asserts the process return code and output.
Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError
@ -87,9 +112,9 @@ def assert_raises_process_error(returncode, output, fun, *args, **kwds):
no CalledProcessError was raised or if the return code and output are not as expected.
Args:
returncode (int): the process return code.
output (string): [a substring of] the process output.
fun (function): the function to call. This should execute a process.
returncode: the process return code.
output: [a substring of] the process output.
fun: the function to call. This should execute a process.
args*: positional arguments for the function.
kwds**: named arguments for the function.
"""
@ -104,7 +129,7 @@ def assert_raises_process_error(returncode, output, fun, *args, **kwds):
raise AssertionError("No exception raised")
def assert_raises_rpc_error(code, message, fun, *args, **kwds):
def assert_raises_rpc_error(code: Optional[int], message: Optional[str], fun: Callable, *args, **kwds):
"""Run an RPC and verify that a specific JSONRPC exception code and message is raised.
Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException
@ -112,11 +137,11 @@ def assert_raises_rpc_error(code, message, fun, *args, **kwds):
no JSONRPCException was raised or if the error code/message are not as expected.
Args:
code (int), optional: the error code returned by the RPC call (defined
in src/rpc/protocol.h). Set to None if checking the error code is not required.
message (string), optional: [a substring of] the error string returned by the
RPC call. Set to None if checking the error string is not required.
fun (function): the function to call. This should be the name of an RPC.
code: the error code returned by the RPC call (defined in src/rpc/protocol.h).
Set to None if checking the error code is not required.
message: [a substring of] the error string returned by the RPC call.
Set to None if checking the error string is not required.
fun: the function to call. This should be the name of an RPC.
args*: positional arguments for the function.
kwds**: named arguments for the function.
"""
@ -203,29 +228,45 @@ def check_json_precision():
raise RuntimeError("JSON encode/decode loses precision")
def EncodeDecimal(o):
if isinstance(o, Decimal):
return str(o)
raise TypeError(repr(o) + " is not JSON serializable")
def count_bytes(hex_string):
return len(bytearray.fromhex(hex_string))
def hex_str_to_bytes(hex_str):
return unhexlify(hex_str.encode('ascii'))
def str_to_b64str(string):
return b64encode(string.encode('utf-8')).decode('ascii')
def ceildiv(a, b):
"""
Divide 2 ints and round up to next int rather than round down
Implementation requires python integers, which have a // operator that does floor division.
Other types like decimal.Decimal whose // operator truncates towards 0 will not work.
"""
assert isinstance(a, int)
assert isinstance(b, int)
return -(-a // b)
def get_fee(tx_size, feerate_btc_kvb):
"""Calculate the fee in BTC given a feerate is BTC/kvB. Reflects CFeeRate::GetFee"""
feerate_sat_kvb = int(feerate_btc_kvb * Decimal(1e8)) # Fee in sat/kvb as an int to avoid float precision errors
target_fee_sat = ceildiv(feerate_sat_kvb * tx_size, 1000) # Round calculated fee up to nearest sat
return target_fee_sat / Decimal(1e8) # Return result in BTC
def satoshi_round(amount):
return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0):
def wait_until_helper_internal(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0):
"""Sleep until the predicate resolves to be True.
Warning: Note that this method is not recommended to be used in tests as it is
not aware of the context of the test framework. Using the `wait_until()` members
from `BitcoinTestFramework` or `P2PInterface` class ensures the timeout is
properly scaled. Furthermore, `wait_until()` from `P2PInterface` class in
`p2p.py` has a preset lock.
"""
if attempts == float('inf') and timeout == float('inf'):
timeout = 60
timeout = timeout * timeout_factor
@ -253,6 +294,16 @@ def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=N
raise RuntimeError('Unreachable')
def sha256sum_file(filename):
h = hashlib.sha256()
with open(filename, 'rb') as f:
d = f.read(4096)
while len(d) > 0:
h.update(d)
d = f.read(4096)
return h.digest()
# RPC/P2P connection constants and functions
############################################
@ -269,15 +320,15 @@ class PortSeed:
n = None
def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None):
def get_rpc_proxy(url: str, node_number: int, *, timeout: Optional[int]=None, coveragedir: Optional[str]=None) -> coverage.AuthServiceProxyWrapper:
"""
Args:
url (str): URL of the RPC server to call
node_number (int): the node number (or id) that this calls to
url: URL of the RPC server to call
node_number: the node number (or id) that this calls to
Kwargs:
timeout (int): HTTP timeout in seconds
coveragedir (str): Directory
timeout: HTTP timeout in seconds
coveragedir: Directory
Returns:
AuthServiceProxy. convenience object for making RPC calls.
@ -288,11 +339,10 @@ def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None):
proxy_kwargs['timeout'] = int(timeout)
proxy = AuthServiceProxy(url, **proxy_kwargs)
proxy.url = url # store URL on proxy for info
coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None
return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile)
return coverage.AuthServiceProxyWrapper(proxy, url, coverage_logfile)
def p2p_port(n):
@ -321,38 +371,76 @@ def rpc_url(datadir, i, chain, rpchost):
################
def initialize_datadir(dirname, n, chain):
def initialize_datadir(dirname, n, chain, disable_autoconnect=True):
datadir = get_datadir_path(dirname, n)
if not os.path.isdir(datadir):
os.makedirs(datadir)
# Translate chain name to config name
if chain == 'testnet3':
write_config(os.path.join(datadir, "particl.conf"), n=n, chain=chain, disable_autoconnect=disable_autoconnect)
os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True)
os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True)
return datadir
def write_config(config_path, *, n, chain, extra_config="", disable_autoconnect=True):
# Translate chain subdirectory name to config name
if chain == 'testnet':
chain_name_conf_arg = 'testnet'
chain_name_conf_section = 'test'
else:
chain_name_conf_arg = chain
chain_name_conf_section = chain
with open(os.path.join(datadir, "particl.conf"), 'w', encoding='utf8') as f:
f.write("{}=1\n".format(chain_name_conf_arg))
f.write("[{}]\n".format(chain_name_conf_section))
with open(config_path, 'w', encoding='utf8') as f:
if chain_name_conf_arg:
f.write("{}=1\n".format(chain_name_conf_arg))
if chain_name_conf_section:
f.write("[{}]\n".format(chain_name_conf_section))
f.write("port=" + str(p2p_port(n)) + "\n")
f.write("rpcport=" + str(rpc_port(n)) + "\n")
# Disable server-side timeouts to avoid intermittent issues
f.write("rpcservertimeout=99000\n")
f.write("rpcdoccheck=1\n")
f.write("fallbackfee=0.0002\n")
f.write("server=1\n")
f.write("keypool=1\n")
f.write("discover=0\n")
f.write("dnsseed=0\n")
f.write("fixedseeds=0\n")
f.write("listenonion=0\n")
# Increase peertimeout to avoid disconnects while using mocktime.
# peertimeout is measured in mock time, so setting it large enough to
# cover any duration in mock time is sufficient. It can be overridden
# in tests.
f.write("peertimeout=999999999\n")
f.write("printtoconsole=0\n")
f.write("upnp=0\n")
f.write("natpmp=0\n")
f.write("shrinkdebugfile=0\n")
os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True)
os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True)
return datadir
f.write("deprecatedrpc=create_bdb\n") # Required to run the tests
# To improve SQLite wallet performance so that the tests don't timeout, use -unsafesqlitesync
f.write("unsafesqlitesync=1\n")
if disable_autoconnect:
f.write("connect=0\n")
f.write(extra_config)
def get_datadir_path(dirname, n):
return os.path.join(dirname, "node" + str(n))
return pathlib.Path(dirname) / f"node{n}"
def get_temp_default_datadir(temp_dir: pathlib.Path) -> tuple[dict, pathlib.Path]:
"""Return os-specific environment variables that can be set to make the
GetDefaultDataDir() function return a datadir path under the provided
temp_dir, as well as the complete path it would return."""
if platform.system() == "Windows":
env = dict(APPDATA=str(temp_dir))
datadir = temp_dir / "Particl"
else:
env = dict(HOME=str(temp_dir))
if platform.system() == "Darwin":
datadir = temp_dir / "Library/Application Support/Particl"
else:
datadir = temp_dir / ".particl"
return env, datadir
def append_config(datadir, options):
@ -395,7 +483,7 @@ def delete_cookie_file(datadir, chain):
def softfork_active(node, key):
"""Return whether a softfork is active."""
return node.getblockchaininfo()['softforks'][key]['active']
return node.getdeploymentinfo()['deployments'][key]['active']
def set_node_times(nodes, t):
@ -403,208 +491,51 @@ def set_node_times(nodes, t):
node.setmocktime(t)
def disconnect_nodes(from_connection, node_num):
def get_peer_ids():
result = []
for peer in from_connection.getpeerinfo():
if "testnode{}".format(node_num) in peer['subver']:
result.append(peer['id'])
return result
peer_ids = get_peer_ids()
if not peer_ids:
logger.warning("disconnect_nodes: {} and {} were not connected".format(
from_connection.index,
node_num,
))
return
for peer_id in peer_ids:
try:
from_connection.disconnectnode(nodeid=peer_id)
except JSONRPCException as e:
# If this node is disconnected between calculating the peer id
# and issuing the disconnect, don't worry about it.
# This avoids a race condition if we're mass-disconnecting peers.
if e.error['code'] != -29: # RPC_CLIENT_NODE_NOT_CONNECTED
raise
# wait to disconnect
wait_until(lambda: not get_peer_ids(), timeout=5)
def connect_nodes(from_connection, node_num):
ip_port = "127.0.0.1:" + str(p2p_port(node_num))
from_connection.addnode(ip_port, "onetry")
# poll until version handshake complete to avoid race conditions
# with transaction relaying
# See comments in net_processing:
# * Must have a version message before anything else
# * Must have a verack message before anything else
wait_until(lambda: all(peer['version'] != 0 for peer in from_connection.getpeerinfo()))
wait_until(lambda: all(peer['bytesrecv_per_msg'].pop('verack', 0) == 24 for peer in from_connection.getpeerinfo()))
def check_node_connections(*, node, num_in, num_out):
info = node.getnetworkinfo()
assert_equal(info["connections_in"], num_in)
assert_equal(info["connections_out"], num_out)
# Transaction/Block functions
#############################
def find_output(node, txid, amount, *, blockhash=None):
"""
Return index to output of txid with value amount
Raises exception if there is none.
"""
txdata = node.getrawtransaction(txid, 1, blockhash)
for i in range(len(txdata["vout"])):
if txdata["vout"][i]["value"] == amount:
return i
raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount)))
def gather_inputs(from_node, amount_needed, confirmations_required=1):
"""
Return a random set of unspent txouts that are enough to pay amount_needed
"""
assert confirmations_required >= 0
utxo = from_node.listunspent(confirmations_required)
random.shuffle(utxo)
inputs = []
total_in = Decimal("0.00000000")
while total_in < amount_needed and len(utxo) > 0:
t = utxo.pop()
total_in += t["amount"]
inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]})
if total_in < amount_needed:
raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in))
return (total_in, inputs)
def make_change(from_node, amount_in, amount_out, fee):
"""
Create change output(s), return them
"""
outputs = {}
amount = amount_out + fee
change = amount_in - amount
if change > amount * 2:
# Create an extra change output to break up big inputs
change_address = from_node.getnewaddress()
# Split change in two, being careful of rounding:
outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
change = amount_in - amount - outputs[change_address]
if change > 0:
outputs[from_node.getnewaddress()] = change
return outputs
def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants):
"""
Create a random transaction.
Returns (txid, hex-encoded-transaction-data, fee)
"""
from_node = random.choice(nodes)
to_node = random.choice(nodes)
fee = min_fee + fee_increment * random.randint(0, fee_variants)
(total_in, inputs) = gather_inputs(from_node, amount + fee)
outputs = make_change(from_node, total_in, amount, fee)
outputs[to_node.getnewaddress()] = float(amount)
rawtx = from_node.createrawtransaction(inputs, outputs)
signresult = from_node.signrawtransactionwithwallet(rawtx)
txid = from_node.sendrawtransaction(signresult["hex"], 0)
return (txid, signresult["hex"], fee)
# Helper to create at least "count" utxos
# Pass in a fee that is sufficient for relay and mining new transactions.
def create_confirmed_utxos(fee, node, count):
to_generate = int(0.5 * count) + 101
while to_generate > 0:
node.generate(min(25, to_generate))
to_generate -= 25
utxos = node.listunspent()
iterations = count - len(utxos)
addr1 = node.getnewaddress()
addr2 = node.getnewaddress()
if iterations <= 0:
return utxos
for i in range(iterations):
t = utxos.pop()
inputs = []
inputs.append({"txid": t["txid"], "vout": t["vout"]})
outputs = {}
send_value = t['amount'] - fee
outputs[addr1] = satoshi_round(send_value / 2)
outputs[addr2] = satoshi_round(send_value / 2)
raw_tx = node.createrawtransaction(inputs, outputs)
signed_tx = node.signrawtransactionwithwallet(raw_tx)["hex"]
node.sendrawtransaction(signed_tx)
while (node.getmempoolinfo()['size'] > 0):
node.generate(1)
utxos = node.listunspent()
assert len(utxos) >= count
return utxos
# Create large OP_RETURN txouts that can be appended to a transaction
# to make it large (helper for constructing large transactions).
# to make it large (helper for constructing large transactions). The
# total serialized size of the txouts is about 66k vbytes.
def gen_return_txouts():
# Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create
# So we have big transactions (and therefore can't fit very many into each block)
# create one script_pubkey
script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes
for i in range(512):
script_pubkey = script_pubkey + "01"
# concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change
txouts = []
from .messages import CTxOut
txout = CTxOut()
txout.nValue = 0
txout.scriptPubKey = hex_str_to_bytes(script_pubkey)
for k in range(128):
txouts.append(txout)
from .script import CScript, OP_RETURN
txouts = [CTxOut(nValue=0, scriptPubKey=CScript([OP_RETURN, b'\x01'*67437]))]
assert_equal(sum([len(txout.serialize()) for txout in txouts]), 67456)
return txouts
# Create a spend of each passed-in utxo, splicing in "txouts" to each raw
# transaction to make it large. See gen_return_txouts() above.
def create_lots_of_big_transactions(node, txouts, utxos, num, fee):
addr = node.getnewaddress()
def create_lots_of_big_transactions(mini_wallet, node, fee, tx_batch_size, txouts, utxos=None):
txids = []
from .messages import CTransaction
for _ in range(num):
t = utxos.pop()
inputs = [{"txid": t["txid"], "vout": t["vout"]}]
outputs = {}
change = t['amount'] - fee
outputs[addr] = satoshi_round(change)
rawtx = node.createrawtransaction(inputs, outputs)
tx = CTransaction()
tx.deserialize(BytesIO(hex_str_to_bytes(rawtx)))
for txout in txouts:
tx.vout.append(txout)
newtx = tx.serialize().hex()
signresult = node.signrawtransactionwithwallet(newtx, None, "NONE")
txid = node.sendrawtransaction(signresult["hex"], 0)
txids.append(txid)
use_internal_utxos = utxos is None
for _ in range(tx_batch_size):
tx = mini_wallet.create_self_transfer(
utxo_to_spend=None if use_internal_utxos else utxos.pop(),
fee=fee,
)["tx"]
tx.vout.extend(txouts)
res = node.testmempoolaccept([tx.serialize().hex()])[0]
assert_equal(res['fees']['base'], fee)
txids.append(node.sendrawtransaction(tx.serialize().hex()))
return txids
def mine_large_block(node, utxos=None):
def mine_large_block(test_framework, mini_wallet, node):
# generate a 66k transaction,
# and 14 of them is close to the 1MB block limit
num = 14
txouts = gen_return_txouts()
utxos = utxos if utxos is not None else []
if len(utxos) < num:
utxos.clear()
utxos.extend(node.listunspent())
fee = 100 * node.getnetworkinfo()["relayfee"]
create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee)
node.generate(1)
create_lots_of_big_transactions(mini_wallet, node, fee, 14, txouts)
test_framework.generate(node, 1)
def find_vout_for_address(node, txid, addr):
@ -614,11 +545,6 @@ def find_vout_for_address(node, txid, addr):
"""
tx = node.getrawtransaction(txid, True)
for i in range(len(tx["vout"])):
scriptPubKey = tx["vout"][i]["scriptPubKey"]
if "addresses" in scriptPubKey:
if any([addr == a for a in scriptPubKey["addresses"]]):
return i
elif "address" in scriptPubKey:
if addr == scriptPubKey["address"]:
return i
if addr == tx["vout"][i]["scriptPubKey"]["address"]:
return i
raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr))

View file

@ -13,7 +13,7 @@ from enum import IntEnum, auto
from typing import Optional
CURRENT_DB_VERSION = 27
CURRENT_DB_VERSION = 28
CURRENT_DB_DATA_VERSION = 6
@ -174,6 +174,7 @@ class Offer(Table):
secret_hash = Column("blob")
addr_from = Column("string")
pk_from = Column("blob")
addr_to = Column("string")
created_at = Column("integer")
expire_at = Column("integer")
@ -216,6 +217,7 @@ class Bid(Table):
created_at = Column("integer")
expire_at = Column("integer")
bid_addr = Column("string")
pk_bid_addr = Column("blob")
proof_address = Column("string")
proof_utxos = Column("blob")
# Address to spend lock tx to - address from wallet if empty TODO
@ -927,15 +929,12 @@ class DBMethods:
table_name: str = table_class.__tablename__
query: str = "SELECT "
columns = []
for mc in inspect.getmembers(table_class):
mc_name, mc_obj = mc
if not hasattr(mc_obj, "__sqlite3_column__"):
continue
if len(columns) > 0:
query += ", "
query += mc_name
@ -943,10 +942,29 @@ class DBMethods:
query += f" FROM {table_name} WHERE 1=1 "
query_data = {}
for ck in constraints:
if not validColumnName(ck):
raise ValueError(f"Invalid constraint column: {ck}")
query += f" AND {ck} = :{ck} "
constraint_value = constraints[ck]
if isinstance(constraint_value, tuple) or isinstance(
constraint_value, list
):
if len(constraint_value) < 2:
raise ValueError(f"Too few constraint values for list: {ck}")
query += f" AND {ck} IN ("
for i, cv in enumerate(constraint_value):
cv_name: str = f"{ck}_{i}"
if i > 0:
query += ","
query += ":" + cv_name
query_data[cv_name] = cv
query += ") "
else:
query += f" AND {ck} = :{ck} "
query_data[ck] = constraint_value
for order_col, order_dir in order_by.items():
if validColumnName(order_col) is False:
@ -959,7 +977,6 @@ class DBMethods:
if query_suffix:
query += query_suffix
query_data = constraints.copy()
query_data.update(extra_query_data)
rows = cursor.execute(query, query_data)
for row in rows:

View file

@ -428,6 +428,11 @@ def upgradeDatabase(self, db_version):
elif current_version == 26:
db_version += 1
cursor.execute("ALTER TABLE offers ADD COLUMN auto_accept_type INTEGER")
elif current_version == 27:
db_version += 1
cursor.execute("ALTER TABLE offers ADD COLUMN pk_from BLOB")
cursor.execute("ALTER TABLE bids ADD COLUMN pk_bid_addr BLOB")
if current_version != db_version:
self.db_version = db_version
self.setIntKV("db_version", db_version, cursor)

View file

View file

@ -0,0 +1,350 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import base64
import json
import threading
import websocket
from queue import Queue, Empty
from basicswap.util.smsg import (
smsgEncrypt,
smsgDecrypt,
smsgGetID,
)
from basicswap.chainparams import (
Coins,
)
from basicswap.util.address import (
b58decode,
decodeWif,
)
from basicswap.basicswap_util import (
BidStates,
)
def encode_base64(data: bytes) -> str:
return base64.b64encode(data).decode("utf-8")
def decode_base64(encoded_data: str) -> bytes:
return base64.b64decode(encoded_data)
class WebSocketThread(threading.Thread):
def __init__(self, url: str, tag: str = None, logger=None):
super().__init__()
self.url: str = url
self.tag = tag
self.logger = logger
self.ws = None
self.mutex = threading.Lock()
self.corrId: int = 0
self.connected: bool = False
self.delay_event = threading.Event()
self.recv_queue = Queue()
self.cmd_recv_queue = Queue()
def on_message(self, ws, message):
if self.logger:
self.logger.debug("Simplex received msg")
else:
print(f"{self.tag} - Received msg")
if message.startswith('{"corrId"'):
self.cmd_recv_queue.put(message)
else:
self.recv_queue.put(message)
def queue_get(self):
try:
return self.recv_queue.get(block=False)
except Empty:
return None
def cmd_queue_get(self):
try:
return self.cmd_recv_queue.get(block=False)
except Empty:
return None
def on_error(self, ws, error):
if self.logger:
self.logger.error(f"Simplex ws - {error}")
else:
print(f"{self.tag} - Error: {error}")
def on_close(self, ws, close_status_code, close_msg):
if self.logger:
self.logger.info(f"Simplex ws - Closed: {close_status_code}, {close_msg}")
else:
print(f"{self.tag} - Closed: {close_status_code}, {close_msg}")
def on_open(self, ws):
if self.logger:
self.logger.info("Simplex ws - Connection opened")
else:
print(f"{self.tag}: WebSocket connection opened")
self.connected = True
def send_command(self, cmd_str: str):
with self.mutex:
self.corrId += 1
if self.logger:
self.logger.debug(f"Simplex sent command {self.corrId}")
else:
print(f"{self.tag}: sent command {self.corrId}")
cmd = json.dumps({"corrId": str(self.corrId), "cmd": cmd_str})
self.ws.send(cmd)
return self.corrId
def run(self):
self.ws = websocket.WebSocketApp(
self.url,
on_message=self.on_message,
on_error=self.on_error,
on_open=self.on_open,
on_close=self.on_close,
)
while not self.delay_event.is_set():
self.ws.run_forever()
self.delay_event.wait(0.5)
def stop(self):
self.delay_event.set()
if self.ws:
self.ws.close()
def waitForResponse(ws_thread, sent_id, delay_event):
sent_id = str(sent_id)
for i in range(100):
message = ws_thread.cmd_queue_get()
if message is not None:
data = json.loads(message)
# print(f"json: {json.dumps(data, indent=4)}")
if "corrId" in data:
if data["corrId"] == sent_id:
return data
delay_event.wait(0.5)
raise ValueError(f"waitForResponse timed-out waiting for id: {sent_id}")
def waitForConnected(ws_thread, delay_event):
for i in range(100):
if ws_thread.connected:
return True
delay_event.wait(0.5)
raise ValueError("waitForConnected timed-out.")
def getPrivkeyForAddress(self, addr) -> bytes:
ci_part = self.ci(Coins.PART)
try:
return ci_part.decodeKey(
self.callrpc(
"smsgdumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
try:
return ci_part.decodeKey(
ci_part.rpc_wallet(
"dumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
raise ValueError("key not found")
def sendSimplexMsg(
self, network, addr_from: str, addr_to: str, payload: bytes, msg_valid: int, cursor
) -> bytes:
self.log.debug("sendSimplexMsg")
try:
rv = self.callrpc(
"smsggetpubkey",
[
addr_to,
],
)
pubkey_to: bytes = b58decode(rv["publickey"])
except Exception as e: # noqa: F841
use_cursor = self.openDB(cursor)
try:
query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr_to}).fetchall()
if len(rows) > 0:
pubkey_to = rows[0][0]
else:
query: str = (
"SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1"
)
rows = use_cursor.execute(query, {"addr_to": addr_to}).fetchall()
if len(rows) > 0:
pubkey_to = rows[0][0]
else:
raise ValueError(f"Could not get public key for address {addr_to}")
finally:
if cursor is None:
self.closeDB(use_cursor, commit=False)
privkey_from = getPrivkeyForAddress(self, addr_from)
payload += bytes((0,)) # Include null byte to match smsg
smsg_msg: bytes = smsgEncrypt(privkey_from, pubkey_to, payload)
smsg_id = smsgGetID(smsg_msg)
ws_thread = network["ws_thread"]
sent_id = ws_thread.send_command("#bsx " + encode_base64(smsg_msg))
response = waitForResponse(ws_thread, sent_id, self.delay_event)
if response["resp"]["type"] != "newChatItems":
json_str = json.dumps(response, indent=4)
self.log.debug(f"Response {json_str}")
raise ValueError("Send failed")
return smsg_id
def decryptSimplexMsg(self, msg_data):
ci_part = self.ci(Coins.PART)
# Try with the network key first
network_key: bytes = decodeWif(self.network_key)
try:
decrypted = smsgDecrypt(network_key, msg_data, output_dict=True)
decrypted["from"] = ci_part.pubkey_to_address(
bytes.fromhex(decrypted["pk_from"])
)
decrypted["to"] = self.network_addr
decrypted["msg_net"] = "simplex"
return decrypted
except Exception as e: # noqa: F841
pass
# Try with all active bid/offer addresses
query: str = """SELECT DISTINCT address FROM (
SELECT bid_addr AS address FROM bids WHERE active_ind = 1
AND (in_progress = 1 OR (state > :bid_received AND state < :bid_completed) OR (state IN (:bid_received, :bid_sent) AND expire_at > :now))
UNION
SELECT addr_from AS address FROM offers WHERE active_ind = 1 AND expire_at > :now
)"""
now: int = self.getTime()
try:
cursor = self.openDB()
addr_rows = cursor.execute(
query,
{
"bid_received": int(BidStates.BID_RECEIVED),
"bid_completed": int(BidStates.SWAP_COMPLETED),
"bid_sent": int(BidStates.BID_SENT),
"now": now,
},
).fetchall()
finally:
self.closeDB(cursor, commit=False)
decrypted = None
for row in addr_rows:
addr = row[0]
try:
vk_addr = getPrivkeyForAddress(self, addr)
decrypted = smsgDecrypt(vk_addr, msg_data, output_dict=True)
decrypted["from"] = ci_part.pubkey_to_address(
bytes.fromhex(decrypted["pk_from"])
)
decrypted["to"] = addr
decrypted["msg_net"] = "simplex"
return decrypted
except Exception as e: # noqa: F841
pass
return decrypted
def readSimplexMsgs(self, network):
ws_thread = network["ws_thread"]
for i in range(100):
message = ws_thread.queue_get()
if message is None:
break
data = json.loads(message)
# self.log.debug(f"message 1: {json.dumps(data, indent=4)}")
try:
if data["resp"]["type"] in ("chatItemsStatusesUpdated", "newChatItems"):
for chat_item in data["resp"]["chatItems"]:
item_status = chat_item["chatItem"]["meta"]["itemStatus"]
if item_status["type"] in ("sndRcvd", "rcvNew"):
snd_progress = item_status.get("sndProgress", None)
if snd_progress:
if snd_progress != "complete":
item_id = chat_item["chatItem"]["meta"]["itemId"]
self.log.debug(
f"simplex chat item {item_id} {snd_progress}"
)
continue
try:
msg_data: bytes = decode_base64(
chat_item["chatItem"]["content"]["msgContent"]["text"]
)
decrypted_msg = decryptSimplexMsg(self, msg_data)
if decrypted_msg is None:
continue
self.processMsg(decrypted_msg)
except Exception as e: # noqa: F841
# self.log.debug(f"decryptSimplexMsg error: {e}")
pass
except Exception as e:
self.log.debug(f"readSimplexMsgs error: {e}")
self.delay_event.wait(0.05)
def initialiseSimplexNetwork(self, network_config) -> None:
self.log.debug("initialiseSimplexNetwork")
client_host: str = network_config.get("client_host", "127.0.0.1")
ws_port: str = network_config.get("ws_port")
ws_thread = WebSocketThread(f"ws://{client_host}:{ws_port}", logger=self.log)
self.threads.append(ws_thread)
ws_thread.start()
waitForConnected(ws_thread, self.delay_event)
sent_id = ws_thread.send_command("/groups")
response = waitForResponse(ws_thread, sent_id, self.delay_event)
if len(response["resp"]["groups"]) < 1:
sent_id = ws_thread.send_command("/c " + network_config["group_link"])
response = waitForResponse(ws_thread, sent_id, self.delay_event)
assert "groupLinkId" in response["resp"]["connection"]
network = {
"type": "simplex",
"ws_thread": ws_thread,
}
self.active_networks.append(network)

View file

@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import os
import select
import subprocess
import time
from basicswap.bin.run import Daemon
def initSimplexClient(args, logger, delay_event):
logger.info("Initialising Simplex client")
(pipe_r, pipe_w) = os.pipe() # subprocess.PIPE is buffered, blocks when read
if os.name == "nt":
str_args = " ".join(args)
p = subprocess.Popen(
str_args, shell=True, stdin=subprocess.PIPE, stdout=pipe_w, stderr=pipe_w
)
else:
p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=pipe_w, stderr=pipe_w)
def readOutput():
buf = os.read(pipe_r, 1024).decode("utf-8")
response = None
# logging.debug(f"simplex-chat output: {buf}")
if "display name:" in buf:
logger.debug("Setting display name")
response = b"user\n"
else:
logger.debug(f"Unexpected output: {buf}")
return
if response is not None:
p.stdin.write(response)
p.stdin.flush()
try:
start_time: int = time.time()
max_wait_seconds: int = 60
while p.poll() is None:
if time.time() > start_time + max_wait_seconds:
raise ValueError("Timed out")
if os.name == "nt":
readOutput()
delay_event.wait(0.1)
continue
while len(select.select([pipe_r], [], [], 0)[0]) == 1:
readOutput()
delay_event.wait(0.1)
except Exception as e:
logger.error(f"initSimplexClient: {e}")
finally:
if p.poll() is None:
p.terminate()
os.close(pipe_r)
os.close(pipe_w)
p.stdin.close()
def startSimplexClient(
bin_path: str,
data_path: str,
server_address: str,
websocket_port: int,
logger,
delay_event,
) -> Daemon:
logger.info("Starting Simplex client")
if not os.path.exists(data_path):
os.makedirs(data_path)
db_path = os.path.join(data_path, "simplex_client_data")
args = [bin_path, "-d", db_path, "-s", server_address, "-p", str(websocket_port)]
if not os.path.exists(db_path):
# Need to set initial profile through CLI
# TODO: Must be a better way?
init_args = args + ["-e", "/help"] # Run command ro exit client
initSimplexClient(init_args, logger, delay_event)
args += ["-l", "debug"]
opened_files = []
stdout_dest = open(
os.path.join(data_path, "simplex_stdout.log"),
"w",
)
opened_files.append(stdout_dest)
stderr_dest = stdout_dest
return Daemon(
subprocess.Popen(
args,
shell=False,
stdin=subprocess.PIPE,
stdout=stdout_dest,
stderr=stderr_dest,
cwd=data_path,
),
opened_files,
)

20
basicswap/network/util.py Normal file
View file

@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
from basicswap.util.address import b58decode
def getMsgPubkey(self, msg) -> bytes:
if "pk_from" in msg:
return bytes.fromhex(msg["pk_from"])
rv = self.callrpc(
"smsggetpubkey",
[
msg["from"],
],
)
return b58decode(rv["publickey"])

229
basicswap/util/smsg.py Normal file
View file

@ -0,0 +1,229 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import hashlib
import hmac
import secrets
import time
from typing import Union, Dict
from coincurve.keys import (
PublicKey,
PrivateKey,
)
from Crypto.Cipher import AES
from basicswap.util.crypto import hash160, sha256, ripemd160
from basicswap.util.ecc import getSecretInt
from basicswap.contrib.test_framework.messages import (
uint256_from_compact,
uint256_from_str,
)
AES_BLOCK_SIZE = 16
def aes_pad(s: bytes):
c = AES_BLOCK_SIZE - len(s) % AES_BLOCK_SIZE
return s + (bytes((c,)) * c)
def aes_unpad(s: bytes):
return s[: -(s[len(s) - 1])]
def aes_encrypt(raw: bytes, pass_data: bytes, iv: bytes):
assert len(pass_data) == 32
assert len(iv) == 16
raw = aes_pad(raw)
cipher = AES.new(pass_data, AES.MODE_CBC, iv)
return cipher.encrypt(raw)
def aes_decrypt(enc, pass_data: bytes, iv: bytes):
assert len(pass_data) == 32
assert len(iv) == 16
cipher = AES.new(pass_data, AES.MODE_CBC, iv)
return aes_unpad(cipher.decrypt(enc))
SMSG_MIN_TTL = 60 * 60
SMSG_BUCKET_LEN = 60 * 60
SMSG_HDR_LEN = (
108 # Length of unencrypted header, 4 + 4 + 2 + 1 + 8 + 4 + 16 + 33 + 32 + 4
)
SMSG_PL_HDR_LEN = 1 + 20 + 65 + 4 # Length of encrypted header in payload
def smsgGetTimestamp(smsg_message: bytes) -> int:
assert len(smsg_message) > SMSG_HDR_LEN
return int.from_bytes(smsg_message[11 : 11 + 8], byteorder="little")
def smsgGetPOWHash(smsg_message: bytes) -> bytes:
assert len(smsg_message) > SMSG_HDR_LEN
ofs: int = 4
nonce: bytes = smsg_message[ofs : ofs + 4]
iv: bytes = nonce * 8
m = hmac.new(iv, digestmod="SHA256")
m.update(smsg_message[4:])
return m.digest()
def smsgGetID(smsg_message: bytes) -> bytes:
assert len(smsg_message) > SMSG_HDR_LEN
smsg_timestamp = int.from_bytes(smsg_message[11 : 11 + 8], byteorder="little")
return smsg_timestamp.to_bytes(8, byteorder="big") + ripemd160(smsg_message[8:])
def smsgEncrypt(privkey_from: bytes, pubkey_to: bytes, payload: bytes) -> bytes:
# assert len(payload) < 128 # Requires lz4 if payload > 128 bytes
# TODO: Add lz4 to match core smsg
smsg_timestamp = int(time.time())
r = getSecretInt().to_bytes(32, byteorder="big")
R = PublicKey.from_secret(r).format()
p = PrivateKey(r).ecdh(pubkey_to)
H = hashlib.sha512(p).digest()
key_e: bytes = H[:32]
key_m: bytes = H[32:]
smsg_iv: bytes = secrets.token_bytes(16)
payload_hash: bytes = sha256(sha256(payload))
signature: bytes = PrivateKey(privkey_from).sign_recoverable(
payload_hash, hasher=None
)
# Convert format to BTC, add 4 to mark as compressed key
recid = signature[64]
signature = bytes((27 + recid + 4,)) + signature[:64]
pubkey_from: bytes = PublicKey.from_secret(privkey_from).format()
pkh_from: bytes = hash160(pubkey_from)
len_payload = len(payload)
address_version = 0
plaintext_data: bytes = (
bytes((address_version,))
+ pkh_from
+ signature
+ len_payload.to_bytes(4, byteorder="little")
+ payload
)
ciphertext: bytes = aes_encrypt(plaintext_data, key_e, smsg_iv)
m = hmac.new(key_m, digestmod="SHA256")
m.update(smsg_timestamp.to_bytes(8, byteorder="little"))
m.update(smsg_iv)
m.update(ciphertext)
mac: bytes = m.digest()
smsg_hash = bytes((0,)) * 4
smsg_nonce = bytes((0,)) * 4
smsg_version = bytes((2, 1))
smsg_flags = bytes((0,))
smsg_ttl = SMSG_MIN_TTL
assert len(R) == 33
assert len(mac) == 32
smsg_message: bytes = (
smsg_hash
+ smsg_nonce
+ smsg_version
+ smsg_flags
+ smsg_timestamp.to_bytes(8, byteorder="little")
+ smsg_ttl.to_bytes(4, byteorder="little")
+ smsg_iv
+ R
+ mac
+ len(ciphertext).to_bytes(4, byteorder="little")
+ ciphertext
)
target: int = uint256_from_compact(0x1EFFFFFF)
for i in range(1000000):
pow_hash = smsgGetPOWHash(smsg_message)
if uint256_from_str(pow_hash) > target:
smsg_nonce = (int.from_bytes(smsg_nonce, byteorder="little") + 1).to_bytes(
4, byteorder="little"
)
smsg_message = pow_hash[:4] + smsg_nonce + smsg_message[8:]
continue
smsg_message = pow_hash[:4] + smsg_message[4:]
return smsg_message
raise ValueError("Failed to set POW hash.")
def smsgDecrypt(
privkey_to: bytes, encrypted_message: bytes, output_dict: bool = False
) -> Union[bytes, Dict]:
# Without lz4
assert len(encrypted_message) > SMSG_HDR_LEN
smsg_timestamp = int.from_bytes(encrypted_message[11 : 11 + 8], byteorder="little")
ofs: int = 23
smsg_iv = encrypted_message[ofs : ofs + 16]
ofs += 16
R = encrypted_message[ofs : ofs + 33]
ofs += 33
mac = encrypted_message[ofs : ofs + 32]
ofs += 32
ciphertextlen = int.from_bytes(encrypted_message[ofs : ofs + 4], byteorder="little")
ofs += 4
ciphertext = encrypted_message[ofs:]
assert len(ciphertext) == ciphertextlen
p = PrivateKey(privkey_to).ecdh(R)
H = hashlib.sha512(p).digest()
key_e: bytes = H[:32]
key_m: bytes = H[32:]
m = hmac.new(key_m, digestmod="SHA256")
m.update(smsg_timestamp.to_bytes(8, byteorder="little"))
m.update(smsg_iv)
m.update(ciphertext)
mac_calculated: bytes = m.digest()
assert mac == mac_calculated
plaintext = aes_decrypt(ciphertext, key_e, smsg_iv)
ofs = 1
pkh_from = plaintext[ofs : ofs + 20]
ofs += 20
signature = plaintext[ofs : ofs + 65]
ofs += 65
ofs += 4
payload = plaintext[ofs:]
payload_hash: bytes = sha256(sha256(payload))
# Convert format from BTC
recid = (signature[0] - 27) & 3
signature = signature[1:] + bytes((recid,))
pubkey_signer = PublicKey.from_signature_and_message(
signature, payload_hash, hasher=None
).format()
pkh_from_recovered: bytes = hash160(pubkey_signer)
assert pkh_from == pkh_from_recovered
if output_dict:
return {
"msgid": smsgGetID(encrypted_message).hex(),
"sent": smsg_timestamp,
"hex": payload.hex(),
"pk_from": pubkey_signer.hex(),
}
return payload

View file

@ -3,4 +3,5 @@ python-gnupg==0.5.4
Jinja2==3.1.6
pycryptodome==3.21.0
PySocks==1.7.1
websocket-client==1.8.0
coincurve@https://github.com/basicswap/coincurve/archive/refs/tags/basicswap_v0.2.zip

View file

@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.12
# This file is autogenerated by pip-compile with Python 3.13
# by the following command:
#
# pip-compile --generate-hashes --output-file=requirements.txt requirements.in
@ -305,3 +305,7 @@ pyzmq==26.2.1 \
--hash=sha256:f9ba5def063243793dec6603ad1392f735255cbc7202a3a484c14f99ec290705 \
--hash=sha256:fc409c18884eaf9ddde516d53af4f2db64a8bc7d81b1a0c274b8aa4e929958e8
# via -r requirements.in
websocket-client==1.8.0 \
--hash=sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526 \
--hash=sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da
# via -r requirements.in

View file

@ -30,7 +30,6 @@ from basicswap.contrib.test_framework.messages import (
CTransaction,
CTxIn,
COutPoint,
ToHex,
)
from basicswap.contrib.test_framework.script import (
CScript,
@ -318,7 +317,7 @@ class Test(TestFunctions):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -357,10 +356,10 @@ class Test(TestFunctions):
)
)
tx_spend.vout.append(ci.txoType()(ci.make_int(1.099), script_out))
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend)
tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try:

View file

@ -0,0 +1,342 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
"""
docker run \
-e "ADDR=127.0.0.1" \
-e "PASS=password" \
-p 5223:5223 \
-v /tmp/simplex/smp/config:/etc/opt/simplex:z \
-v /tmp/simplex/smp/logs:/var/opt/simplex:z \
-v /tmp/simplex/certs:/certificates \
simplexchat/smp-server:latest
Fingerprint: Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=
export SIMPLEX_SERVER_ADDRESS=smp://Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=:password@127.0.0.1:5223,443
https://github.com/simplex-chat/simplex-chat/issues/4127
json: {"corrId":"3","cmd":"/_send #1 text test123"}
direct message: {"corrId":"1","cmd":"/_send @2 text the message"}
"""
import json
import logging
import os
import random
import shutil
import sys
import unittest
import basicswap.config as cfg
from basicswap.basicswap import (
BidStates,
SwapTypes,
)
from basicswap.chainparams import Coins
from basicswap.network.simplex import (
WebSocketThread,
waitForConnected,
waitForResponse,
)
from basicswap.network.simplex_chat import startSimplexClient
from tests.basicswap.common import (
stopDaemons,
wait_for_bid,
wait_for_offer,
)
from tests.basicswap.test_xmr import BaseTest, test_delay_event, RESET_TEST
SIMPLEX_SERVER_ADDRESS = os.getenv("SIMPLEX_SERVER_ADDRESS")
SIMPLEX_CLIENT_PATH = os.path.expanduser(os.getenv("SIMPLEX_CLIENT_PATH"))
TEST_DIR = cfg.TEST_DATADIRS
logger = logging.getLogger()
logger.level = logging.DEBUG
if not len(logger.handlers):
logger.addHandler(logging.StreamHandler(sys.stdout))
class TestSimplex(unittest.TestCase):
daemons = []
remove_testdir: bool = False
@classmethod
def tearDownClass(cls):
stopDaemons(cls.daemons)
def test_basic(self):
if os.path.isdir(TEST_DIR):
if RESET_TEST:
logging.info("Removing " + TEST_DIR)
shutil.rmtree(TEST_DIR)
else:
logging.info("Restoring instance from " + TEST_DIR)
if not os.path.exists(TEST_DIR):
os.makedirs(TEST_DIR)
client1_dir = os.path.join(TEST_DIR, "client1")
if os.path.exists(client1_dir):
shutil.rmtree(client1_dir)
client1_daemon = startSimplexClient(
SIMPLEX_CLIENT_PATH,
client1_dir,
SIMPLEX_SERVER_ADDRESS,
5225,
logger,
test_delay_event,
)
self.daemons.append(client1_daemon)
client2_dir = os.path.join(TEST_DIR, "client2")
if os.path.exists(client2_dir):
shutil.rmtree(client2_dir)
client2_daemon = startSimplexClient(
SIMPLEX_CLIENT_PATH,
client2_dir,
SIMPLEX_SERVER_ADDRESS,
5226,
logger,
test_delay_event,
)
self.daemons.append(client2_daemon)
threads = []
try:
ws_thread = WebSocketThread("ws://127.0.0.1:5225", tag="C1")
ws_thread.start()
threads.append(ws_thread)
ws_thread2 = WebSocketThread("ws://127.0.0.1:5226", tag="C2")
ws_thread2.start()
threads.append(ws_thread2)
waitForConnected(ws_thread, test_delay_event)
sent_id = ws_thread.send_command("/group bsx")
response = waitForResponse(ws_thread, sent_id, test_delay_event)
assert response["resp"]["type"] == "groupCreated"
ws_thread.send_command("/set voice #bsx off")
ws_thread.send_command("/set files #bsx off")
ws_thread.send_command("/set direct #bsx off")
ws_thread.send_command("/set reactions #bsx off")
ws_thread.send_command("/set reports #bsx off")
ws_thread.send_command("/set disappear #bsx on week")
sent_id = ws_thread.send_command("/create link #bsx")
connReqContact = None
connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event)
connReqContact = connReqMsgData["resp"]["connReqContact"]
group_link = "https://simplex.chat" + connReqContact[8:]
logger.info(f"group_link: {group_link}")
sent_id = ws_thread2.send_command("/c " + group_link)
response = waitForResponse(ws_thread2, sent_id, test_delay_event)
assert "groupLinkId" in response["resp"]["connection"]
sent_id = ws_thread2.send_command("/groups")
response = waitForResponse(ws_thread2, sent_id, test_delay_event)
assert len(response["resp"]["groups"]) == 1
ws_thread.send_command("#bsx test msg 1")
found_1 = False
found_2 = False
for i in range(100):
message = ws_thread.queue_get()
if message is not None:
data = json.loads(message)
# print(f"message 1: {json.dumps(data, indent=4)}")
try:
if data["resp"]["type"] in (
"chatItemsStatusesUpdated",
"newChatItems",
):
for chat_item in data["resp"]["chatItems"]:
# print(f"chat_item 1: {json.dumps(chat_item, indent=4)}")
if chat_item["chatItem"]["meta"]["itemStatus"][
"type"
] in ("sndRcvd", "rcvNew"):
if (
chat_item["chatItem"]["content"]["msgContent"][
"text"
]
== "test msg 1"
):
found_1 = True
except Exception as e:
print(f"error 1: {e}")
message = ws_thread2.queue_get()
if message is not None:
data = json.loads(message)
# print(f"message 2: {json.dumps(data, indent=4)}")
try:
if data["resp"]["type"] in (
"chatItemsStatusesUpdated",
"newChatItems",
):
for chat_item in data["resp"]["chatItems"]:
# print(f"chat_item 1: {json.dumps(chat_item, indent=4)}")
if chat_item["chatItem"]["meta"]["itemStatus"][
"type"
] in ("sndRcvd", "rcvNew"):
if (
chat_item["chatItem"]["content"]["msgContent"][
"text"
]
== "test msg 1"
):
found_2 = True
except Exception as e:
print(f"error 2: {e}")
if found_1 and found_2:
break
test_delay_event.wait(0.5)
assert found_1 is True
assert found_2 is True
finally:
for t in threads:
t.stop()
t.join()
class Test(BaseTest):
__test__ = True
start_ltc_nodes = False
start_xmr_nodes = True
group_link = None
daemons = []
coin_to = Coins.XMR
# coin_to = Coins.PART
@classmethod
def prepareTestDir(cls):
base_ws_port: int = 5225
for i in range(cls.num_nodes):
client_dir = os.path.join(TEST_DIR, f"simplex_client{i}")
if os.path.exists(client_dir):
shutil.rmtree(client_dir)
client_daemon = startSimplexClient(
SIMPLEX_CLIENT_PATH,
client_dir,
SIMPLEX_SERVER_ADDRESS,
base_ws_port + i,
logger,
test_delay_event,
)
cls.daemons.append(client_daemon)
# Create the group for bsx
logger.info("Creating BSX group")
ws_thread = None
try:
ws_thread = WebSocketThread(f"ws://127.0.0.1:{base_ws_port}", tag="C0")
ws_thread.start()
waitForConnected(ws_thread, test_delay_event)
sent_id = ws_thread.send_command("/group bsx")
response = waitForResponse(ws_thread, sent_id, test_delay_event)
assert response["resp"]["type"] == "groupCreated"
ws_thread.send_command("/set voice #bsx off")
ws_thread.send_command("/set files #bsx off")
ws_thread.send_command("/set direct #bsx off")
ws_thread.send_command("/set reactions #bsx off")
ws_thread.send_command("/set reports #bsx off")
ws_thread.send_command("/set disappear #bsx on week")
sent_id = ws_thread.send_command("/create link #bsx")
connReqContact = None
connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event)
connReqContact = connReqMsgData["resp"]["connReqContact"]
cls.group_link = "https://simplex.chat" + connReqContact[8:]
logger.info(f"BSX group_link: {cls.group_link}")
finally:
if ws_thread:
ws_thread.stop()
ws_thread.join()
@classmethod
def tearDownClass(cls):
logging.info("Finalising Test")
super(Test, cls).tearDownClass()
stopDaemons(cls.daemons)
@classmethod
def addCoinSettings(cls, settings, datadir, node_id):
settings["networks"] = [
{
"type": "simplex",
"server_address": SIMPLEX_SERVER_ADDRESS,
"client_path": SIMPLEX_CLIENT_PATH,
"ws_port": 5225 + node_id,
"group_link": cls.group_link,
},
]
def test_01_swap(self):
logging.info("---------- Test xmr swap")
swap_clients = self.swap_clients
for sc in swap_clients:
sc.dleag_split_size_init = 9000
sc.dleag_split_size = 11000
assert len(swap_clients[0].active_networks) == 1
assert swap_clients[0].active_networks[0]["type"] == "simplex"
coin_from = Coins.BTC
coin_to = self.coin_to
ci_from = swap_clients[0].ci(coin_from)
ci_to = swap_clients[1].ci(coin_to)
swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1)
rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1)
offer_id = swap_clients[0].postOffer(
coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP
)
wait_for_offer(test_delay_event, swap_clients[1], offer_id)
offer = swap_clients[1].getOffer(offer_id)
bid_id = swap_clients[1].postBid(offer_id, offer.amount_from)
wait_for_bid(test_delay_event, swap_clients[0], bid_id, BidStates.BID_RECEIVED)
swap_clients[0].acceptBid(bid_id)
wait_for_bid(
test_delay_event,
swap_clients[0],
bid_id,
BidStates.SWAP_COMPLETED,
wait_for=320,
)
wait_for_bid(
test_delay_event,
swap_clients[1],
bid_id,
BidStates.SWAP_COMPLETED,
sent=True,
wait_for=320,
)

View file

@ -0,0 +1,147 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import logging
from basicswap.chainparams import Coins
from basicswap.util.smsg import (
smsgEncrypt,
smsgDecrypt,
smsgGetID,
smsgGetTimestamp,
SMSG_BUCKET_LEN,
)
from basicswap.contrib.test_framework.messages import (
NODE_SMSG,
msg_smsgPong,
msg_smsgMsg,
)
from basicswap.contrib.test_framework.p2p import (
P2PInterface,
P2P_SERVICES,
NetworkThread,
)
from basicswap.contrib.test_framework.util import (
PortSeed,
)
from tests.basicswap.common import BASE_PORT
from tests.basicswap.test_xmr import BaseTest, test_delay_event
class P2PInterfaceSMSG(P2PInterface):
def __init__(self):
super().__init__()
self.is_part = True
def on_smsgPing(self, msg):
logging.info("on_smsgPing")
self.send_message(msg_smsgPong(1))
def on_smsgPong(self, msg):
logging.info("on_smsgPong", msg)
def on_smsgInv(self, msg):
logging.info("on_smsgInv")
def wait_for_smsg(ci, msg_id: str, wait_for=20) -> None:
for i in range(wait_for):
if test_delay_event.is_set():
raise ValueError("Test stopped.")
try:
ci.rpc_wallet("smsg", [msg_id])
return
except Exception as e:
logging.info(e)
test_delay_event.wait(1)
raise ValueError("wait_for_smsg timed out.")
class Test(BaseTest):
__test__ = True
start_ltc_nodes = False
start_xmr_nodes = False
@classmethod
def setUpClass(cls):
super(Test, cls).setUpClass()
PortSeed.n = 1
logging.info("Setting up network thread")
cls.network_thread = NetworkThread()
cls.network_thread.network_event_loop.set_debug(True)
cls.network_thread.start()
cls.network_thread.network_event_loop.set_debug(True)
@classmethod
def run_loop_ended(cls):
logging.info("run_loop_ended")
logging.info("Closing down network thread")
cls.network_thread.close()
@classmethod
def tearDownClass(cls):
logging.info("Finalising Test")
# logging.info('Closing down network thread')
# cls.network_thread.close()
super(Test, cls).tearDownClass()
@classmethod
def coins_loop(cls):
super(Test, cls).coins_loop()
def test_01_p2p(self):
swap_clients = self.swap_clients
kwargs = {}
kwargs["dstport"] = BASE_PORT
kwargs["dstaddr"] = "127.0.0.1"
services = P2P_SERVICES | NODE_SMSG
p2p_conn = P2PInterfaceSMSG()
p2p_conn.p2p_connected_to_node = True
p2p_conn.peer_connect(
**kwargs,
services=services,
send_version=True,
net="regtest",
timeout_factor=99999,
supports_v2_p2p=False,
)()
p2p_conn.wait_for_connect()
p2p_conn.wait_for_verack()
p2p_conn.sync_with_ping()
ci0_part = swap_clients[0].ci(Coins.PART)
test_key_recv: bytes = ci0_part.getNewRandomKey()
test_key_recv_wif: str = ci0_part.encodeKey(test_key_recv)
test_key_recv_pk: bytes = ci0_part.getPubkey(test_key_recv)
ci0_part.rpc("smsgimportprivkey", [test_key_recv_wif, "test key"])
message_test: str = "Test message"
test_key_send: bytes = ci0_part.getNewRandomKey()
encrypted_message: bytes = smsgEncrypt(
test_key_send, test_key_recv_pk, message_test.encode("utf-8")
)
decrypted_message: bytes = smsgDecrypt(test_key_recv, encrypted_message)
assert decrypted_message.decode("utf-8") == message_test
msg_id: bytes = smsgGetID(encrypted_message)
smsg_timestamp: int = smsgGetTimestamp(encrypted_message)
smsg_bucket: int = smsg_timestamp - (smsg_timestamp % SMSG_BUCKET_LEN)
smsgMsg = msg_smsgMsg(1, smsg_bucket, encrypted_message)
p2p_conn.send_message(smsgMsg)
wait_for_smsg(ci0_part, msg_id.hex())
rv = ci0_part.rpc_wallet("smsg", [msg_id.hex()])
assert rv["text"] == message_test

View file

@ -26,7 +26,6 @@ from tests.basicswap.common import (
waitForRPC,
)
from basicswap.contrib.test_framework.messages import (
ToHex,
CTxIn,
COutPoint,
CTransaction,
@ -251,7 +250,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -285,10 +284,10 @@ class TestBCH(BasicSwapTest):
)
)
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend)
tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try:
@ -362,7 +361,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -405,7 +404,7 @@ class TestBCH(BasicSwapTest):
)
)
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
try:
txid = ci.rpc(
"sendrawtransaction",
@ -640,7 +639,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -682,7 +681,7 @@ class TestBCH(BasicSwapTest):
)
)
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc(
"sendrawtransaction",
@ -730,7 +729,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -772,7 +771,7 @@ class TestBCH(BasicSwapTest):
)
)
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc(
"sendrawtransaction",

View file

@ -46,8 +46,7 @@ from tests.basicswap.common import (
)
from basicswap.contrib.test_framework.descriptors import descsum_create
from basicswap.contrib.test_framework.messages import (
ToHex,
FromHex,
from_hex,
CTxIn,
COutPoint,
CTransaction,
@ -860,7 +859,7 @@ class BasicSwapTest(TestFunctions):
addr_p2sh_segwit,
],
)
decoded_tx = FromHex(CTransaction(), tx_funded)
decoded_tx = from_hex(CTransaction(), tx_funded)
decoded_tx.vin[0].scriptSig = bytes.fromhex("16" + addr_p2sh_segwit_info["hex"])
txid_with_scriptsig = decoded_tx.rehash()
assert txid_with_scriptsig == tx_signed_decoded["txid"]
@ -950,7 +949,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -979,10 +978,10 @@ class BasicSwapTest(TestFunctions):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script,
]
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend)
tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try:
@ -1055,7 +1054,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -1094,7 +1093,7 @@ class BasicSwapTest(TestFunctions):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script,
]
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
try:
txid = ci.rpc(
"sendrawtransaction",
@ -1435,7 +1434,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -1477,7 +1476,7 @@ class BasicSwapTest(TestFunctions):
)
)
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc(
"sendrawtransaction",
@ -1525,7 +1524,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet(
@ -1567,7 +1566,7 @@ class BasicSwapTest(TestFunctions):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script,
]
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc(
"sendrawtransaction",

View file

@ -56,7 +56,6 @@ from basicswap.contrib.test_framework.messages import (
CTransaction,
CTxIn,
CTxInWitness,
ToHex,
)
from basicswap.contrib.test_framework.script import (
CScript,
@ -211,7 +210,7 @@ class Test(BaseTest):
tx = CTransaction()
tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx)
tx_hex = tx.serialize().hex()
tx_funded = callnoderpc(0, "fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = callnoderpc(
@ -248,10 +247,10 @@ class Test(BaseTest):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script,
]
tx_spend_hex = ToHex(tx_spend)
tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend)
tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try:

View file

@ -247,7 +247,7 @@ def ltcCli(cmd, node_id=0):
def signal_handler(sig, frame):
logging.info("signal {} detected.".format(sig))
logging.info(f"signal {sig} detected.")
signal_event.set()
test_delay_event.set()
@ -309,6 +309,7 @@ def run_loop(cls):
for c in cls.swap_clients:
c.update()
test_delay_event.wait(1.0)
cls.run_loop_ended()
class BaseTest(unittest.TestCase):
@ -322,12 +323,13 @@ class BaseTest(unittest.TestCase):
ltc_daemons = []
xmr_daemons = []
xmr_wallet_auth = []
restore_instance = False
extra_wait_time = 0
restore_instance: bool = False
extra_wait_time: int = 0
num_nodes: int = NUM_NODES
start_ltc_nodes = False
start_xmr_nodes = True
has_segwit = True
start_ltc_nodes: bool = False
start_xmr_nodes: bool = True
has_segwit: bool = True
xmr_addr = None
btc_addr = None
@ -392,6 +394,8 @@ class BaseTest(unittest.TestCase):
cls.stream_fp.setFormatter(formatter)
logger.addHandler(cls.stream_fp)
cls.prepareTestDir()
try:
logging.info("Preparing coin nodes.")
for i in range(NUM_NODES):
@ -645,6 +649,7 @@ class BaseTest(unittest.TestCase):
start_nodes,
cls,
)
basicswap_dir = os.path.join(
os.path.join(TEST_DIR, "basicswap_" + str(i))
)
@ -966,6 +971,10 @@ class BaseTest(unittest.TestCase):
super(BaseTest, cls).tearDownClass()
@classmethod
def prepareTestDir(cls):
pass
@classmethod
def addCoinSettings(cls, settings, datadir, node_id):
pass
@ -995,6 +1004,10 @@ class BaseTest(unittest.TestCase):
{"wallet_address": cls.xmr_addr, "amount_of_blocks": 1},
)
@classmethod
def run_loop_ended(cls):
pass
@classmethod
def waitForParticlHeight(cls, num_blocks, node_id=0):
logging.info(