diff --git a/basicswap/base.py b/basicswap/base.py index 2f21947..a7681f7 100644 --- a/basicswap/base.py +++ b/basicswap/base.py @@ -46,7 +46,7 @@ class BaseApp: self.settings = settings self.coin_clients = {} self.coin_interfaces = {} - self.mxDB = threading.RLock() + self.mxDB = threading.Lock() self.debug = self.settings.get('debug', False) self.delay_event = threading.Event() self.chainstate_delay_event = threading.Event() diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 0dc174a..132106d 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -425,14 +425,16 @@ class BasicSwap(BaseApp): def openSession(self, session=None): if session: return session + self.mxDB.acquire() return scoped_session(self.session_factory) - def closeSession(self, use_session, commit=True): + def closeSession(self, session, commit=True): if commit: - use_session.commit() - use_session.close() - use_session.remove() + session.commit() + + session.close() + session.remove() self.mxDB.release() def handleSessionErrors(self, e, session, tag): @@ -459,17 +461,18 @@ class BasicSwap(BaseApp): rpcauth = chain_client_settings['rpcuser'] + ':' + chain_client_settings['rpcpassword'] self.log.debug(f'Read {Coins(coin).name} rpc credentials from json settings') - session = scoped_session(self.session_factory) try: - last_height_checked = session.query(DBKVInt).filter_by(key='last_height_checked_' + chainparams[coin]['name']).first().value - except Exception: - last_height_checked = 0 - try: - block_check_min_time = session.query(DBKVInt).filter_by(key='block_check_min_time_' + chainparams[coin]['name']).first().value - except Exception: - block_check_min_time = 0xffffffffffffffff - session.close() - session.remove() + session = self.openSession() + try: + last_height_checked = session.query(DBKVInt).filter_by(key='last_height_checked_' + chainparams[coin]['name']).first().value + except Exception: + last_height_checked = 0 + try: + block_check_min_time = session.query(DBKVInt).filter_by(key='block_check_min_time_' + chainparams[coin]['name']).first().value + except Exception: + block_check_min_time = 0xffffffffffffffff + finally: + self.closeSession(session) coin_chainparams = chainparams[coin] default_segwit = coin_chainparams.get('has_segwit', False) @@ -1100,22 +1103,18 @@ class BasicSwap(BaseApp): identity_stats.updated_at = self.getTime() session.add(identity_stats) - def setIntKVInSession(self, str_key: str, int_val: int, session) -> None: - kv = session.query(DBKVInt).filter_by(key=str_key).first() - if not kv: - kv = DBKVInt(key=str_key, value=int_val) - else: - kv.value = int_val - session.add(kv) - - def setIntKV(self, str_key: str, int_val: int) -> None: - session = self.openSession() + def setIntKV(self, str_key: str, int_val: int, session=None) -> None: try: - session = scoped_session(self.session_factory) - self.setIntKVInSession(str_key, int_val, session) - session.commit() + use_session = self.openSession(session) + kv = use_session.query(DBKVInt).filter_by(key=str_key).first() + if not kv: + kv = DBKVInt(key=str_key, value=int_val) + else: + kv.value = int_val + use_session.add(kv) finally: - self.closeSession(session, commit=False) + if session is None: + self.closeSession(use_session) def setStringKV(self, str_key: str, str_val: str, session=None) -> None: try: @@ -1178,7 +1177,7 @@ class BasicSwap(BaseApp): if offer.swap_type == SwapTypes.XMR_SWAP: xmr_swap = session.query(XmrSwap).filter_by(bid_id=bid.bid_id).first() - self.watchXmrSwap(bid, offer, xmr_swap) + self.watchXmrSwap(bid, offer, xmr_swap, session) if self.ci(coin_to).watch_blocks_for_scripts() and bid.xmr_a_lock_tx and bid.xmr_a_lock_tx.chain_height: if not bid.xmr_b_lock_tx or not bid.xmr_b_lock_tx.txid: ci_from = self.ci(coin_from) @@ -1188,7 +1187,7 @@ class BasicSwap(BaseApp): chain_b_block_header = ci_to.getBlockHeaderAt(block_time) dest_script = ci_to.getPkDest(xmr_swap.pkbs) self.addWatchedScript(ci_to.coin_type(), bid.bid_id, dest_script, TxTypes.XMR_SWAP_B_LOCK) - self.setLastHeightCheckedStart(ci_to.coin_type(), chain_b_block_header['height']) + self.setLastHeightCheckedStart(ci_to.coin_type(), chain_b_block_header['height'], session) else: self.swaps_in_progress[bid.bid_id] = (bid, offer) @@ -1200,14 +1199,14 @@ class BasicSwap(BaseApp): if bid.participate_tx and bid.participate_tx.txid is None: self.addWatchedScript(coin_to, bid.bid_id, self.ci(coin_to).getScriptDest(bid.participate_tx.script), TxTypes.PTX) if bid.initiate_tx and bid.initiate_tx.chain_height: - self.setLastHeightCheckedStart(coin_to, bid.initiate_tx.chain_height) + self.setLastHeightCheckedStart(coin_to, bid.initiate_tx.chain_height, session) if self.coin_clients[coin_from]['last_height_checked'] < 1: if bid.initiate_tx and bid.initiate_tx.chain_height: - self.setLastHeightCheckedStart(coin_from, bid.initiate_tx.chain_height) + self.setLastHeightCheckedStart(coin_from, bid.initiate_tx.chain_height, session) if self.coin_clients[coin_to]['last_height_checked'] < 1: if bid.participate_tx and bid.participate_tx.chain_height: - self.setLastHeightCheckedStart(coin_to, bid.participate_tx.chain_height) + self.setLastHeightCheckedStart(coin_to, bid.participate_tx.chain_height, session) # TODO process addresspool if bid has previously been abandoned @@ -1610,11 +1609,10 @@ class BasicSwap(BaseApp): reverse_bid: bool = self.is_reverse_ads_bid(coin_from) - self.mxDB.acquire() - session = None try: + session = self.openSession() self.checkCoinsReady(coin_from_t, coin_to_t) - offer_addr = self.newSMSGAddress(use_type=AddressTypes.OFFER)[0] if addr_send_from is None else addr_send_from + offer_addr = self.newSMSGAddress(use_type=AddressTypes.OFFER, session=session)[0] if addr_send_from is None else addr_send_from offer_created_at = self.getTime() msg_buf = OfferMessage() @@ -1690,7 +1688,6 @@ class BasicSwap(BaseApp): raise ValueError('Security token must be 20 bytes long.') bid_reversed: bool = msg_buf.swap_type == SwapTypes.XMR_SWAP and self.is_reverse_ads_bid(msg_buf.coin_from) - session = scoped_session(self.session_factory) offer = Offer( offer_id=offer_id, active_ind=1, @@ -1749,13 +1746,8 @@ class BasicSwap(BaseApp): session.add(offer) session.add(SentOffer(offer_id=offer_id)) - session.commit() - finally: - if session: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session) self.log.info('Sent OFFER %s', offer_id.hex()) return offer_id @@ -1912,12 +1904,11 @@ class BasicSwap(BaseApp): return sha256(bytes(self.callcoinrpc(Coins.PART, 'extkey', ['info', evkey, path])['key_info']['result'], 'utf-8')) - def getReceiveAddressFromPool(self, coin_type, bid_id: bytes, tx_type): + def getReceiveAddressFromPool(self, coin_type, bid_id: bytes, tx_type, session=None): self.log.debug('Get address from pool bid_id {}, type {}, coin {}'.format(bid_id.hex(), tx_type, coin_type)) - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) - record = session.query(PooledAddress).filter(sa.and_(PooledAddress.coin_type == int(coin_type), PooledAddress.bid_id == None)).first() # noqa: E712,E711 + use_session = self.openSession(session) + record = use_session.query(PooledAddress).filter(sa.and_(PooledAddress.coin_type == int(coin_type), PooledAddress.bid_id == None)).first() # noqa: E712,E711 if not record: address = self.getReceiveAddressForCoin(coin_type) record = PooledAddress( @@ -1927,19 +1918,17 @@ class BasicSwap(BaseApp): record.tx_type = tx_type addr = record.addr ensure(self.ci(coin_type).isAddressMine(addr), 'Pool address not owned by wallet!') - session.add(record) - session.commit() + use_session.add(record) + use_session.commit() finally: - session.close() - session.remove() - self.mxDB.release() + if session is None: + self.closeSession(use_session, commit=False) return addr def returnAddressToPool(self, bid_id: bytes, tx_type): self.log.debug('Return address to pool bid_id {}, type {}'.format(bid_id.hex(), tx_type)) - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) + session = self.openSession() try: record = session.query(PooledAddress).filter(sa.and_(PooledAddress.bid_id == bid_id, PooledAddress.tx_type == tx_type)).one() self.log.debug('Returning address to pool addr {}'.format(record.addr)) @@ -1948,9 +1937,7 @@ class BasicSwap(BaseApp): except Exception as ex: pass finally: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session, commit=False) def getReceiveAddressForCoin(self, coin_type): new_addr = self.ci(coin_type).getNewAddress(self.coin_clients[coin_type]['use_segwit']) @@ -2000,21 +1987,21 @@ class BasicSwap(BaseApp): self.log.debug('In txn: {}'.format(txid)) return txid - def cacheNewAddressForCoin(self, coin_type): + def cacheNewAddressForCoin(self, coin_type, session=None): self.log.debug('cacheNewAddressForCoin %s', Coins(coin_type).name) key_str = 'receive_addr_' + self.ci(coin_type).coin_name().lower() addr = self.getReceiveAddressForCoin(coin_type) - self.setStringKV(key_str, addr) + self.setStringKV(key_str, addr, session) return addr - def getCachedMainWalletAddress(self, ci): + def getCachedMainWalletAddress(self, ci, session=None): db_key = 'main_wallet_addr_' + ci.coin_name().lower() - cached_addr = self.getStringKV(db_key) + cached_addr = self.getStringKV(db_key, session) if cached_addr is not None: return cached_addr self.log.warning(f'Setting {db_key}') main_address = ci.getMainWalletAddress() - self.setStringKV(db_key, main_address) + self.setStringKV(db_key, main_address, session) return main_address def checkWalletSeed(self, c): @@ -2102,62 +2089,60 @@ class BasicSwap(BaseApp): self.setStringKV(key_str, addr) return addr - def getCachedStealthAddressForCoin(self, coin_type): + def getCachedStealthAddressForCoin(self, coin_type, session=None): self.log.debug('getCachedStealthAddressForCoin %s', Coins(coin_type).name) if coin_type == Coins.LTC_MWEB: coin_type = Coins.LTC ci = self.ci(coin_type) key_str = 'stealth_addr_' + ci.coin_name().lower() - session = self.openSession() + use_session = self.openSession(session) try: try: - addr = session.query(DBKVString).filter_by(key=key_str).first().value + addr = use_session.query(DBKVString).filter_by(key=key_str).first().value except Exception: addr = ci.getNewStealthAddress() self.log.info('Generated new stealth address for %s', coin_type) - session.add(DBKVString( + use_session.add(DBKVString( key=key_str, value=addr )) finally: - self.closeSession(session) + if session is None: + self.closeSession(use_session) return addr - def getCachedWalletRestoreHeight(self, ci): + def getCachedWalletRestoreHeight(self, ci, session=None): self.log.debug('getCachedWalletRestoreHeight %s', ci.coin_name()) key_str = 'restore_height_' + ci.coin_name().lower() - session = self.openSession() + use_session = self.openSession(session) try: try: - wrh = session.query(DBKVInt).filter_by(key=key_str).first().value + wrh = use_session.query(DBKVInt).filter_by(key=key_str).first().value except Exception: wrh = ci.getWalletRestoreHeight() self.log.info('Found restore height for %s, block %d', ci.coin_name(), wrh) - session.add(DBKVInt( + use_session.add(DBKVInt( key=key_str, value=wrh )) finally: - self.closeSession(session) + if session is None: + self.closeSession(use_session) return wrh - def getWalletRestoreHeight(self, ci): + def getWalletRestoreHeight(self, ci, session=None): wrh = ci._restore_height if wrh is not None: return wrh - found_height = self.getCachedWalletRestoreHeight(ci) + found_height = self.getCachedWalletRestoreHeight(ci, session=session) ci.setWalletRestoreHeight(found_height) return found_height - def getNewContractId(self): - session = self.openSession() - try: - self._contract_count += 1 - session.execute('UPDATE kv_int SET value = :value WHERE KEY="contract_count"', {'value': self._contract_count}) - finally: - self.closeSession(session) + def getNewContractId(self, session): + self._contract_count += 1 + session.execute('UPDATE kv_int SET value = :value WHERE KEY="contract_count"', {'value': self._contract_count}) return self._contract_count def getProofOfFunds(self, coin_type, amount_for: int, extra_commit_bytes): @@ -2375,8 +2360,8 @@ class BasicSwap(BaseApp): amount, amount_to, bid_rate = self.setBidAmounts(amount, offer, extra_options, ci_from) self.validateBidAmount(offer, amount, bid_rate) - self.mxDB.acquire() try: + session = self.openSession() self.checkCoinsReady(coin_from, coin_to) msg_buf = BidMessage() @@ -2395,7 +2380,7 @@ class BasicSwap(BaseApp): if len(proof_utxos) > 0: msg_buf.proof_utxos = ci_to.encodeProofUtxos(proof_utxos) - contract_count = self.getNewContractId() + contract_count = self.getNewContractId(session) contract_pubkey = self.getContractPubkey(dt.datetime.fromtimestamp(now).date(), contract_count) msg_buf.pkhash_buyer = ci_from.pkh(contract_pubkey) pkhash_buyer_to = ci_to.pkh(contract_pubkey) @@ -2408,7 +2393,7 @@ class BasicSwap(BaseApp): bid_bytes = msg_buf.to_bytes() payload_hex = str.format('{:02x}', MessageTypes.BID) + bid_bytes.hex() - bid_addr = self.newSMSGAddress(use_type=AddressTypes.BID)[0] if addr_send_from is None else addr_send_from + bid_addr = self.newSMSGAddress(use_type=AddressTypes.BID, session=session)[0] if addr_send_from is None else addr_send_from msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid) @@ -2437,28 +2422,20 @@ class BasicSwap(BaseApp): if len(msg_buf.pkhash_buyer_to) > 0: bid.pkhash_buyer_to = msg_buf.pkhash_buyer_to - try: - session = scoped_session(self.session_factory) - self.saveBidInSession(bid_id, bid, session) - session.commit() - finally: - session.close() - session.remove() + self.saveBidInSession(bid_id, bid, session) self.log.info('Sent BID %s', bid_id.hex()) return bid_id finally: - self.mxDB.release() + self.closeSession(session) - def getOffer(self, offer_id: bytes, sent: bool = False): - self.mxDB.acquire() + def getOffer(self, offer_id: bytes, sent: bool = False, session=None): try: - session = scoped_session(self.session_factory) - return session.query(Offer).filter_by(offer_id=offer_id).first() + use_session = self.openSession(session) + return use_session.query(Offer).filter_by(offer_id=offer_id).first() finally: - session.close() - session.remove() - self.mxDB.release() + if session is None: + self.closeSession(use_session, commit=False) def setTxBlockInfoFromHeight(self, ci, tx, height: int) -> None: try: @@ -2494,14 +2471,11 @@ class BasicSwap(BaseApp): return bid, xmr_swap def getXmrBid(self, bid_id: bytes, sent: bool = False): - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) + session = self.openSession() return self.getXmrBidFromSession(session, bid_id, sent) finally: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session, commit=False) def getXmrOfferFromSession(self, session, offer_id: bytes, sent: bool = False): offer = session.query(Offer).filter_by(offer_id=offer_id).first() @@ -2510,15 +2484,13 @@ class BasicSwap(BaseApp): xmr_offer = session.query(XmrOffer).filter_by(offer_id=offer_id).first() return offer, xmr_offer - def getXmrOffer(self, offer_id: bytes, sent: bool = False): - self.mxDB.acquire() + def getXmrOffer(self, offer_id: bytes, sent: bool = False, session=None): try: - session = scoped_session(self.session_factory) - return self.getXmrOfferFromSession(session, offer_id, sent) + use_session = self.openSession(session) + return self.getXmrOfferFromSession(use_session, offer_id, sent) finally: - session.close() - session.remove() - self.mxDB.release() + if session is None: + self.closeSession(use_session, commit=False) def getBid(self, bid_id: bytes, session=None): try: @@ -2545,9 +2517,8 @@ class BasicSwap(BaseApp): self.closeSession(use_session, commit=False) def getXmrBidAndOffer(self, bid_id: bytes, list_events=True): - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) + session = self.openSession() xmr_swap = None offer = None xmr_offer = None @@ -2565,20 +2536,15 @@ class BasicSwap(BaseApp): return bid, xmr_swap, offer, xmr_offer, events finally: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session, commit=False) def getIdentity(self, address: str): - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) + session = self.openSession() identity = session.query(KnownIdentity).filter_by(address=address).first() return identity finally: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session, commit=False) def list_bid_events(self, bid_id: bytes, session): query_str = 'SELECT created_at, event_type, event_msg FROM eventlog ' + \ @@ -2596,115 +2562,122 @@ class BasicSwap(BaseApp): return events - def acceptBid(self, bid_id: bytes) -> None: + def acceptBid(self, bid_id: bytes, session=None) -> None: self.log.info('Accepting bid %s', bid_id.hex()) - bid, offer = self.getBidAndOffer(bid_id) - ensure(bid, 'Bid not found') - ensure(offer, 'Offer not found') + try: + use_session = self.openSession(session) - # Ensure bid is still valid - now: int = self.getTime() - ensure(bid.expire_at > now, 'Bid expired') - ensure(bid.state in (BidStates.BID_RECEIVED, ), 'Wrong bid state: {}'.format(BidStates(bid.state).name)) + bid, offer = self.getBidAndOffer(bid_id, use_session) + ensure(bid, 'Bid not found') + ensure(offer, 'Offer not found') - if offer.swap_type == SwapTypes.XMR_SWAP: - ensure(bid.protocol_version >= MINPROTO_VERSION_ADAPTOR_SIG, 'Incompatible bid protocol version') - reverse_bid: bool = self.is_reverse_ads_bid(offer.coin_from) - if reverse_bid: - return self.acceptADSReverseBid(bid_id) - return self.acceptXmrBid(bid_id) + # Ensure bid is still valid + now: int = self.getTime() + ensure(bid.expire_at > now, 'Bid expired') + ensure(bid.state in (BidStates.BID_RECEIVED, ), 'Wrong bid state: {}'.format(BidStates(bid.state).name)) - ensure(bid.protocol_version >= MINPROTO_VERSION_SECRET_HASH, 'Incompatible bid protocol version') - if bid.contract_count is None: - bid.contract_count = self.getNewContractId() + if offer.swap_type == SwapTypes.XMR_SWAP: + ensure(bid.protocol_version >= MINPROTO_VERSION_ADAPTOR_SIG, 'Incompatible bid protocol version') + reverse_bid: bool = self.is_reverse_ads_bid(offer.coin_from) + if reverse_bid: + return self.acceptADSReverseBid(bid_id, use_session) + return self.acceptXmrBid(bid_id, use_session) - coin_from = Coins(offer.coin_from) - ci_from = self.ci(coin_from) - ci_to = self.ci(offer.coin_to) - bid_date = dt.datetime.fromtimestamp(bid.created_at).date() + ensure(bid.protocol_version >= MINPROTO_VERSION_SECRET_HASH, 'Incompatible bid protocol version') + if bid.contract_count is None: + bid.contract_count = self.getNewContractId(use_session) - secret = self.getContractSecret(bid_date, bid.contract_count) - secret_hash = sha256(secret) + coin_from = Coins(offer.coin_from) + ci_from = self.ci(coin_from) + ci_to = self.ci(offer.coin_to) + bid_date = dt.datetime.fromtimestamp(bid.created_at).date() - pubkey_refund = self.getContractPubkey(bid_date, bid.contract_count) - pkhash_refund = ci_from.pkh(pubkey_refund) + secret = self.getContractSecret(bid_date, bid.contract_count) + secret_hash = sha256(secret) - if coin_from in (Coins.DCR, ): - op_hash = OpCodes.OP_SHA256_DECRED - else: - op_hash = OpCodes.OP_SHA256 + pubkey_refund = self.getContractPubkey(bid_date, bid.contract_count) + pkhash_refund = ci_from.pkh(pubkey_refund) - if bid.initiate_tx is not None: - self.log.warning('Initiate txn %s already exists for bid %s', bid.initiate_tx.txid, bid_id.hex()) - txid = bid.initiate_tx.txid - script = bid.initiate_tx.script - else: - if offer.lock_type < TxLockTypes.ABS_LOCK_BLOCKS: - sequence = ci_from.getExpectedSequence(offer.lock_type, offer.lock_value) - script = atomic_swap_1.buildContractScript(sequence, secret_hash, bid.pkhash_buyer, pkhash_refund, op_hash=op_hash) + if coin_from in (Coins.DCR, ): + op_hash = OpCodes.OP_SHA256_DECRED else: - if offer.lock_type == TxLockTypes.ABS_LOCK_BLOCKS: - lock_value = ci_from.getChainHeight() + offer.lock_value + op_hash = OpCodes.OP_SHA256 + + if bid.initiate_tx is not None: + self.log.warning('Initiate txn %s already exists for bid %s', bid.initiate_tx.txid, bid_id.hex()) + txid = bid.initiate_tx.txid + script = bid.initiate_tx.script + else: + if offer.lock_type < TxLockTypes.ABS_LOCK_BLOCKS: + sequence = ci_from.getExpectedSequence(offer.lock_type, offer.lock_value) + script = atomic_swap_1.buildContractScript(sequence, secret_hash, bid.pkhash_buyer, pkhash_refund, op_hash=op_hash) else: - lock_value = self.getTime() + offer.lock_value - self.log.debug('Initiate %s lock_value %d %d', ci_from.coin_name(), offer.lock_value, lock_value) - script = atomic_swap_1.buildContractScript(lock_value, secret_hash, bid.pkhash_buyer, pkhash_refund, OpCodes.OP_CHECKLOCKTIMEVERIFY, op_hash=op_hash) + if offer.lock_type == TxLockTypes.ABS_LOCK_BLOCKS: + lock_value = ci_from.getChainHeight() + offer.lock_value + else: + lock_value = self.getTime() + offer.lock_value + self.log.debug('Initiate %s lock_value %d %d', ci_from.coin_name(), offer.lock_value, lock_value) + script = atomic_swap_1.buildContractScript(lock_value, secret_hash, bid.pkhash_buyer, pkhash_refund, OpCodes.OP_CHECKLOCKTIMEVERIFY, op_hash=op_hash) - bid.pkhash_seller = ci_to.pkh(pubkey_refund) + bid.pkhash_seller = ci_to.pkh(pubkey_refund) - prefunded_tx = self.getPreFundedTx(Concepts.OFFER, offer.offer_id, TxTypes.ITX_PRE_FUNDED) - txn, lock_tx_vout = self.createInitiateTxn(coin_from, bid_id, bid, script, prefunded_tx) + prefunded_tx = self.getPreFundedTx(Concepts.OFFER, offer.offer_id, TxTypes.ITX_PRE_FUNDED, session=use_session) + txn, lock_tx_vout = self.createInitiateTxn(coin_from, bid_id, bid, script, prefunded_tx) - # Store the signed refund txn in case wallet is locked when refund is possible - refund_txn = self.createRefundTxn(coin_from, txn, offer, bid, script) - bid.initiate_txn_refund = bytes.fromhex(refund_txn) + # Store the signed refund txn in case wallet is locked when refund is possible + refund_txn = self.createRefundTxn(coin_from, txn, offer, bid, script, session=use_session) + bid.initiate_txn_refund = bytes.fromhex(refund_txn) - txid = ci_from.publishTx(bytes.fromhex(txn)) - self.log.debug('Submitted initiate txn %s to %s chain for bid %s', txid, ci_from.coin_name(), bid_id.hex()) - bid.initiate_tx = SwapTx( - bid_id=bid_id, - tx_type=TxTypes.ITX, - txid=bytes.fromhex(txid), - vout=lock_tx_vout, - tx_data=bytes.fromhex(txn), - script=script, - ) - bid.setITxState(TxStates.TX_SENT) - self.logEvent(Concepts.BID, bid.bid_id, EventLogTypes.ITX_PUBLISHED, '', None) + txid = ci_from.publishTx(bytes.fromhex(txn)) + self.log.debug('Submitted initiate txn %s to %s chain for bid %s', txid, ci_from.coin_name(), bid_id.hex()) + bid.initiate_tx = SwapTx( + bid_id=bid_id, + tx_type=TxTypes.ITX, + txid=bytes.fromhex(txid), + vout=lock_tx_vout, + tx_data=bytes.fromhex(txn), + script=script, + ) + bid.setITxState(TxStates.TX_SENT) + self.logEvent(Concepts.BID, bid.bid_id, EventLogTypes.ITX_PUBLISHED, '', use_session) - # Check non-bip68 final - try: - txid = ci_from.publishTx(bid.initiate_txn_refund) - self.log.error('Submit refund_txn unexpectedly worked: ' + txid) - except Exception as ex: - if ci_from.isTxNonFinalError(str(ex)) is False: - self.log.error('Submit refund_txn unexpected error' + str(ex)) - raise ex + # Check non-bip68 final + try: + txid = ci_from.publishTx(bid.initiate_txn_refund) + self.log.error('Submit refund_txn unexpectedly worked: ' + txid) + except Exception as ex: + if ci_from.isTxNonFinalError(str(ex)) is False: + self.log.error('Submit refund_txn unexpected error' + str(ex)) + raise ex - if txid is not None: - msg_buf = BidAcceptMessage() - msg_buf.bid_msg_id = bid_id - msg_buf.initiate_txid = bytes.fromhex(txid) - msg_buf.contract_script = bytes(script) + if txid is not None: + msg_buf = BidAcceptMessage() + msg_buf.bid_msg_id = bid_id + msg_buf.initiate_txid = bytes.fromhex(txid) + msg_buf.contract_script = bytes(script) - # pkh sent in script is hashed with sha256, Decred expects blake256 - if bid.pkhash_seller != pkhash_refund: - msg_buf.pkhash_seller = bid.pkhash_seller + # pkh sent in script is hashed with sha256, Decred expects blake256 + if bid.pkhash_seller != pkhash_refund: + msg_buf.pkhash_seller = bid.pkhash_seller - bid_bytes = msg_buf.to_bytes() - payload_hex = str.format('{:02x}', MessageTypes.BID_ACCEPT) + bid_bytes.hex() + bid_bytes = msg_buf.to_bytes() + payload_hex = str.format('{:02x}', MessageTypes.BID_ACCEPT) + bid_bytes.hex() - msg_valid: int = self.getAcceptBidMsgValidTime(bid) - accept_msg_id = self.sendSmsg(offer.addr_from, bid.bid_addr, payload_hex, msg_valid) + msg_valid: int = self.getAcceptBidMsgValidTime(bid) + accept_msg_id = self.sendSmsg(offer.addr_from, bid.bid_addr, payload_hex, msg_valid) - self.addMessageLink(Concepts.BID, bid_id, MessageTypes.BID_ACCEPT, accept_msg_id) - self.log.info('Sent BID_ACCEPT %s', accept_msg_id.hex()) + self.addMessageLink(Concepts.BID, bid_id, MessageTypes.BID_ACCEPT, accept_msg_id, session=use_session) + self.log.info('Sent BID_ACCEPT %s', accept_msg_id.hex()) - bid.setState(BidStates.BID_ACCEPTED) + bid.setState(BidStates.BID_ACCEPTED) - self.saveBid(bid_id, bid) - self.swaps_in_progress[bid_id] = (bid, offer) + self.saveBidInSession(bid_id, bid, use_session) + self.swaps_in_progress[bid_id] = (bid, offer) + + finally: + if session is None: + self.closeSession(use_session) def sendXmrSplitMessages(self, msg_type, addr_from: str, addr_to: str, bid_id: bytes, dleag: bytes, msg_valid: int, bid_msg_ids) -> None: msg_buf2 = XmrSplitMessage( @@ -2732,9 +2705,9 @@ class BasicSwap(BaseApp): # Send MSG1L F -> L or MSG0F L -> F self.log.debug('postXmrBid %s', offer_id.hex()) - self.mxDB.acquire() try: - offer, xmr_offer = self.getXmrOffer(offer_id) + session = self.openSession() + offer, xmr_offer = self.getXmrOffer(offer_id, session=session) ensure(offer, 'Offer not found: {}.'.format(offer_id.hex())) ensure(xmr_offer, 'Adaptor-sig offer not found: {}.'.format(offer_id.hex())) @@ -2778,9 +2751,9 @@ class BasicSwap(BaseApp): payload_hex = str.format('{:02x}', MessageTypes.ADS_BID_LF) + bid_bytes.hex() xmr_swap = XmrSwap() - xmr_swap.contract_count = self.getNewContractId() + xmr_swap.contract_count = self.getNewContractId(session) - bid_addr = self.newSMSGAddress(use_type=AddressTypes.BID)[0] if addr_send_from is None else addr_send_from + bid_addr = self.newSMSGAddress(use_type=AddressTypes.BID, session=session)[0] if addr_send_from is None else addr_send_from msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) xmr_swap.bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid) @@ -2802,12 +2775,8 @@ class BasicSwap(BaseApp): bid.setState(BidStates.BID_REQUEST_SENT) - session = self.openSession() - try: - self.saveBidInSession(xmr_swap.bid_id, bid, session, xmr_swap) - session.commit() - finally: - self.closeSession(session, commit=False) + self.saveBidInSession(xmr_swap.bid_id, bid, session, xmr_swap) + session.commit() self.log.info('Sent ADS_BID_LF %s', xmr_swap.bid_id.hex()) return xmr_swap.bid_id @@ -2819,7 +2788,7 @@ class BasicSwap(BaseApp): msg_buf.amount = int(amount) # Amount of coin_from msg_buf.amount_to = amount_to - address_out = self.getReceiveAddressFromPool(coin_from, offer_id, TxTypes.XMR_SWAP_A_LOCK) + address_out = self.getReceiveAddressFromPool(coin_from, offer_id, TxTypes.XMR_SWAP_A_LOCK, session=session) if coin_from in (Coins.PART_BLIND, ): addrinfo = ci_from.rpc('getaddressinfo', [address_out]) msg_buf.dest_af = bytes.fromhex(addrinfo['pubkey']) @@ -2827,7 +2796,7 @@ class BasicSwap(BaseApp): msg_buf.dest_af = ci_from.decodeAddress(address_out) xmr_swap = XmrSwap() - xmr_swap.contract_count = self.getNewContractId() + xmr_swap.contract_count = self.getNewContractId(session) xmr_swap.dest_af = msg_buf.dest_af for_ed25519: bool = True if ci_to.curve_type() == Curves.ed25519 else False @@ -2866,7 +2835,7 @@ class BasicSwap(BaseApp): bid_bytes = msg_buf.to_bytes() payload_hex = str.format('{:02x}', MessageTypes.XMR_BID_FL) + bid_bytes.hex() - bid_addr = self.newSMSGAddress(use_type=AddressTypes.BID)[0] if addr_send_from is None else addr_send_from + bid_addr = self.newSMSGAddress(use_type=AddressTypes.BID, session=session)[0] if addr_send_from is None else addr_send_from msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) xmr_swap.bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid) @@ -2892,34 +2861,30 @@ class BasicSwap(BaseApp): bid.chain_a_height_start = ci_from.getChainHeight() bid.chain_b_height_start = ci_to.getChainHeight() - wallet_restore_height = self.getWalletRestoreHeight(ci_to) + wallet_restore_height = self.getWalletRestoreHeight(ci_to, session) if bid.chain_b_height_start < wallet_restore_height: bid.chain_b_height_start = wallet_restore_height self.log.warning('Adaptor-sig swap restore height clamped to {}'.format(wallet_restore_height)) bid.setState(BidStates.BID_SENT) - session = self.openSession() - try: - self.saveBidInSession(xmr_swap.bid_id, bid, session, xmr_swap) - for k, msg_id in bid_msg_ids.items(): - self.addMessageLink(Concepts.BID, xmr_swap.bid_id, MessageTypes.BID, msg_id, msg_sequence=k, session=session) - finally: - self.closeSession(session) + self.saveBidInSession(xmr_swap.bid_id, bid, session, xmr_swap) + for k, msg_id in bid_msg_ids.items(): + self.addMessageLink(Concepts.BID, xmr_swap.bid_id, MessageTypes.BID, msg_id, msg_sequence=k, session=session) self.log.info('Sent XMR_BID_FL %s', xmr_swap.bid_id.hex()) return xmr_swap.bid_id finally: - self.mxDB.release() + self.closeSession(session) - def acceptXmrBid(self, bid_id: bytes) -> None: + def acceptXmrBid(self, bid_id: bytes, session=None) -> None: # MSG1F and MSG2F L -> F self.log.info('Accepting adaptor-sig bid %s', bid_id.hex()) now: int = self.getTime() - self.mxDB.acquire() try: - bid, xmr_swap = self.getXmrBid(bid_id) + use_session = self.openSession(session) + bid, xmr_swap = self.getXmrBidFromSession(use_session, bid_id) ensure(bid, 'Bid not found: {}.'.format(bid_id.hex())) ensure(xmr_swap, 'Adaptor-sig swap not found: {}.'.format(bid_id.hex())) ensure(bid.expire_at > now, 'Bid expired') @@ -2930,7 +2895,7 @@ class BasicSwap(BaseApp): ensure(last_bid_state == BidStates.BID_RECEIVED, 'Wrong bid state: {}'.format(str(BidStates(last_bid_state)))) - offer, xmr_offer = self.getXmrOffer(bid.offer_id) + offer, xmr_offer = self.getXmrOffer(bid.offer_id, session=use_session) ensure(offer, 'Offer not found: {}.'.format(bid.offer_id.hex())) ensure(xmr_offer, 'Adaptor-sig offer not found: {}.'.format(bid.offer_id.hex())) ensure(offer.expire_at > now, 'Offer has expired') @@ -2945,7 +2910,7 @@ class BasicSwap(BaseApp): b_fee_rate: int = xmr_offer.a_fee_rate if reverse_bid else xmr_offer.b_fee_rate if xmr_swap.contract_count is None: - xmr_swap.contract_count = self.getNewContractId() + xmr_swap.contract_count = self.getNewContractId(use_session) for_ed25519: bool = True if ci_to.curve_type() == Curves.ed25519 else False kbvl = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBVL, for_ed25519) @@ -2967,7 +2932,7 @@ class BasicSwap(BaseApp): # MSG2F pi = self.pi(SwapTypes.XMR_SWAP) xmr_swap.a_lock_tx_script = pi.genScriptLockTxScript(ci_from, xmr_swap.pkal, xmr_swap.pkaf) - prefunded_tx = self.getPreFundedTx(Concepts.OFFER, bid.offer_id, TxTypes.ITX_PRE_FUNDED) + prefunded_tx = self.getPreFundedTx(Concepts.OFFER, bid.offer_id, TxTypes.ITX_PRE_FUNDED, session=use_session) if prefunded_tx: xmr_swap.a_lock_tx = pi.promoteMockTx(ci_from, prefunded_tx, xmr_swap.a_lock_tx_script) else: @@ -3079,27 +3044,24 @@ class BasicSwap(BaseApp): bid.setState(BidStates.BID_ACCEPTED) # ADS - session = self.openSession() - try: - self.saveBidInSession(bid_id, bid, session, xmr_swap=xmr_swap) - for k, msg_id in bid_msg_ids.items(): - self.addMessageLink(Concepts.BID, bid_id, MessageTypes.BID_ACCEPT, msg_id, msg_sequence=k, session=session) - finally: - self.closeSession(session) + self.saveBidInSession(bid_id, bid, use_session, xmr_swap=xmr_swap) + for k, msg_id in bid_msg_ids.items(): + self.addMessageLink(Concepts.BID, bid_id, MessageTypes.BID_ACCEPT, msg_id, msg_sequence=k, session=use_session) # Add to swaps_in_progress only when waiting on txns self.log.info('Sent XMR_BID_ACCEPT_LF %s', bid_id.hex()) return bid_id finally: - self.mxDB.release() + if session is None: + self.closeSession(use_session) - def acceptADSReverseBid(self, bid_id: bytes) -> None: + def acceptADSReverseBid(self, bid_id: bytes, session=None) -> None: self.log.info('Accepting reverse adaptor-sig bid %s', bid_id.hex()) now: int = self.getTime() - self.mxDB.acquire() try: - bid, xmr_swap = self.getXmrBid(bid_id) + use_session = self.openSession(session) + bid, xmr_swap = self.getXmrBidFromSession(use_session, bid_id) ensure(bid, 'Bid not found: {}.'.format(bid_id.hex())) ensure(xmr_swap, 'Adaptor-sig swap not found: {}.'.format(bid_id.hex())) ensure(bid.expire_at > now, 'Bid expired') @@ -3110,7 +3072,7 @@ class BasicSwap(BaseApp): ensure(last_bid_state == BidStates.BID_RECEIVED, 'Wrong bid state: {}'.format(str(BidStates(last_bid_state)))) - offer, xmr_offer = self.getXmrOffer(bid.offer_id) + offer, xmr_offer = self.getXmrOffer(bid.offer_id, session=use_session) ensure(offer, 'Offer not found: {}.'.format(bid.offer_id.hex())) ensure(xmr_offer, 'Adaptor-sig offer not found: {}.'.format(bid.offer_id.hex())) ensure(offer.expire_at > now, 'Offer has expired') @@ -3122,7 +3084,7 @@ class BasicSwap(BaseApp): ci_to = self.ci(coin_to) if xmr_swap.contract_count is None: - xmr_swap.contract_count = self.getNewContractId() + xmr_swap.contract_count = self.getNewContractId(use_session) for_ed25519: bool = True if ci_to.curve_type() == Curves.ed25519 else False kbvf = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBVF, for_ed25519) @@ -3130,7 +3092,7 @@ class BasicSwap(BaseApp): kaf = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KAF) - address_out = self.getReceiveAddressFromPool(coin_from, bid.offer_id, TxTypes.XMR_SWAP_A_LOCK) + address_out = self.getReceiveAddressFromPool(coin_from, bid.offer_id, TxTypes.XMR_SWAP_A_LOCK, session=use_session) if coin_from == Coins.PART_BLIND: addrinfo = ci_from.rpc('getaddressinfo', [address_out]) xmr_swap.dest_af = bytes.fromhex(addrinfo['pubkey']) @@ -3167,17 +3129,13 @@ class BasicSwap(BaseApp): bid.setState(BidStates.BID_REQUEST_ACCEPTED) - session = self.openSession() - try: - for k, msg_id in bid_msg_ids.items(): - self.addMessageLink(Concepts.BID, bid_id, MessageTypes.ADS_BID_ACCEPT_FL, msg_id, msg_sequence=k, session=session) - self.log.info('Sent ADS_BID_ACCEPT_FL %s', bid_msg_ids[0].hex()) - self.saveBidInSession(bid_id, bid, session, xmr_swap=xmr_swap) - finally: - self.closeSession(session) - + for k, msg_id in bid_msg_ids.items(): + self.addMessageLink(Concepts.BID, bid_id, MessageTypes.ADS_BID_ACCEPT_FL, msg_id, msg_sequence=k, session=use_session) + self.log.info('Sent ADS_BID_ACCEPT_FL %s', bid_msg_ids[0].hex()) + self.saveBidInSession(bid_id, bid, use_session, xmr_swap=xmr_swap) finally: - self.mxDB.release() + if session is None: + self.closeSession(use_session) def deactivateBidForReason(self, bid_id: bytes, new_state, session_in=None) -> None: try: @@ -3325,7 +3283,7 @@ class BasicSwap(BaseApp): return txn_signed - def createRedeemTxn(self, coin_type, bid, for_txn_type='participate', addr_redeem_out=None, fee_rate=None): + def createRedeemTxn(self, coin_type, bid, for_txn_type='participate', addr_redeem_out=None, fee_rate=None, session=None): self.log.debug('createRedeemTxn for coin %s', Coins(coin_type).name) ci = self.ci(coin_type) @@ -3377,7 +3335,7 @@ class BasicSwap(BaseApp): ensure(amount_out > 0, 'Amount out <= 0') if addr_redeem_out is None: - addr_redeem_out = self.getReceiveAddressFromPool(coin_type, bid.bid_id, TxTypes.PTX_REDEEM if for_txn_type == 'participate' else TxTypes.ITX_REDEEM) + addr_redeem_out = self.getReceiveAddressFromPool(coin_type, bid.bid_id, TxTypes.PTX_REDEEM if for_txn_type == 'participate' else TxTypes.ITX_REDEEM, session) assert (addr_redeem_out is not None) self.log.debug('addr_redeem_out %s', addr_redeem_out) @@ -3440,7 +3398,7 @@ class BasicSwap(BaseApp): self.log.debug('Have valid redeem txn %s for contract %s tx %s', redeem_txid.hex(), for_txn_type, prev_txnid) return redeem_txn - def createRefundTxn(self, coin_type, txn, offer, bid, txn_script: bytearray, addr_refund_out=None, tx_type=TxTypes.ITX_REFUND): + def createRefundTxn(self, coin_type, txn, offer, bid, txn_script: bytearray, addr_refund_out=None, tx_type=TxTypes.ITX_REFUND, session=None): self.log.debug('createRefundTxn for coin %s', Coins(coin_type).name) if self.coin_clients[coin_type]['connection_type'] != 'rpc': return None @@ -3488,7 +3446,7 @@ class BasicSwap(BaseApp): raise ValueError('Refund amount out <= 0') if addr_refund_out is None: - addr_refund_out = self.getReceiveAddressFromPool(coin_type, bid.bid_id, tx_type) + addr_refund_out = self.getReceiveAddressFromPool(coin_type, bid.bid_id, tx_type, session) ensure(addr_refund_out is not None, 'addr_refund_out is null') self.log.debug('addr_refund_out %s', addr_refund_out) @@ -3591,7 +3549,7 @@ class BasicSwap(BaseApp): # Bid saved in checkBidState - def setLastHeightCheckedStart(self, coin_type, tx_height: int) -> int: + def setLastHeightCheckedStart(self, coin_type, tx_height: int, session=None) -> int: ci = self.ci(coin_type) coin_name = ci.coin_name() if tx_height < 1: @@ -3603,12 +3561,12 @@ class BasicSwap(BaseApp): if len(cc['watched_outputs']) == 0 and len(cc['watched_scripts']) == 0: cc['last_height_checked'] = tx_height cc['block_check_min_time'] = block_time - self.setIntKV('block_check_min_time_' + coin_name, block_time) + self.setIntKV('block_check_min_time_' + coin_name, block_time, session) self.log.debug('Start checking %s chain at height %d', coin_name, tx_height) elif cc['last_height_checked'] > tx_height: cc['last_height_checked'] = tx_height cc['block_check_min_time'] = block_time - self.setIntKV('block_check_min_time_' + coin_name, block_time) + self.setIntKV('block_check_min_time_' + coin_name, block_time, session) self.log.debug('Rewind checking of %s chain to height %d', coin_name, tx_height) return tx_height @@ -3777,8 +3735,7 @@ class BasicSwap(BaseApp): session = None try: - self.mxDB.acquire() - session = scoped_session(self.session_factory) + session = self.openSession() xmr_offer = session.query(XmrOffer).filter_by(offer_id=offer.offer_id).first() ensure(xmr_offer, 'Adaptor-sig offer not found: {}.'.format(offer.offer_id.hex())) xmr_swap = session.query(XmrSwap).filter_by(bid_id=bid.bid_id).first() @@ -3930,7 +3887,7 @@ class BasicSwap(BaseApp): chain_b_block_header = ci_to.getBlockHeaderAt(block_time) dest_script = ci_to.getPkDest(xmr_swap.pkbs) self.addWatchedScript(ci_to.coin_type(), bid.bid_id, dest_script, TxTypes.XMR_SWAP_B_LOCK) - self.setLastHeightCheckedStart(ci_to.coin_type(), chain_b_block_header['height']) + self.setLastHeightCheckedStart(ci_to.coin_type(), chain_b_block_header['height'], session) if bid_changed: self.saveBidInSession(bid_id, bid, session, xmr_swap) @@ -3963,7 +3920,7 @@ class BasicSwap(BaseApp): try: txn_hex = ci_from.getMempoolTx(xmr_swap.a_lock_spend_tx_id) self.log.info('Found lock spend txn in %s mempool, %s', ci_from.coin_name(), xmr_swap.a_lock_spend_tx_id.hex()) - self.process_XMR_SWAP_A_LOCK_tx_spend(bid_id, xmr_swap.a_lock_spend_tx_id.hex(), txn_hex) + self.process_XMR_SWAP_A_LOCK_tx_spend(bid_id, xmr_swap.a_lock_spend_tx_id.hex(), txn_hex, session) except Exception as e: self.log.debug('getrawtransaction lock spend tx failed: %s', str(e)) elif state == BidStates.XMR_SWAP_SCRIPT_TX_REDEEMED: @@ -4000,10 +3957,7 @@ class BasicSwap(BaseApp): except Exception as ex: raise ex finally: - if session: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session) return rv @@ -4314,12 +4268,11 @@ class BasicSwap(BaseApp): self.removeWatchedOutput(coin_to, bid_id, bid.participate_tx.txid.hex()) self.saveBid(bid_id, bid) - def process_XMR_SWAP_A_LOCK_tx_spend(self, bid_id: bytes, spend_txid_hex, spend_txn_hex) -> None: + def process_XMR_SWAP_A_LOCK_tx_spend(self, bid_id: bytes, spend_txid_hex, spend_txn_hex, session=None) -> None: self.log.debug('Detected spend of Adaptor-sig swap coin a lock tx for bid %s', bid_id.hex()) - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) - bid, xmr_swap = self.getXmrBidFromSession(session, bid_id) + use_session = self.openSession(session) + bid, xmr_swap = self.getXmrBidFromSession(use_session, bid_id) ensure(bid, 'Bid not found: {}.'.format(bid_id.hex())) ensure(xmr_swap, 'Adaptor-sig swap not found: {}.'.format(bid_id.hex())) @@ -4327,7 +4280,7 @@ class BasicSwap(BaseApp): self.log.debug('Bid stalled %s', bid_id.hex()) return - offer, xmr_offer = self.getXmrOfferFromSession(session, bid.offer_id, sent=False) + offer, xmr_offer = self.getXmrOfferFromSession(use_session, bid.offer_id, sent=False) ensure(offer, 'Offer not found: {}.'.format(bid.offer_id.hex())) ensure(xmr_offer, 'Adaptor-sig offer not found: {}.'.format(bid.offer_id.hex())) @@ -4357,24 +4310,21 @@ class BasicSwap(BaseApp): elif spending_txid == xmr_swap.a_lock_refund_tx_id: self.log.debug('Coin a lock tx spent by lock refund tx.') bid.setState(BidStates.XMR_SWAP_SCRIPT_TX_PREREFUND) - self.logBidEvent(bid.bid_id, EventLogTypes.LOCK_TX_A_REFUND_TX_SEEN, '', session) + self.logBidEvent(bid.bid_id, EventLogTypes.LOCK_TX_A_REFUND_TX_SEEN, '', use_session) else: self.setBidError(bid.bid_id, bid, 'Unexpected txn spent coin a lock tx: {}'.format(spend_txid_hex), save_bid=False) - self.saveBidInSession(bid_id, bid, session, xmr_swap, save_in_progress=offer) - session.commit() + self.saveBidInSession(bid_id, bid, use_session, xmr_swap, save_in_progress=offer) except Exception as ex: self.logException(f'process_XMR_SWAP_A_LOCK_tx_spend {ex}') finally: - session.close() - session.remove() - self.mxDB.release() + if session is None: + self.closeSession(use_session) def process_XMR_SWAP_A_LOCK_REFUND_tx_spend(self, bid_id: bytes, spend_txid_hex, spend_txn) -> None: self.log.debug('Detected spend of Adaptor-sig swap coin a lock refund tx for bid %s', bid_id.hex()) - self.mxDB.acquire() try: - session = scoped_session(self.session_factory) + session = self.openSession() bid, xmr_swap = self.getXmrBidFromSession(session, bid_id) ensure(bid, 'Bid not found: {}.'.format(bid_id.hex())) ensure(xmr_swap, 'Adaptor-sig swap not found: {}.'.format(bid_id.hex())) @@ -4423,13 +4373,10 @@ class BasicSwap(BaseApp): bid.setState(BidStates.XMR_SWAP_FAILED_SWIPED) self.saveBidInSession(bid_id, bid, session, xmr_swap, save_in_progress=offer) - session.commit() except Exception as ex: self.logException(f'process_XMR_SWAP_A_LOCK_REFUND_tx_spend {ex}') finally: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session) def processSpentOutput(self, coin_type, watched_output, spend_txid_hex, spend_n, spend_txn) -> None: if watched_output.swap_type == SwapTypes.XMR_SWAP: @@ -4499,7 +4446,7 @@ class BasicSwap(BaseApp): block_height = int(block['height']) if cc['last_height_checked'] != block_height: cc['last_height_checked'] = block_height - self.setIntKVInSession('last_height_checked_' + ci.coin_name().lower(), block_height, use_session) + self.setIntKV('last_height_checked_' + ci.coin_name().lower(), block_height, session=use_session) query = '''INSERT INTO checkedblocks (created_at, coin_type, block_height, block_hash, block_time) VALUES (:now, :coin_type, :block_height, :block_hash, :block_time)''' @@ -4661,12 +4608,10 @@ class BasicSwap(BaseApp): return q.count() def checkQueuedActions(self) -> None: - self.mxDB.acquire() now: int = self.getTime() - session = None reload_in_progress: bool = False try: - session = scoped_session(self.session_factory) + session = self.openSession() q = session.query(Action).filter(sa.and_(Action.active_ind == 1, Action.trigger_at <= now)) for row in q: @@ -4674,10 +4619,10 @@ class BasicSwap(BaseApp): try: if row.action_type == ActionTypes.ACCEPT_BID: accepting_bid = True - self.acceptBid(row.linked_id) + self.acceptBid(row.linked_id, session) elif row.action_type == ActionTypes.ACCEPT_XMR_BID: accepting_bid = True - self.acceptXmrBid(row.linked_id) + self.acceptXmrBid(row.linked_id, session) elif row.action_type == ActionTypes.SIGN_XMR_SWAP_LOCK_TX_A: self.sendXmrBidTxnSigsFtoL(row.linked_id, session) elif row.action_type == ActionTypes.SEND_XMR_SWAP_LOCK_TX_A: @@ -4698,7 +4643,7 @@ class BasicSwap(BaseApp): atomic_swap_1.redeemITx(self, row.linked_id, session) elif row.action_type == ActionTypes.ACCEPT_AS_REV_BID: accepting_bid = True - self.acceptADSReverseBid(row.linked_id) + self.acceptADSReverseBid(row.linked_id, session) else: self.log.warning('Unknown event type: %d', row.event_type) except Exception as ex: @@ -4716,7 +4661,7 @@ class BasicSwap(BaseApp): # If delaying with no (further) queued actions reset state if self.countQueuedActions(session, bid_id, None) < 2: - bid, offer = self.getBidAndOffer(bid_id) + bid, offer = self.getBidAndOffer(bid_id, session) last_state = getLastBidState(bid.states) if bid and bid.state == BidStates.SWAP_DELAYING and last_state == BidStates.BID_RECEIVED: new_state = BidStates.BID_ERROR if offer.bid_reversed else BidStates.BID_RECEIVED @@ -4733,26 +4678,20 @@ class BasicSwap(BaseApp): else: session.execute('DELETE FROM actions WHERE trigger_at <= :now', {'now': now}) - session.commit() except Exception as ex: self.handleSessionErrors(ex, session, 'checkQueuedActions') reload_in_progress = True finally: - if session: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session) if reload_in_progress: self.loadFromDB() def checkXmrSwaps(self) -> None: - self.mxDB.acquire() now: int = self.getTime() ttl_xmr_split_messages = 60 * 60 - session = None try: - session = scoped_session(self.session_factory) + session = self.openSession() q = session.query(Bid).filter(Bid.state == BidStates.BID_RECEIVING) for bid in q: q = session.execute('SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = x\'{}\' AND msg_type = {}'.format(bid.bid_id.hex(), XmrSplitMsgTypes.BID)).first() @@ -4797,12 +4736,8 @@ class BasicSwap(BaseApp): q = session.query(XmrSplitData).filter(XmrSplitData.created_at + ttl_xmr_split_messages < now) q.delete(synchronize_session=False) - session.commit() finally: - if session: - session.close() - session.remove() - self.mxDB.release() + self.closeSession(session) def processOffer(self, msg) -> None: offer_bytes = bytes.fromhex(msg['hex'][2:-2]) @@ -4864,8 +4799,8 @@ class BasicSwap(BaseApp): if self.isOfferRevoked(offer_id, msg['from']): raise ValueError('Offer has been revoked {}.'.format(offer_id.hex())) - session = scoped_session(self.session_factory) try: + session = self.openSession() # Offers must be received on the public network_addr or manually created addresses if msg['to'] != self.network_addr: # Double check active_ind, shouldn't be possible to receive message if not active @@ -4875,7 +4810,7 @@ class BasicSwap(BaseApp): raise ValueError('Offer received on incorrect address') # Check for sent - existing_offer = self.getOffer(offer_id) + existing_offer = self.getOffer(offer_id, session=session) if existing_offer is None: bid_reversed: bool = offer_data.swap_type == SwapTypes.XMR_SWAP and self.is_reverse_ads_bid(offer_data.coin_from) offer = Offer( @@ -4926,10 +4861,8 @@ class BasicSwap(BaseApp): else: existing_offer.setState(OfferStates.OFFER_RECEIVED) session.add(existing_offer) - session.commit() finally: - session.close() - session.remove() + self.closeSession(session) def processOfferRevoke(self, msg) -> None: ensure(msg['to'] == self.network_addr, 'Message received on wrong address') @@ -5325,7 +5258,7 @@ class BasicSwap(BaseApp): self.log.debug('Receiving adaptor-sig bid accept %s', bid.bid_id.hex()) now: int = self.getTime() - offer, xmr_offer = self.getXmrOffer(bid.offer_id, sent=True) + offer, xmr_offer = self.getXmrOffer(bid.offer_id, sent=True, session=session) ensure(offer, 'Offer not found: {}.'.format(bid.offer_id.hex())) ensure(xmr_offer, 'Adaptor-sig offer not found: {}.'.format(bid.offer_id.hex())) xmr_swap = session.query(XmrSwap).filter_by(bid_id=bid.bid_id).first() @@ -5568,13 +5501,13 @@ class BasicSwap(BaseApp): self.log.error(traceback.format_exc()) self.setBidError(bid.bid_id, bid, str(ex), xmr_swap=xmr_swap) - def watchXmrSwap(self, bid, offer, xmr_swap) -> None: + def watchXmrSwap(self, bid, offer, xmr_swap, session=None) -> None: self.log.debug('Adaptor-sig swap in progress, bid %s', bid.bid_id.hex()) self.swaps_in_progress[bid.bid_id] = (bid, offer) reverse_bid: bool = self.is_reverse_ads_bid(offer.coin_from) coin_from = Coins(offer.coin_to if reverse_bid else offer.coin_from) - self.setLastHeightCheckedStart(coin_from, bid.chain_a_height_start) + self.setLastHeightCheckedStart(coin_from, bid.chain_a_height_start, session) self.addWatchedOutput(coin_from, bid.bid_id, bid.xmr_a_lock_tx.txid.hex(), bid.xmr_a_lock_tx.vout, TxTypes.XMR_SWAP_A_LOCK, SwapTypes.XMR_SWAP) lock_refund_vout = self.ci(coin_from).getLockRefundTxSwapOutput(xmr_swap) @@ -5639,7 +5572,7 @@ class BasicSwap(BaseApp): bid.xmr_a_lock_tx.setState(TxStates.TX_NONE) bid.setState(BidStates.XMR_SWAP_MSG_SCRIPT_LOCK_TX_SIGS) - self.watchXmrSwap(bid, offer, xmr_swap) + self.watchXmrSwap(bid, offer, xmr_swap, session) self.saveBidInSession(bid_id, bid, session, xmr_swap) except Exception as ex: if self.debug: @@ -5714,7 +5647,7 @@ class BasicSwap(BaseApp): self.logBidEvent(bid.bid_id, EventLogTypes.LOCK_TX_A_PUBLISHED, '', session) bid.setState(BidStates.XMR_SWAP_HAVE_SCRIPT_COIN_SPEND_TX) - self.watchXmrSwap(bid, offer, xmr_swap) + self.watchXmrSwap(bid, offer, xmr_swap, session) delay = self.get_short_delay_event_seconds() self.log.info('Sending lock spend tx message for bid %s in %d seconds', bid_id.hex(), delay) @@ -5928,11 +5861,11 @@ class BasicSwap(BaseApp): vkbs = ci_to.sumKeys(kbsl, kbsf) if coin_to == Coins.XMR: - address_to = self.getCachedMainWalletAddress(ci_to) + address_to = self.getCachedMainWalletAddress(ci_to, session) elif coin_to in (Coins.PART_BLIND, Coins.PART_ANON): - address_to = self.getCachedStealthAddressForCoin(coin_to) + address_to = self.getCachedStealthAddressForCoin(coin_to, session) else: - address_to = self.getReceiveAddressFromPool(coin_to, bid_id, TxTypes.XMR_SWAP_B_LOCK_SPEND) + address_to = self.getReceiveAddressFromPool(coin_to, bid_id, TxTypes.XMR_SWAP_B_LOCK_SPEND, session) lock_tx_vout = bid.getLockTXBVout() txid = ci_to.spendBLockTx(xmr_swap.b_lock_tx_id, address_to, xmr_swap.vkbv, vkbs, bid.amount_to, b_fee_rate, bid.chain_b_height_start, lock_tx_vout=lock_tx_vout) @@ -5995,11 +5928,11 @@ class BasicSwap(BaseApp): try: if offer.coin_to == Coins.XMR: - address_to = self.getCachedMainWalletAddress(ci_to) + address_to = self.getCachedMainWalletAddress(ci_to, session) elif coin_to in (Coins.PART_BLIND, Coins.PART_ANON): - address_to = self.getCachedStealthAddressForCoin(coin_to) + address_to = self.getCachedStealthAddressForCoin(coin_to, session) else: - address_to = self.getReceiveAddressFromPool(coin_to, bid_id, TxTypes.XMR_SWAP_B_LOCK_REFUND) + address_to = self.getReceiveAddressFromPool(coin_to, bid_id, TxTypes.XMR_SWAP_B_LOCK_REFUND, session) lock_tx_vout = bid.getLockTXBVout() txid = ci_to.spendBLockTx(xmr_swap.b_lock_tx_id, address_to, xmr_swap.vkbv, vkbs, bid.amount_to, b_fee_rate, bid.chain_b_height_start, lock_tx_vout=lock_tx_vout) @@ -6420,7 +6353,6 @@ class BasicSwap(BaseApp): self.closeSession(session) def processMsg(self, msg) -> None: - self.mxDB.acquire() try: msg_type = int(msg['hex'][:2], 16) @@ -6463,9 +6395,6 @@ class BasicSwap(BaseApp): str(ex), None) - finally: - self.mxDB.release() - def processZmqSmsg(self) -> None: message = self.zmqSubscriber.recv() clear = self.zmqSubscriber.recv() @@ -6569,7 +6498,6 @@ class BasicSwap(BaseApp): self.log.debug(f'Expired {bids_expired} bid{mb} and {offers_expired} offer{mo}') def update(self) -> None: - # Run every half second from basicswap-run if self._zmq_queue_enabled: try: if self._read_zmq_queue: @@ -6590,7 +6518,6 @@ class BasicSwap(BaseApp): for msg in msgs['messages']: self.processMsg(msg) - self.mxDB.acquire() try: # TODO: Wait for blocks / txns, would need to check multiple coins now: int = self.getTime() @@ -6640,8 +6567,6 @@ class BasicSwap(BaseApp): except Exception as ex: self.logException(f'update {ex}') - finally: - self.mxDB.release() def manualBidUpdate(self, bid_id: bytes, data): self.log.info('Manually updating bid %s', bid_id.hex()) @@ -7132,9 +7057,8 @@ class BasicSwap(BaseApp): def getCachedWalletsInfo(self, opts=None): rv = {} - # Requires? self.mxDB.acquire() try: - session = scoped_session(self.session_factory) + session = self.openSession() where_str = '' if opts is not None and 'coin_id' in opts: where_str = 'WHERE coin_id = {}'.format(opts['coin_id']) @@ -7171,8 +7095,7 @@ class BasicSwap(BaseApp): else: rv[coin_id] = wallet_data finally: - session.close() - session.remove() + self.closeSession(session) if opts is not None and 'coin_id' in opts: return rv diff --git a/basicswap/db_upgrades.py b/basicswap/db_upgrades.py index 605ff09..f76db35 100644 --- a/basicswap/db_upgrades.py +++ b/basicswap/db_upgrades.py @@ -93,7 +93,7 @@ def upgradeDatabaseData(self, data_version): created_at=now)) self.db_data_version = CURRENT_DB_DATA_VERSION - self.setIntKVInSession('db_data_version', self.db_data_version, session) + self.setIntKV('db_data_version', self.db_data_version, session) session.commit() self.log.info('Upgraded database records to version {}'.format(self.db_data_version)) finally: @@ -314,7 +314,7 @@ def upgradeDatabase(self, db_version): session.execute('ALTER TABLE bids ADD COLUMN pkhash_buyer_to BLOB') if current_version != db_version: self.db_version = db_version - self.setIntKVInSession('db_version', db_version, session) + self.setIntKV('db_version', db_version, session) session.commit() session.close() session.remove() diff --git a/basicswap/protocols/atomic_swap_1.py b/basicswap/protocols/atomic_swap_1.py index e65a97e..6bea165 100644 --- a/basicswap/protocols/atomic_swap_1.py +++ b/basicswap/protocols/atomic_swap_1.py @@ -105,7 +105,7 @@ def redeemITx(self, bid_id: bytes, session): bid, offer = self.getBidAndOffer(bid_id, session) ci_from = self.ci(offer.coin_from) - txn = self.createRedeemTxn(ci_from.coin_type(), bid, for_txn_type='initiate') + txn = self.createRedeemTxn(ci_from.coin_type(), bid, for_txn_type='initiate', session=session) txid = ci_from.publishTx(bytes.fromhex(txn)) bid.initiate_tx.spend_txid = bytes.fromhex(txid) diff --git a/basicswap/protocols/xmr_swap_1.py b/basicswap/protocols/xmr_swap_1.py index 833449c..33a252f 100644 --- a/basicswap/protocols/xmr_swap_1.py +++ b/basicswap/protocols/xmr_swap_1.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- -# Copyright (c) 2020-2023 tecnovert +# Copyright (c) 2020-2024 tecnovert # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. -from sqlalchemy.orm import scoped_session - from basicswap.util import ( ensure, ) @@ -45,7 +43,7 @@ def addLockRefundSigs(self, xmr_swap, ci): def recoverNoScriptTxnWithKey(self, bid_id: bytes, encoded_key): self.log.info('Manually recovering %s', bid_id.hex()) # Manually recover txn if other key is known - session = scoped_session(self.session_factory) + session = self.openSession() try: bid, xmr_swap = self.getXmrBidFromSession(session, bid_id) ensure(bid, 'Bid not found: {}.'.format(bid_id.hex())) @@ -86,8 +84,7 @@ def recoverNoScriptTxnWithKey(self, bid_id: bytes, encoded_key): return txid finally: - session.close() - session.remove() + self.closeSession(session, commit=False) def getChainBSplitKey(swap_client, bid, xmr_swap, offer): diff --git a/tests/basicswap/extended/test_dcr.py b/tests/basicswap/extended/test_dcr.py index c0c668f..ff8cad4 100644 --- a/tests/basicswap/extended/test_dcr.py +++ b/tests/basicswap/extended/test_dcr.py @@ -345,9 +345,8 @@ def run_test_ads_both_refund(self, coin_from: Coins, coin_to: Coins, lock_value: ci_from = swap_clients[id_offerer].ci(coin_from) ci_to = swap_clients[id_offerer].ci(coin_to) - if reverse_bid: - self.prepare_balance(coin_to, 100.0, 1801, 1800) - self.prepare_balance(coin_from, 100.0, 1800, 1801) + self.prepare_balance(coin_to, 100.0, 1801, 1800) + self.prepare_balance(coin_from, 100.0, 1800, 1801) id_leader: int = id_bidder if reverse_bid else id_offerer id_follower: int = id_offerer if reverse_bid else id_bidder diff --git a/tests/basicswap/test_xmr.py b/tests/basicswap/test_xmr.py index 599cbef..80b25f8 100644 --- a/tests/basicswap/test_xmr.py +++ b/tests/basicswap/test_xmr.py @@ -1388,9 +1388,13 @@ class Test(BaseTest): js_0 = read_json_api(1800, 'wallets/part') node0_blind_before = js_0['blind_balance'] + js_0['blind_unconfirmed'] - amt_swap = make_int(random.uniform(0.1, 2.0), scale=8, r=1) - rate_swap = make_int(random.uniform(2.0, 20.0), scale=8, r=1) - offer_id = swap_clients[0].postOffer(Coins.PART_BLIND, Coins.XMR, amt_swap, rate_swap, amt_swap, SwapTypes.XMR_SWAP) + coin_from = Coins.PART_BLIND + coin_to = Coins.XMR + ci_from = swap_clients[0].ci(coin_from) + ci_to = swap_clients[0].ci(coin_to) + amt_swap = ci_from.make_int(random.uniform(0.1, 2.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[0].postOffer(coin_from, coin_to, amt_swap, rate_swap, amt_swap, SwapTypes.XMR_SWAP) wait_for_offer(test_delay_event, swap_clients[1], offer_id) offers = swap_clients[0].listOffers(filters={'offer_id': offer_id}) offer = offers[0]