mirror of
https://github.com/basicswap/basicswap.git
synced 2024-11-16 15:58:17 +00:00
Load in-progress bids only when unlocked.
This commit is contained in:
parent
3234e3fba3
commit
2922b171a6
11 changed files with 276 additions and 89 deletions
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2019-2022 tecnovert
|
||||
# Copyright (c) 2019-2023 tecnovert
|
||||
# Distributed under the MIT software license, see the accompanying
|
||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
|
||||
|
||||
|
@ -92,7 +92,7 @@ class BaseApp:
|
|||
except Exception:
|
||||
return {}
|
||||
|
||||
def setDaemonPID(self, name, pid):
|
||||
def setDaemonPID(self, name, pid) -> None:
|
||||
if isinstance(name, Coins):
|
||||
self.coin_clients[name]['pid'] = pid
|
||||
return
|
||||
|
@ -100,12 +100,12 @@ class BaseApp:
|
|||
if v['name'] == name:
|
||||
v['pid'] = pid
|
||||
|
||||
def getChainDatadirPath(self, coin):
|
||||
def getChainDatadirPath(self, coin) -> str:
|
||||
datadir = self.coin_clients[coin]['datadir']
|
||||
testnet_name = '' if self.chain == 'mainnet' else chainparams[coin][self.chain].get('name', self.chain)
|
||||
return os.path.join(datadir, testnet_name)
|
||||
|
||||
def getCoinIdFromName(self, coin_name):
|
||||
def getCoinIdFromName(self, coin_name: str):
|
||||
for c, params in chainparams.items():
|
||||
if coin_name.lower() == params['name'].lower():
|
||||
return c
|
||||
|
@ -146,7 +146,7 @@ class BaseApp:
|
|||
raise ValueError('CLI error ' + str(out[1]))
|
||||
return out[0].decode('utf-8').strip()
|
||||
|
||||
def is_transient_error(self, ex):
|
||||
def is_transient_error(self, ex) -> bool:
|
||||
if isinstance(ex, TemporaryError):
|
||||
return True
|
||||
str_error = str(ex).lower()
|
||||
|
@ -164,13 +164,13 @@ class BaseApp:
|
|||
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
def popConnectionParameters(self):
|
||||
def popConnectionParameters(self) -> None:
|
||||
if self.use_tor_proxy:
|
||||
socket.socket = self.default_socket
|
||||
socket.getaddrinfo = self.default_socket_getaddrinfo
|
||||
socket.setdefaulttimeout(self.default_socket_timeout)
|
||||
|
||||
def logException(self, message):
|
||||
def logException(self, message) -> None:
|
||||
self.log.error(message)
|
||||
if self.debug:
|
||||
self.log.error(traceback.format_exc())
|
||||
|
|
|
@ -216,6 +216,7 @@ class WatchedTransaction():
|
|||
|
||||
class BasicSwap(BaseApp):
|
||||
ws_server = None
|
||||
_read_zmq_queue: bool = True
|
||||
protocolInterfaces = {
|
||||
SwapTypes.SELLER_FIRST: atomic_swap_1.AtomicSwapInterface(),
|
||||
SwapTypes.XMR_SWAP: xmr_swap_1.XmrSwapInterface(),
|
||||
|
@ -696,7 +697,41 @@ class BasicSwap(BaseApp):
|
|||
self._network = bsn.Network(self.settings['p2p_host'], self.settings['p2p_port'], network_key, self)
|
||||
self._network.startNetwork()
|
||||
|
||||
self.initialise()
|
||||
self.log.debug('network_key %s\nnetwork_pubkey %s\nnetwork_addr %s',
|
||||
self.network_key, self.network_pubkey, self.network_addr)
|
||||
|
||||
ro = self.callrpc('smsglocalkeys')
|
||||
found = False
|
||||
for k in ro['smsg_keys']:
|
||||
if k['address'] == self.network_addr:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
self.log.info('Importing network key to SMSG')
|
||||
self.callrpc('smsgimportprivkey', [self.network_key, 'basicswap offers'])
|
||||
ro = self.callrpc('smsglocalkeys', ['anon', '-', self.network_addr])
|
||||
ensure(ro['result'] == 'Success.', 'smsglocalkeys failed')
|
||||
|
||||
# TODO: Ensure smsg is enabled for the active wallet.
|
||||
|
||||
# Initialise locked state
|
||||
_, _ = self.getLockedState()
|
||||
|
||||
# Re-load in-progress bids
|
||||
self.loadFromDB()
|
||||
|
||||
# Scan inbox
|
||||
# TODO: Redundant? small window for zmq messages to go unnoticed during startup?
|
||||
# options = {'encoding': 'hex'}
|
||||
options = {'encoding': 'none'}
|
||||
ro = self.callrpc('smsginbox', ['unread', '', options])
|
||||
nm = 0
|
||||
for msg in ro['messages']:
|
||||
# TODO: Remove workaround for smsginbox bug
|
||||
get_msg = self.callrpc('smsg', [msg['msgid'], {'encoding': 'hex', 'setread': True}])
|
||||
self.processMsg(get_msg)
|
||||
nm += 1
|
||||
self.log.info('Scanned %d unread messages.', nm)
|
||||
|
||||
def stopDaemon(self, coin):
|
||||
if coin == Coins.XMR:
|
||||
|
@ -757,6 +792,11 @@ class BasicSwap(BaseApp):
|
|||
if synced < 1.0:
|
||||
raise ValueError('{} chain is still syncing, currently at {}.'.format(self.coin_clients[c]['name'], synced))
|
||||
|
||||
def isSystemUnlocked(self):
|
||||
# TODO - Check all active coins
|
||||
ci = self.ci(Coins.PART)
|
||||
return not ci.isWalletLocked()
|
||||
|
||||
def checkSystemStatus(self):
|
||||
ci = self.ci(Coins.PART)
|
||||
if ci.isWalletLocked():
|
||||
|
@ -801,6 +841,7 @@ class BasicSwap(BaseApp):
|
|||
self._is_encrypted, self._is_locked = self.ci(Coins.PART).isWalletEncryptedLocked()
|
||||
|
||||
def unlockWallets(self, password, coin=None):
|
||||
self._read_zmq_queue = False
|
||||
for c in self.activeCoins():
|
||||
if coin and c != coin:
|
||||
continue
|
||||
|
@ -808,13 +849,20 @@ class BasicSwap(BaseApp):
|
|||
if c == Coins.PART:
|
||||
self._is_locked = False
|
||||
|
||||
self.loadFromDB()
|
||||
self._read_zmq_queue = True
|
||||
|
||||
def lockWallets(self, coin=None):
|
||||
self._read_zmq_queue = False
|
||||
self.swaps_in_progress.clear()
|
||||
|
||||
for c in self.activeCoins():
|
||||
if coin and c != coin:
|
||||
continue
|
||||
self.ci(c).lockWallet()
|
||||
if c == Coins.PART:
|
||||
self._is_locked = True
|
||||
self._read_zmq_queue = True
|
||||
|
||||
def initialiseWallet(self, coin_type, raise_errors=False):
|
||||
if coin_type == Coins.PART:
|
||||
|
@ -929,7 +977,7 @@ class BasicSwap(BaseApp):
|
|||
with self.mxDB:
|
||||
try:
|
||||
session = scoped_session(self.session_factory)
|
||||
session.execute('DELETE FROM kv_string WHERE key = "{}" '.format(str_key))
|
||||
session.execute('DELETE FROM kv_string WHERE key = :key', {'key': str_key})
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
|
@ -1037,7 +1085,10 @@ class BasicSwap(BaseApp):
|
|||
if session is None:
|
||||
self.closeSession(use_session)
|
||||
|
||||
def loadFromDB(self):
|
||||
def loadFromDB(self) -> None:
|
||||
if self.isSystemUnlocked() is False:
|
||||
self.log.info('Not loading from db. System is locked.')
|
||||
return
|
||||
self.log.info('Loading data from db')
|
||||
self.mxDB.acquire()
|
||||
self.swaps_in_progress.clear()
|
||||
|
@ -1061,39 +1112,6 @@ class BasicSwap(BaseApp):
|
|||
session.remove()
|
||||
self.mxDB.release()
|
||||
|
||||
def initialise(self):
|
||||
self.log.debug('network_key %s\nnetwork_pubkey %s\nnetwork_addr %s',
|
||||
self.network_key, self.network_pubkey, self.network_addr)
|
||||
|
||||
ro = self.callrpc('smsglocalkeys')
|
||||
found = False
|
||||
for k in ro['smsg_keys']:
|
||||
if k['address'] == self.network_addr:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
self.log.info('Importing network key to SMSG')
|
||||
self.callrpc('smsgimportprivkey', [self.network_key, 'basicswap offers'])
|
||||
ro = self.callrpc('smsglocalkeys', ['anon', '-', self.network_addr])
|
||||
ensure(ro['result'] == 'Success.', 'smsglocalkeys failed')
|
||||
|
||||
# TODO: Ensure smsg is enabled for the active wallet.
|
||||
|
||||
self.loadFromDB()
|
||||
|
||||
# Scan inbox
|
||||
# TODO: Redundant? small window for zmq messages to go unnoticed during startup?
|
||||
# options = {'encoding': 'hex'}
|
||||
options = {'encoding': 'none'}
|
||||
ro = self.callrpc('smsginbox', ['unread', '', options])
|
||||
nm = 0
|
||||
for msg in ro['messages']:
|
||||
# TODO: Remove workaround for smsginbox bug
|
||||
get_msg = self.callrpc('smsg', [msg['msgid'], {'encoding': 'hex', 'setread': True}])
|
||||
self.processMsg(get_msg)
|
||||
nm += 1
|
||||
self.log.info('Scanned %d unread messages.', nm)
|
||||
|
||||
def getActiveBidMsgValidTime(self):
|
||||
return self.SMSG_SECONDS_IN_HOUR * 48
|
||||
|
||||
|
@ -1882,7 +1900,7 @@ class BasicSwap(BaseApp):
|
|||
try:
|
||||
self._contract_count += 1
|
||||
session = scoped_session(self.session_factory)
|
||||
session.execute('UPDATE kv_int SET value = {} WHERE KEY="contract_count"'.format(self._contract_count))
|
||||
session.execute('UPDATE kv_int SET value = :value WHERE KEY="contract_count"', {'value': self._contract_count})
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
|
@ -3870,7 +3888,11 @@ class BasicSwap(BaseApp):
|
|||
c['last_height_checked'] = last_height_checked
|
||||
self.setIntKV('last_height_checked_' + chainparams[coin_type]['name'], last_height_checked)
|
||||
|
||||
def expireMessages(self):
|
||||
def expireMessages(self) -> None:
|
||||
if self._is_locked is True:
|
||||
self.log.debug('Not expiring messages while system locked')
|
||||
return
|
||||
|
||||
self.mxDB.acquire()
|
||||
rpc_conn = None
|
||||
try:
|
||||
|
@ -3947,9 +3969,9 @@ class BasicSwap(BaseApp):
|
|||
self.logException(f'checkQueuedActions failed: {ex}')
|
||||
|
||||
if self.debug:
|
||||
session.execute('UPDATE actions SET active_ind = 2 WHERE trigger_at <= {}'.format(now))
|
||||
session.execute('UPDATE actions SET active_ind = 2 WHERE trigger_at <= :now', {'now': now})
|
||||
else:
|
||||
session.execute('DELETE FROM actions WHERE trigger_at <= {}'.format(now))
|
||||
session.execute('DELETE FROM actions WHERE trigger_at <= :now', {'now': now})
|
||||
|
||||
session.commit()
|
||||
except Exception as ex:
|
||||
|
@ -5014,7 +5036,7 @@ class BasicSwap(BaseApp):
|
|||
|
||||
if coin_to == Coins.XMR:
|
||||
address_to = self.getCachedMainWalletAddress(ci_to)
|
||||
elif coin_to == Coins.PART_BLIND:
|
||||
elif coin_to in (Coins.PART_BLIND, Coins.PART_ANON):
|
||||
address_to = self.getCachedStealthAddressForCoin(coin_to)
|
||||
else:
|
||||
address_to = self.getReceiveAddressFromPool(coin_to, bid_id, TxTypes.XMR_SWAP_B_LOCK_SPEND)
|
||||
|
@ -5323,6 +5345,9 @@ class BasicSwap(BaseApp):
|
|||
rv = None
|
||||
if msg_type == MessageTypes.OFFER:
|
||||
self.processOffer(msg)
|
||||
elif msg_type == MessageTypes.OFFER_REVOKE:
|
||||
self.processOfferRevoke(msg)
|
||||
# TODO: When changing from wallet keys (encrypted/locked) handle swap messages while locked
|
||||
elif msg_type == MessageTypes.BID:
|
||||
self.processBid(msg)
|
||||
elif msg_type == MessageTypes.BID_ACCEPT:
|
||||
|
@ -5339,8 +5364,6 @@ class BasicSwap(BaseApp):
|
|||
self.processXmrSplitMessage(msg)
|
||||
elif msg_type == MessageTypes.XMR_BID_LOCK_RELEASE_LF:
|
||||
self.processXmrLockReleaseMessage(msg)
|
||||
if msg_type == MessageTypes.OFFER_REVOKE:
|
||||
self.processOfferRevoke(msg)
|
||||
|
||||
except InactiveCoin as ex:
|
||||
self.log.info('Ignoring message involving inactive coin {}, type {}'.format(Coins(ex.coinid).name, MessageTypes(msg_type).name))
|
||||
|
@ -5381,10 +5404,10 @@ class BasicSwap(BaseApp):
|
|||
|
||||
def update(self):
|
||||
try:
|
||||
# while True:
|
||||
message = self.zmqSubscriber.recv(flags=zmq.NOBLOCK)
|
||||
if message == b'smsg':
|
||||
self.processZmqSmsg()
|
||||
if self._read_zmq_queue:
|
||||
message = self.zmqSubscriber.recv(flags=zmq.NOBLOCK)
|
||||
if message == b'smsg':
|
||||
self.processZmqSmsg()
|
||||
except zmq.Again as ex:
|
||||
pass
|
||||
except Exception as ex:
|
||||
|
@ -6178,6 +6201,7 @@ class BasicSwap(BaseApp):
|
|||
|
||||
addr_info = self.callrpc('getaddressinfo', [new_addr])
|
||||
self.callrpc('smsgaddlocaladdress', [new_addr]) # Enable receiving smsgs
|
||||
self.callrpc('smsglocalkeys', ['anon', '-', new_addr])
|
||||
|
||||
use_session.add(SmsgAddress(addr=new_addr, use_type=use_type, active_ind=1, created_at=now, note=addressnote, pubkey=addr_info['pubkey']))
|
||||
return new_addr, addr_info['pubkey']
|
||||
|
@ -6193,6 +6217,7 @@ class BasicSwap(BaseApp):
|
|||
ci = self.ci(Coins.PART)
|
||||
add_addr = ci.pubkey_to_address(bytes.fromhex(pubkey_hex))
|
||||
self.callrpc('smsgaddaddress', [add_addr, pubkey_hex])
|
||||
self.callrpc('smsglocalkeys', ['anon', '-', add_addr])
|
||||
|
||||
session.add(SmsgAddress(addr=add_addr, use_type=AddressTypes.SEND_OFFER, active_ind=1, created_at=now, note=addressnote, pubkey=pubkey_hex))
|
||||
session.commit()
|
||||
|
@ -6209,7 +6234,7 @@ class BasicSwap(BaseApp):
|
|||
mode = '-' if active_ind == 0 else '+'
|
||||
self.callrpc('smsglocalkeys', ['recv', mode, address])
|
||||
|
||||
session.execute('UPDATE smsgaddresses SET active_ind = {}, note = "{}" WHERE addr = "{}"'.format(active_ind, addressnote, address))
|
||||
session.execute('UPDATE smsgaddresses SET active_ind = :active_ind, note = :note WHERE addr = :addr', {'active_ind': active_ind, 'note': addressnote, 'addr': address})
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2019-2022 tecnovert
|
||||
# Copyright (c) 2019-2023 tecnovert
|
||||
# Distributed under the MIT software license, see the accompanying
|
||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
|
||||
|
||||
|
@ -374,29 +374,29 @@ class CoinInterface:
|
|||
ticker = 'rt' + ticker
|
||||
return ticker
|
||||
|
||||
def getExchangeTicker(self, exchange_name):
|
||||
def getExchangeTicker(self, exchange_name: str) -> str:
|
||||
return chainparams[self.coin_type()]['ticker']
|
||||
|
||||
def getExchangeName(self, exchange_name):
|
||||
def getExchangeName(self, exchange_name: str) -> str:
|
||||
return chainparams[self.coin_type()]['name']
|
||||
|
||||
def ticker_mainnet(self):
|
||||
def ticker_mainnet(self) -> str:
|
||||
ticker = chainparams[self.coin_type()]['ticker']
|
||||
return ticker
|
||||
|
||||
def min_amount(self):
|
||||
def min_amount(self) -> int:
|
||||
return chainparams[self.coin_type()][self._network]['min_amount']
|
||||
|
||||
def max_amount(self):
|
||||
def max_amount(self) -> int:
|
||||
return chainparams[self.coin_type()][self._network]['max_amount']
|
||||
|
||||
def setWalletSeedWarning(self, value):
|
||||
def setWalletSeedWarning(self, value: bool) -> None:
|
||||
self._unknown_wallet_seed = value
|
||||
|
||||
def setWalletRestoreHeight(self, value):
|
||||
def setWalletRestoreHeight(self, value: int) -> None:
|
||||
self._restore_height = value
|
||||
|
||||
def knownWalletSeed(self):
|
||||
def knownWalletSeed(self) -> bool:
|
||||
return not self._unknown_wallet_seed
|
||||
|
||||
def chainparams(self):
|
||||
|
@ -408,13 +408,13 @@ class CoinInterface:
|
|||
def has_segwit(self) -> bool:
|
||||
return chainparams[self.coin_type()].get('has_segwit', True)
|
||||
|
||||
def is_transient_error(self, ex):
|
||||
def is_transient_error(self, ex) -> bool:
|
||||
if isinstance(ex, TemporaryError):
|
||||
return True
|
||||
str_error = str(ex).lower()
|
||||
str_error: str = str(ex).lower()
|
||||
if 'not enough unlocked money' in str_error:
|
||||
return True
|
||||
if 'No unlocked balance' in str_error:
|
||||
if 'no unlocked balance' in str_error:
|
||||
return True
|
||||
if 'transaction was rejected by daemon' in str_error:
|
||||
return True
|
||||
|
|
|
@ -423,6 +423,7 @@ class KnownIdentity(Base):
|
|||
num_recv_bids_failed = sa.Column(sa.Integer)
|
||||
automation_override = sa.Column(sa.Integer) # AutomationOverrideOptions
|
||||
visibility_override = sa.Column(sa.Integer) # VisibilityOverrideOptions
|
||||
data = sa.Column(sa.LargeBinary)
|
||||
note = sa.Column(sa.String)
|
||||
updated_at = sa.Column(sa.BigInteger)
|
||||
created_at = sa.Column(sa.BigInteger)
|
||||
|
|
|
@ -238,10 +238,11 @@ def upgradeDatabase(self, db_version):
|
|||
tx_data BLOB,
|
||||
used_by BLOB,
|
||||
PRIMARY KEY (record_id))''')
|
||||
elif current_version == 16:
|
||||
elif current_version == 17:
|
||||
db_version += 1
|
||||
session.execute('ALTER TABLE knownidentities ADD COLUMN automation_override INTEGER')
|
||||
session.execute('ALTER TABLE knownidentities ADD COLUMN visibility_override INTEGER')
|
||||
session.execute('ALTER TABLE knownidentities ADD COLUMN data BLOB')
|
||||
session.execute('UPDATE knownidentities SET active_ind = 1')
|
||||
|
||||
if current_version != db_version:
|
||||
|
|
|
@ -684,7 +684,7 @@ class PARTInterfaceBlind(PARTInterface):
|
|||
return -1
|
||||
return None
|
||||
|
||||
def spendBLockTx(self, chain_b_lock_txid, address_to, kbv, kbs, cb_swap_value, b_fee, restore_height, spend_actual_balance=False):
|
||||
def spendBLockTx(self, chain_b_lock_txid: bytes, address_to: str, kbv: bytes, kbs: bytes, cb_swap_value: int, b_fee: int, restore_height: int, spend_actual_balance: bool = False) -> bytes:
|
||||
Kbv = self.getPubkey(kbv)
|
||||
Kbs = self.getPubkey(kbs)
|
||||
sx_addr = self.formatStealthAddress(Kbv, Kbs)
|
||||
|
@ -813,7 +813,7 @@ class PARTInterfaceAnon(PARTInterface):
|
|||
return -1
|
||||
return None
|
||||
|
||||
def spendBLockTx(self, chain_b_lock_txid, address_to, kbv, kbs, cb_swap_value, b_fee, restore_height, spend_actual_balance=False):
|
||||
def spendBLockTx(self, chain_b_lock_txid: bytes, address_to: str, kbv: bytes, kbs: bytes, cb_swap_value: int, b_fee: int, restore_height: int, spend_actual_balance: bool = False) -> bytes:
|
||||
Kbv = self.getPubkey(kbv)
|
||||
Kbs = self.getPubkey(kbs)
|
||||
sx_addr = self.formatStealthAddress(Kbv, Kbs)
|
||||
|
|
|
@ -417,7 +417,7 @@ class XMRInterface(CoinInterface):
|
|||
|
||||
return bytes.fromhex(rv['tx_hash_list'][0])
|
||||
|
||||
def withdrawCoin(self, value, addr_to, subfee):
|
||||
def withdrawCoin(self, value: int, addr_to: str, subfee: bool) -> str:
|
||||
with self._mx_wallet:
|
||||
value_sats = make_int(value, self.exp())
|
||||
|
||||
|
@ -427,7 +427,7 @@ class XMRInterface(CoinInterface):
|
|||
if subfee:
|
||||
balance = self.rpc_wallet_cb('get_balance')
|
||||
diff = balance['unlocked_balance'] - value_sats
|
||||
if diff > 0 and diff <= 10:
|
||||
if diff >= 0 and diff <= 10:
|
||||
self._log.info('subfee enabled and value close to total, using sweep_all.')
|
||||
params = {'address': addr_to}
|
||||
if self._fee_priority > 0:
|
||||
|
|
|
@ -344,7 +344,7 @@ def js_bids(self, url_split, post_string: str, is_json: bool) -> bytes:
|
|||
data = describeBid(swap_client, bid, xmr_swap, offer, xmr_offer, events, edit_bid, show_txns, for_api=True)
|
||||
return bytes(json.dumps(data), 'UTF-8')
|
||||
|
||||
post_data = getFormData(post_string, is_json)
|
||||
post_data = {} if post_string == '' else getFormData(post_string, is_json)
|
||||
offer_id, filters = parseBidFilters(post_data)
|
||||
|
||||
bids = swap_client.listBids(offer_id=offer_id, filters=filters)
|
||||
|
|
|
@ -66,7 +66,7 @@ def dumpje(jin):
|
|||
return json.dumps(jin, default=jsonDecimal).replace('"', '\\"')
|
||||
|
||||
|
||||
def SerialiseNum(n):
|
||||
def SerialiseNum(n: int) -> bytes:
|
||||
if n == 0:
|
||||
return bytes((0x00,))
|
||||
if n > 0 and n <= 16:
|
||||
|
@ -84,7 +84,7 @@ def SerialiseNum(n):
|
|||
return bytes((len(rv),)) + rv
|
||||
|
||||
|
||||
def DeserialiseNum(b, o=0) -> int:
|
||||
def DeserialiseNum(b: bytes, o: int = 0) -> int:
|
||||
if b[o] == 0:
|
||||
return 0
|
||||
if b[o] > 0x50 and b[o] <= 0x50 + 16:
|
||||
|
@ -100,13 +100,13 @@ def DeserialiseNum(b, o=0) -> int:
|
|||
return v
|
||||
|
||||
|
||||
def float_to_str(f):
|
||||
def float_to_str(f: float) -> str:
|
||||
# stackoverflow.com/questions/38847690
|
||||
d1 = decimal_ctx.create_decimal(repr(f))
|
||||
return format(d1, 'f')
|
||||
|
||||
|
||||
def make_int(v, scale=8, r=0): # r = 0, no rounding, fail, r > 0 round up, r < 0 floor
|
||||
def make_int(v, scale=8, r=0) -> int: # r = 0, no rounding, fail, r > 0 round up, r < 0 floor
|
||||
if type(v) == float:
|
||||
v = float_to_str(v)
|
||||
elif type(v) == int:
|
||||
|
@ -177,7 +177,7 @@ def format_amount(i, display_scale, scale=None):
|
|||
return rv
|
||||
|
||||
|
||||
def format_timestamp(value: int, with_seconds=False) -> str:
|
||||
def format_timestamp(value: int, with_seconds: bool = False) -> str:
|
||||
str_format = '%Y-%m-%d %H:%M'
|
||||
if with_seconds:
|
||||
str_format += ':%S'
|
||||
|
@ -185,7 +185,7 @@ def format_timestamp(value: int, with_seconds=False) -> str:
|
|||
return time.strftime(str_format, time.localtime(value))
|
||||
|
||||
|
||||
def b2i(b) -> int:
|
||||
def b2i(b: bytes) -> int:
|
||||
# bytes32ToInt
|
||||
return int.from_bytes(b, byteorder='big')
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2022 tecnovert
|
||||
# Copyright (c) 2022-2023 tecnovert
|
||||
# Distributed under the MIT software license, see the accompanying
|
||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
|
||||
|
||||
|
@ -59,7 +59,7 @@ def b58encode(v):
|
|||
return (__b58chars[0] * nPad) + result
|
||||
|
||||
|
||||
def encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey):
|
||||
def encodeStealthAddress(prefix_byte: int, scan_pubkey: bytes, spend_pubkey: bytes) -> str:
|
||||
data = bytes((0x00,))
|
||||
data += scan_pubkey
|
||||
data += bytes((0x01,))
|
||||
|
@ -72,14 +72,14 @@ def encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey):
|
|||
return b58encode(b)
|
||||
|
||||
|
||||
def decodeWif(encoded_key):
|
||||
def decodeWif(encoded_key: str) -> bytes:
|
||||
key = b58decode(encoded_key)[1:-4]
|
||||
if len(key) == 33:
|
||||
return key[:-1]
|
||||
return key
|
||||
|
||||
|
||||
def toWIF(prefix_byte, b, compressed=True):
|
||||
def toWIF(prefix_byte: int, b: bytes, compressed: bool = True) -> str:
|
||||
b = bytes((prefix_byte,)) + b
|
||||
if compressed:
|
||||
b += bytes((0x01,))
|
||||
|
@ -87,9 +87,9 @@ def toWIF(prefix_byte, b, compressed=True):
|
|||
return b58encode(b)
|
||||
|
||||
|
||||
def getKeyID(bytes):
|
||||
data = hashlib.sha256(bytes).digest()
|
||||
return ripemd160(data)
|
||||
def getKeyID(key_data: bytes) -> str:
|
||||
sha256_hash = hashlib.sha256(key_data).digest()
|
||||
return ripemd160(sha256_hash)
|
||||
|
||||
|
||||
def bech32Decode(hrp, addr):
|
||||
|
@ -109,7 +109,7 @@ def bech32Encode(hrp, data):
|
|||
return ret
|
||||
|
||||
|
||||
def decodeAddress(address_str):
|
||||
def decodeAddress(address_str: str):
|
||||
b58_addr = b58decode(address_str)
|
||||
if b58_addr is not None:
|
||||
address = b58_addr[:-4]
|
||||
|
@ -119,10 +119,10 @@ def decodeAddress(address_str):
|
|||
return None
|
||||
|
||||
|
||||
def encodeAddress(address):
|
||||
def encodeAddress(address: bytes) -> str:
|
||||
checksum = hashlib.sha256(hashlib.sha256(address).digest()).digest()
|
||||
return b58encode(address + checksum[0:4])
|
||||
|
||||
|
||||
def pubkeyToAddress(prefix, pubkey):
|
||||
def pubkeyToAddress(prefix: int, pubkey: bytes) -> str:
|
||||
return encodeAddress(bytes((prefix,)) + getKeyID(pubkey))
|
||||
|
|
160
tests/basicswap/extended/test_encrypted_xmr_reload.py
Normal file
160
tests/basicswap/extended/test_encrypted_xmr_reload.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2020-2023 tecnovert
|
||||
# Distributed under the MIT software license, see the accompanying
|
||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
|
||||
|
||||
"""
|
||||
export TEST_PATH=/tmp/test_basicswap
|
||||
mkdir -p ${TEST_PATH}/bin
|
||||
cp -r ~/tmp/basicswap_bin/* ${TEST_PATH}/bin
|
||||
export PYTHONPATH=$(pwd)
|
||||
python tests/basicswap/extended/test_encrypted_xmr_reload.py
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import unittest
|
||||
import multiprocessing
|
||||
|
||||
from tests.basicswap.util import (
|
||||
read_json_api,
|
||||
post_json_api,
|
||||
waitForServer,
|
||||
)
|
||||
from tests.basicswap.common import (
|
||||
waitForNumOffers,
|
||||
waitForNumBids,
|
||||
waitForNumSwapping,
|
||||
)
|
||||
from tests.basicswap.common_xmr import (
|
||||
XmrTestBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.level = logging.DEBUG
|
||||
if not len(logger.handlers):
|
||||
logger.addHandler(logging.StreamHandler(sys.stdout))
|
||||
|
||||
|
||||
class Test(XmrTestBase):
|
||||
|
||||
def test_reload(self):
|
||||
self.start_processes()
|
||||
|
||||
waitForServer(self.delay_event, 12700)
|
||||
waitForServer(self.delay_event, 12701)
|
||||
wallets1 = read_json_api(12701, 'wallets')
|
||||
assert (float(wallets1['XMR']['balance']) > 0.0)
|
||||
|
||||
node1_password: str = 'notapassword123'
|
||||
logger.info('Encrypting node 1 wallets')
|
||||
rv = read_json_api(12701, 'setpassword', {'oldpassword': '', 'newpassword': node1_password})
|
||||
assert ('success' in rv)
|
||||
rv = read_json_api(12701, 'unlock', {'password': node1_password})
|
||||
assert ('success' in rv)
|
||||
|
||||
data = {
|
||||
'addr_from': '-1',
|
||||
'coin_from': 'part',
|
||||
'coin_to': 'xmr',
|
||||
'amt_from': '1',
|
||||
'amt_to': '1',
|
||||
'lockhrs': '24'}
|
||||
|
||||
offer_id = post_json_api(12700, 'offers/new', data)['offer_id']
|
||||
summary = read_json_api(12700)
|
||||
assert (summary['num_sent_offers'] == 1)
|
||||
|
||||
logger.info('Waiting for offer')
|
||||
waitForNumOffers(self.delay_event, 12701, 1)
|
||||
|
||||
offers = read_json_api(12701, 'offers')
|
||||
offer = offers[0]
|
||||
|
||||
data = {
|
||||
'offer_id': offer['offer_id'],
|
||||
'amount_from': offer['amount_from']}
|
||||
|
||||
data['valid_for_seconds'] = 24 * 60 * 60 + 1
|
||||
bid = post_json_api(12701, 'bids/new', data)
|
||||
assert (bid['error'] == 'Bid TTL too high')
|
||||
del data['valid_for_seconds']
|
||||
data['validmins'] = 24 * 60 + 1
|
||||
bid = post_json_api(12701, 'bids/new', data)
|
||||
assert (bid['error'] == 'Bid TTL too high')
|
||||
|
||||
del data['validmins']
|
||||
data['valid_for_seconds'] = 10
|
||||
bid = post_json_api(12701, 'bids/new', data)
|
||||
assert (bid['error'] == 'Bid TTL too low')
|
||||
del data['valid_for_seconds']
|
||||
data['validmins'] = 1
|
||||
bid = post_json_api(12701, 'bids/new', data)
|
||||
assert (bid['error'] == 'Bid TTL too low')
|
||||
|
||||
data['validmins'] = 60
|
||||
bid_id = post_json_api(12701, 'bids/new', data)
|
||||
|
||||
waitForNumBids(self.delay_event, 12700, 1)
|
||||
|
||||
for i in range(10):
|
||||
bids = read_json_api(12700, 'bids')
|
||||
bid = bids[0]
|
||||
if bid['bid_state'] == 'Received':
|
||||
break
|
||||
self.delay_event.wait(1)
|
||||
assert (bid['expire_at'] == bid['created_at'] + data['validmins'] * 60)
|
||||
|
||||
data = {
|
||||
'accept': True
|
||||
}
|
||||
rv = post_json_api(12700, 'bids/{}'.format(bid['bid_id']), data)
|
||||
assert (rv['bid_state'] == 'Accepted')
|
||||
|
||||
waitForNumSwapping(self.delay_event, 12701, 1)
|
||||
|
||||
logger.info('Restarting node 1')
|
||||
c1 = self.processes[1]
|
||||
c1.terminate()
|
||||
c1.join()
|
||||
self.processes[1] = multiprocessing.Process(target=self.run_thread, args=(1,))
|
||||
self.processes[1].start()
|
||||
|
||||
waitForServer(self.delay_event, 12701)
|
||||
rv = read_json_api(12701)
|
||||
assert ('error' in rv)
|
||||
|
||||
logger.info('Unlocking node 1')
|
||||
rv = read_json_api(12701, 'unlock', {'password': node1_password})
|
||||
assert ('success' in rv)
|
||||
rv = read_json_api(12701)
|
||||
assert (rv['num_swapping'] == 1)
|
||||
|
||||
rv = read_json_api(12700, 'revokeoffer/{}'.format(offer_id))
|
||||
assert (rv['revoked_offer'] == offer_id)
|
||||
|
||||
logger.info('Completing swap')
|
||||
for i in range(240):
|
||||
if self.delay_event.is_set():
|
||||
raise ValueError('Test stopped.')
|
||||
self.delay_event.wait(4)
|
||||
|
||||
rv = read_json_api(12700, 'bids/{}'.format(bid['bid_id']))
|
||||
if rv['bid_state'] == 'Completed':
|
||||
break
|
||||
assert (rv['bid_state'] == 'Completed')
|
||||
|
||||
# Ensure offer was revoked
|
||||
summary = read_json_api(12700)
|
||||
assert (summary['num_network_offers'] == 0)
|
||||
|
||||
# Wait for bid to be removed from in-progress
|
||||
waitForNumBids(self.delay_event, 12700, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in a new issue