diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 3fa699f..dc1284b 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -19,6 +19,11 @@ import secrets from sqlalchemy.orm import sessionmaker, scoped_session from enum import IntEnum, auto +from .interface_part import PARTInterface +from .interface_btc import BTCInterface +from .interface_ltc import LTCInterface +from .interface_xmr import XMRInterface + from . import __version__ from .util import ( COIN, @@ -31,7 +36,7 @@ from .util import ( decodeWif, toWIF, getKeyID, - makeInt, + make_int, ) from .chainparams import ( chainparams, @@ -417,6 +422,27 @@ class BasicSwap(BaseApp): 'chain_lookups': chain_client_settings.get('chain_lookups', 'local'), } + if self.coin_clients[coin]['connection_type'] == 'rpc': + if coin == Coins.XMR: + self.coin_clients[coin]['walletrpcport'] = chain_client_settings.get('walletrpcport', chainparams[coin][self.chain]['walletrpcport']) + if 'walletrpcpassword' in chain_client_settings: + self.coin_clients[coin]['walletrpcauth'] = chain_client_settings['walletrpcuser'] + ':' + chain_client_settings['walletrpcpassword'] + else: + raise ValueError('Missing XMR wallet rpc credentials.') + self.coin_clients[coin]['interface'] = self.createInterface(coin) + + def createInterface(self, coin): + if coin == Coins.PART: + return PARTInterface(self.coin_clients[coin]) + elif coin == Coins.BTC: + return BTCInterface(self.coin_clients[coin]) + elif coin == Coins.LTC: + return LTCInterface(self.coin_clients[coin]) + elif coin == Coins.XMR: + return XMRInterface(self.coin_clients[coin]) + else: + raise ValueError('Unknown coin type') + def setCoinRunParams(self, coin): cc = self.coin_clients[coin] if cc['connection_type'] == 'rpc' and cc['rpcauth'] is None: @@ -1699,7 +1725,7 @@ class BasicSwap(BaseApp): continue # Verify amount if assert_amount: - assert(makeInt(o['amount']) == int(assert_amount)), 'Incorrect output amount in txn {}: {} != {}.'.format(assert_txid, makeInt(o['amount']), int(assert_amount)) + assert(make_int(o['amount']) == int(assert_amount)), 'Incorrect output amount in txn {}: {} != {}.'.format(assert_txid, make_int(o['amount']), int(assert_amount)) if not sum_output: if o['height'] > 0: @@ -1711,7 +1737,7 @@ class BasicSwap(BaseApp): 'index': o['vout'], 'height': o['height'], 'n_conf': n_conf, - 'value': makeInt(o['amount']), + 'value': make_int(o['amount']), } else: sum_unspent += o['amount'] * COIN @@ -1744,7 +1770,7 @@ class BasicSwap(BaseApp): # Verify amount vout = getVoutByAddress(initiate_txn, p2sh) - out_value = makeInt(initiate_txn['vout'][vout]['value']) + out_value = make_int(initiate_txn['vout'][vout]['value']) assert(out_value == int(bid.amount)), 'Incorrect output amount in initiate txn {}: {} != {}.'.format(initiate_txnid_hex, out_value, int(bid.amount)) bid.initiate_tx.conf = initiate_txn['confirmations'] @@ -2442,8 +2468,8 @@ class BasicSwap(BaseApp): 'deposit_address': self.getCachedAddressForCoin(coin), 'name': chainparams[coin]['name'].capitalize(), 'blocks': blockchaininfo['blocks'], - 'balance': format8(makeInt(walletinfo['balance'])), - 'unconfirmed': format8(makeInt(walletinfo.get('unconfirmed_balance'))), + 'balance': format8(make_int(walletinfo['balance'])), + 'unconfirmed': format8(make_int(walletinfo.get('unconfirmed_balance'))), 'synced': '{0:.2f}'.format(round(blockchaininfo['verificationprogress'], 2)), } return rv diff --git a/basicswap/chainparams.py b/basicswap/chainparams.py index 2adf2ee..d778f10 100644 --- a/basicswap/chainparams.py +++ b/basicswap/chainparams.py @@ -14,8 +14,9 @@ class Coins(IntEnum): PART = 1 BTC = 2 LTC = 3 - # DCR = 4 + #DCR = 4 NMC = 5 + XMR = 6 chainparams = { @@ -156,5 +157,26 @@ chainparams = { 'min_amount': 1000, 'max_amount': 100000 * COIN, } + }, + Coins.XMR: { + 'name': 'monero', + 'ticker': 'XMR', + 'client': 'xmr', + 'mainnet': { + 'rpcport': 18081, + 'walletrpcport': 18082, + }, + 'testnet': { + 'rpcport': 28081, + 'walletrpcport': 28082, + }, + 'regtest': { + 'rpcport': 18081, + 'walletrpcport': 18082, + } } } + +class CoinInterface: + pass + diff --git a/basicswap/contrib/MoneroPy/__init__.py b/basicswap/contrib/MoneroPy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/basicswap/contrib/MoneroPy/base58.py b/basicswap/contrib/MoneroPy/base58.py new file mode 100644 index 0000000..83424de --- /dev/null +++ b/basicswap/contrib/MoneroPy/base58.py @@ -0,0 +1,168 @@ +# MoneroPy - A python toolbox for Monero +# Copyright (C) 2016 The MoneroPy Developers. +# +# MoneroPy is released under the BSD 3-Clause license. Use and redistribution of +# this software is subject to the license terms in the LICENSE file found in the +# top-level directory of this distribution. + +__alphabet = [ord(s) for s in '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'] +__b58base = 58 +__UINT64MAX = 2**64 +__encodedBlockSizes = [0, 2, 3, 5, 6, 7, 9, 10, 11] +__fullBlockSize = 8 +__fullEncodedBlockSize = 11 + +def _hexToBin(hex): + if len(hex) % 2 != 0: + return "Hex string has invalid length!" + return [int(hex[i*2:i*2+2], 16) for i in range(len(hex)//2)] + +def _binToHex(bin): + return "".join([("0" + hex(int(bin[i])).split('x')[1])[-2:] for i in range(len(bin))]) + +def _strToBin(a): + return [ord(s) for s in a] + +def _binToStr(bin): + return ''.join([chr(bin[i]) for i in range(len(bin))]) + +def _uint8be_to_64(data): + l_data = len(data) + + if l_data < 1 or l_data > 8: + return "Invalid input length" + + res = 0 + switch = 9 - l_data + for i in range(l_data): + if switch == 1: + res = res << 8 | data[i] + elif switch == 2: + res = res << 8 | data[i] + elif switch == 3: + res = res << 8 | data[i] + elif switch == 4: + res = res << 8 | data[i] + elif switch == 5: + res = res << 8 | data[i] + elif switch == 6: + res = res << 8 | data[i] + elif switch == 7: + res = res << 8 | data[i] + elif switch == 8: + res = res << 8 | data[i] + else: + return "Impossible condition" + return res + +def _uint64_to_8be(num, size): + res = [0] * size; + if size < 1 or size > 8: + return "Invalid input length" + + twopow8 = 2**8 + for i in range(size-1,-1,-1): + res[i] = num % twopow8 + num = num // twopow8 + + return res + +def encode_block(data, buf, index): + l_data = len(data) + + if l_data < 1 or l_data > __fullEncodedBlockSize: + return "Invalid block length: " + str(l_data) + + num = _uint8be_to_64(data) + i = __encodedBlockSizes[l_data] - 1 + + while num > 0: + remainder = num % __b58base + num = num // __b58base + buf[index+i] = __alphabet[remainder]; + i -= 1 + + return buf + +def encode(hex): + '''Encode hexadecimal string as base58 (ex: encoding a Monero address).''' + data = _hexToBin(hex) + l_data = len(data) + + if l_data == 0: + return "" + + full_block_count = l_data // __fullBlockSize + last_block_size = l_data % __fullBlockSize + res_size = full_block_count * __fullEncodedBlockSize + __encodedBlockSizes[last_block_size] + + res = [0] * res_size + for i in range(res_size): + res[i] = __alphabet[0] + + for i in range(full_block_count): + res = encode_block(data[(i*__fullBlockSize):(i*__fullBlockSize+__fullBlockSize)], res, i * __fullEncodedBlockSize) + + if last_block_size > 0: + res = encode_block(data[(full_block_count*__fullBlockSize):(full_block_count*__fullBlockSize+last_block_size)], res, full_block_count * __fullEncodedBlockSize) + + return _binToStr(res) + +def decode_block(data, buf, index): + l_data = len(data) + + if l_data < 1 or l_data > __fullEncodedBlockSize: + return "Invalid block length: " + l_data + + res_size = __encodedBlockSizes.index(l_data) + if res_size <= 0: + return "Invalid block size" + + res_num = 0 + order = 1 + for i in range(l_data-1, -1, -1): + digit = __alphabet.index(data[i]) + if digit < 0: + return "Invalid symbol" + + product = order * digit + res_num + if product > __UINT64MAX: + return "Overflow" + + res_num = product + order = order * __b58base + + if res_size < __fullBlockSize and 2**(8 * res_size) <= res_num: + return "Overflow 2" + + tmp_buf = _uint64_to_8be(res_num, res_size) + for i in range(len(tmp_buf)): + buf[i+index] = tmp_buf[i] + + return buf + +def decode(enc): + '''Decode a base58 string (ex: a Monero address) into hexidecimal form.''' + enc = _strToBin(enc) + l_enc = len(enc) + + if l_enc == 0: + return "" + + full_block_count = l_enc // __fullEncodedBlockSize + last_block_size = l_enc % __fullEncodedBlockSize + last_block_decoded_size = __encodedBlockSizes.index(last_block_size) + + if last_block_decoded_size < 0: + return "Invalid encoded length" + + data_size = full_block_count * __fullBlockSize + last_block_decoded_size + + data = [0] * data_size + for i in range(full_block_count): + data = decode_block(enc[(i*__fullEncodedBlockSize):(i*__fullEncodedBlockSize+__fullEncodedBlockSize)], data, i * __fullBlockSize) + + if last_block_size > 0: + data = decode_block(enc[(full_block_count*__fullEncodedBlockSize):(full_block_count*__fullEncodedBlockSize+last_block_size)], data, full_block_count * __fullBlockSize) + + return _binToHex(data) diff --git a/basicswap/contrib/ellipticcurve.py b/basicswap/contrib/ellipticcurve.py new file mode 100644 index 0000000..8a58166 --- /dev/null +++ b/basicswap/contrib/ellipticcurve.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# +# Implementation of elliptic curves, for cryptographic applications. +# +# This module doesn't provide any way to choose a random elliptic +# curve, nor to verify that an elliptic curve was chosen randomly, +# because one can simply use NIST's standard curves. +# +# Notes from X9.62-1998 (draft): +# Nomenclature: +# - Q is a public key. +# The "Elliptic Curve Domain Parameters" include: +# - q is the "field size", which in our case equals p. +# - p is a big prime. +# - G is a point of prime order (5.1.1.1). +# - n is the order of G (5.1.1.1). +# Public-key validation (5.2.2): +# - Verify that Q is not the point at infinity. +# - Verify that X_Q and Y_Q are in [0,p-1]. +# - Verify that Q is on the curve. +# - Verify that nQ is the point at infinity. +# Signature generation (5.3): +# - Pick random k from [1,n-1]. +# Signature checking (5.4.2): +# - Verify that r and s are in [1,n-1]. +# +# Version of 2008.11.25. +# +# Revision history: +# 2005.12.31 - Initial version. +# 2008.11.25 - Change CurveFp.is_on to contains_point. +# +# Written in 2005 by Peter Pearson and placed in the public domain. + +def inverse_mod(a, m): + """Inverse of a mod m.""" + + if a < 0 or m <= a: + a = a % m + + # From Ferguson and Schneier, roughly: + + c, d = a, m + uc, vc, ud, vd = 1, 0, 0, 1 + while c != 0: + q, c, d = divmod(d, c) + (c,) + uc, vc, ud, vd = ud - q * uc, vd - q * vc, uc, vc + + # At this point, d is the GCD, and ud*a+vd*m = d. + # If d == 1, this means that ud is a inverse. + + assert d == 1 + if ud > 0: + return ud + else: + return ud + m + + +def modular_sqrt(a, p): + # from http://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python/ + """ Find a quadratic residue (mod p) of 'a'. p + must be an odd prime. + + Solve the congruence of the form: + x^2 = a (mod p) + And returns x. Note that p - x is also a root. + + 0 is returned is no square root exists for + these a and p. + + The Tonelli-Shanks algorithm is used (except + for some simple cases in which the solution + is known from an identity). This algorithm + runs in polynomial time (unless the + generalized Riemann hypothesis is false). + """ + # Simple cases + # + if legendre_symbol(a, p) != 1: + return 0 + elif a == 0: + return 0 + elif p == 2: + return p + elif p % 4 == 3: + return pow(a, (p + 1) // 4, p) + + # Partition p-1 to s * 2^e for an odd s (i.e. + # reduce all the powers of 2 from p-1) + # + s = p - 1 + e = 0 + while s % 2 == 0: + s /= 2 + e += 1 + + # Find some 'n' with a legendre symbol n|p = -1. + # Shouldn't take long. + # + n = 2 + while legendre_symbol(n, p) != -1: + n += 1 + + # Here be dragons! + # Read the paper "Square roots from 1; 24, 51, + # 10 to Dan Shanks" by Ezra Brown for more + # information + # + + # x is a guess of the square root that gets better + # with each iteration. + # b is the "fudge factor" - by how much we're off + # with the guess. The invariant x^2 = ab (mod p) + # is maintained throughout the loop. + # g is used for successive powers of n to update + # both a and b + # r is the exponent - decreases with each update + # + x = pow(a, (s + 1) // 2, p) + b = pow(a, s, p) + g = pow(n, s, p) + r = e + + while True: + t = b + m = 0 + for m in range(r): + if t == 1: + break + t = pow(t, 2, p) + + if m == 0: + return x + + gs = pow(g, 2 ** (r - m - 1), p) + g = (gs * gs) % p + x = (x * gs) % p + b = (b * g) % p + r = m + + +def legendre_symbol(a, p): + """ Compute the Legendre symbol a|p using + Euler's criterion. p is a prime, a is + relatively prime to p (if p divides + a, then a|p = 0) + + Returns 1 if a has a square root modulo + p, -1 otherwise. + """ + ls = pow(a, (p - 1) // 2, p) + return -1 if ls == p - 1 else ls + + +def jacobi_symbol(n, k): + """Compute the Jacobi symbol of n modulo k + + See http://en.wikipedia.org/wiki/Jacobi_symbol + + For our application k is always prime, so this is the same as the Legendre symbol.""" + assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k" + n %= k + t = 0 + while n != 0: + while n & 1 == 0: + n >>= 1 + r = k & 7 + t ^= (r == 3 or r == 5) + n, k = k, n + t ^= (n & k & 3 == 3) + n = n % k + if k == 1: + return -1 if t else 1 + return 0 + + +class CurveFp(object): + """Elliptic Curve over the field of integers modulo a prime.""" + def __init__(self, p, a, b): + """The curve of points satisfying y^2 = x^3 + a*x + b (mod p).""" + self.__p = p + self.__a = a + self.__b = b + + def p(self): + return self.__p + + def a(self): + return self.__a + + def b(self): + return self.__b + + def contains_point(self, x, y): + """Is the point (x,y) on this curve?""" + return (y * y - (x * x * x + self.__a * x + self.__b)) % self.__p == 0 + + +class Point(object): + """ A point on an elliptic curve. Altering x and y is forbidding, + but they can be read by the x() and y() methods.""" + def __init__(self, curve, x, y, order=None): + """curve, x, y, order; order (optional) is the order of this point.""" + self.__curve = curve + self.__x = x + self.__y = y + self.__order = order + # self.curve is allowed to be None only for INFINITY: + if self.__curve: + assert self.__curve.contains_point(x, y) + if order: + assert self * order == INFINITY + + def __eq__(self, other): + """Return 1 if the points are identical, 0 otherwise.""" + if self.__curve == other.__curve \ + and self.__x == other.__x \ + and self.__y == other.__y: + return 1 + else: + return 0 + + def __add__(self, other): + """Add one point to another point.""" + + # X9.62 B.3: + if other == INFINITY: + return self + if self == INFINITY: + return other + assert self.__curve == other.__curve + if self.__x == other.__x: + if (self.__y + other.__y) % self.__curve.p() == 0: + return INFINITY + else: + return self.double() + + p = self.__curve.p() + + l = ((other.__y - self.__y) * inverse_mod(other.__x - self.__x, p)) % p + + x3 = (l * l - self.__x - other.__x) % p + y3 = (l * (self.__x - x3) - self.__y) % p + + return Point(self.__curve, x3, y3) + + def __sub__(self, other): + #The inverse of a point P=(xP,yP) is its reflexion across the x-axis : P′=(xP,−yP). + #If you want to compute Q−P, just replace yP by −yP in the usual formula for point addition. + + # X9.62 B.3: + if other == INFINITY: + return self + if self == INFINITY: + return other + assert self.__curve == other.__curve + + p = self.__curve.p() + #opi = inverse_mod(other.__y, p) + opi = -other.__y % p + #print(opi) + #print(-other.__y % p) + + if self.__x == other.__x: + if (self.__y + opi) % self.__curve.p() == 0: + return INFINITY + else: + return self.double + + l = ((opi - self.__y) * inverse_mod(other.__x - self.__x, p)) % p + + x3 = (l * l - self.__x - other.__x) % p + y3 = (l * (self.__x - x3) - self.__y) % p + + return Point(self.__curve, x3, y3) + + def __mul__(self, e): + if self.__order: + e %= self.__order + if e == 0 or self == INFINITY: + return INFINITY + result, q = INFINITY, self + while e: + if e & 1: + result += q + e, q = e >> 1, q.double() + return result + + """ + def __mul__(self, other): + #Multiply a point by an integer. + + def leftmost_bit( x ): + assert x > 0 + result = 1 + while result <= x: result = 2 * result + return result // 2 + + e = other + if self.__order: e = e % self.__order + if e == 0: return INFINITY + if self == INFINITY: return INFINITY + assert e > 0 + + # From X9.62 D.3.2: + + e3 = 3 * e + negative_self = Point( self.__curve, self.__x, -self.__y, self.__order ) + i = leftmost_bit( e3 ) // 2 + result = self + # print "Multiplying %s by %d (e3 = %d):" % ( self, other, e3 ) + while i > 1: + result = result.double() + if ( e3 & i ) != 0 and ( e & i ) == 0: result = result + self + if ( e3 & i ) == 0 and ( e & i ) != 0: result = result + negative_self + # print ". . . i = %d, result = %s" % ( i, result ) + i = i // 2 + + return result + """ + + def __rmul__(self, other): + """Multiply a point by an integer.""" + + return self * other + + def __str__(self): + if self == INFINITY: + return "infinity" + return "(%d, %d)" % (self.__x, self.__y) + + def inverse(self): + return Point(self.__curve, self.__x, -self.__y % self.__curve.p()) + + def double(self): + """Return a new point that is twice the old.""" + + if self == INFINITY: + return INFINITY + + # X9.62 B.3: + + p = self.__curve.p() + a = self.__curve.a() + + l = ((3 * self.__x * self.__x + a) * inverse_mod(2 * self.__y, p)) % p + + x3 = (l * l - 2 * self.__x) % p + y3 = (l * (self.__x - x3) - self.__y) % p + + return Point(self.__curve, x3, y3) + + def x(self): + return self.__x + + def y(self): + return self.__y + + def pair(self): + return (self.__x, self.__y) + + def curve(self): + return self.__curve + + def order(self): + return self.__order + + +# This one point is the Point At Infinity for all purposes: +INFINITY = Point(None, None, None) + + +def __main__(): + + class FailedTest(Exception): + pass + + def test_add(c, x1, y1, x2, y2, x3, y3): + """We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3).""" + p1 = Point(c, x1, y1) + p2 = Point(c, x2, y2) + p3 = p1 + p2 + print("%s + %s = %s" % (p1, p2, p3)) + if p3.x() != x3 or p3.y() != y3: + raise FailedTest("Failure: should give (%d,%d)." % (x3, y3)) + else: + print(" Good.") + + def test_double(c, x1, y1, x3, y3): + """We expect that on curve c, 2*(x1,y1) = (x3, y3).""" + p1 = Point(c, x1, y1) + p3 = p1.double() + print("%s doubled = %s" % (p1, p3)) + if p3.x() != x3 or p3.y() != y3: + raise FailedTest("Failure: should give (%d,%d)." % (x3, y3)) + else: + print(" Good.") + + def test_double_infinity(c): + """We expect that on curve c, 2*INFINITY = INFINITY.""" + p1 = INFINITY + p3 = p1.double() + print("%s doubled = %s" % (p1, p3)) + if p3.x() != INFINITY.x() or p3.y() != INFINITY.y(): + raise FailedTest("Failure: should give (%d,%d)." % (INFINITY.x(), INFINITY.y())) + else: + print(" Good.") + + def test_multiply(c, x1, y1, m, x3, y3): + """We expect that on curve c, m*(x1,y1) = (x3,y3).""" + p1 = Point(c, x1, y1) + p3 = p1 * m + print("%s * %d = %s" % (p1, m, p3)) + if p3.x() != x3 or p3.y() != y3: + raise FailedTest("Failure: should give (%d,%d)." % (x3, y3)) + else: + print(" Good.") + + # A few tests from X9.62 B.3: + + c = CurveFp(23, 1, 1) + test_add(c, 3, 10, 9, 7, 17, 20) + test_double(c, 3, 10, 7, 12) + test_add(c, 3, 10, 3, 10, 7, 12) # (Should just invoke double.) + test_multiply(c, 3, 10, 2, 7, 12) + + test_double_infinity(c) + + # From X9.62 I.1 (p. 96): + + g = Point(c, 13, 7, 7) + + check = INFINITY + for i in range(7 + 1): + p = (i % 7) * g + print("%s * %d = %s, expected %s . . ." % (g, i, p, check)) + if p == check: + print(" Good.") + else: + raise FailedTest("Bad.") + check = check + g + + # NIST Curve P-192: + p = 6277101735386680763835789423207666416083908700390324961279 + r = 6277101735386680763835789423176059013767194773182842284081 + #s = 0x3045ae6fc8422f64ed579528d38120eae12196d5L + c = 0x3099d2bbbfcb2538542dcd5fb078b6ef5f3d6fe2c745de65 + b = 0x64210519e59c80e70fa7e9ab72243049feb8deecc146b9b1 + Gx = 0x188da80eb03090f67cbf20eb43a18800f4ff0afd82ff1012 + Gy = 0x07192b95ffc8da78631011ed6b24cdd573f977a11e794811 + + c192 = CurveFp(p, -3, b) + p192 = Point(c192, Gx, Gy, r) + + # Checking against some sample computations presented + # in X9.62: + + d = 651056770906015076056810763456358567190100156695615665659 + Q = d * p192 + if Q.x() != 0x62B12D60690CDCF330BABAB6E69763B471F994DD702D16A5: + raise FailedTest("p192 * d came out wrong.") + else: + print("p192 * d came out right.") + + k = 6140507067065001063065065565667405560006161556565665656654 + R = k * p192 + if R.x() != 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD \ + or R.y() != 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835: + raise FailedTest("k * p192 came out wrong.") + else: + print("k * p192 came out right.") + + u1 = 2563697409189434185194736134579731015366492496392189760599 + u2 = 6266643813348617967186477710235785849136406323338782220568 + temp = u1 * p192 + u2 * Q + if temp.x() != 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD \ + or temp.y() != 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835: + raise FailedTest("u1 * p192 + u2 * Q came out wrong.") + else: + print("u1 * p192 + u2 * Q came out right.") + + +if __name__ == "__main__": + __main__() diff --git a/basicswap/contrib/test_framework/__init__.py b/basicswap/contrib/test_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/basicswap/contrib/test_framework/address.py b/basicswap/contrib/test_framework/address.py new file mode 100644 index 0000000..7d15167 --- /dev/null +++ b/basicswap/contrib/test_framework/address.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) 2016-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Encode and decode BASE58, P2PKH and P2SH addresses.""" + +import enum +import unittest + +from .script import hash256, hash160, sha256, CScript, OP_0 +from .util import hex_str_to_bytes + +from . import segwit_addr + +from .util import assert_equal + +ADDRESS_BCRT1_UNSPENDABLE = 'bcrt1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq3xueyj' +ADDRESS_BCRT1_UNSPENDABLE_DESCRIPTOR = 'addr(bcrt1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq3xueyj)#juyq9d97' +# Coins sent to this address can be spent with a witness stack of just OP_TRUE +ADDRESS_BCRT1_P2WSH_OP_TRUE = 'bcrt1qft5p2uhsdcdc3l2ua4ap5qqfg4pjaqlp250x7us7a8qqhrxrxfsqseac85' + + +class AddressType(enum.Enum): + bech32 = 'bech32' + p2sh_segwit = 'p2sh-segwit' + legacy = 'legacy' # P2PKH + + +chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' + + +def byte_to_base58(b, version): + result = '' + str = b.hex() + str = chr(version).encode('latin-1').hex() + str + checksum = hash256(hex_str_to_bytes(str)).hex() + str += checksum[:8] + value = int('0x'+str,0) + while value > 0: + result = chars[value % 58] + result + value //= 58 + while (str[:2] == '00'): + result = chars[0] + result + str = str[2:] + return result + + +def base58_to_byte(s, verify_checksum=True): + if not s: + return b'' + n = 0 + for c in s: + n *= 58 + assert c in chars + digit = chars.index(c) + n += digit + h = '%x' % n + if len(h) % 2: + h = '0' + h + res = n.to_bytes((n.bit_length() + 7) // 8, 'big') + pad = 0 + for c in s: + if c == chars[0]: + pad += 1 + else: + break + res = b'\x00' * pad + res + if verify_checksum: + assert_equal(hash256(res[:-4])[:4], res[-4:]) + + return res[1:-4], int(res[0]) + + +def keyhash_to_p2pkh(hash, main = False, btc = True): + assert (len(hash) == 20 or len(hash) == 32) + if len(hash) == 20: + if btc: + version = 0 if main else 111 + else: + version = 56 if main else 118 + return byte_to_base58(hash, version) + version = 57 if main else 119 + return byte_to_base58(hash, version) + +def scripthash_to_p2sh(hash, main = False, btc = True): + assert (len(hash) == 20) + if btc: + version = 5 if main else 196 + else: + version = 60 if main else 122 + return byte_to_base58(hash, version) + +def key_to_p2pkh(key, main = False): + key = check_key(key) + return keyhash_to_p2pkh(hash160(key), main) + +def script_to_p2sh(script, main = False, btc = True): + script = check_script(script) + return scripthash_to_p2sh(hash160(script), main, btc) + +def key_to_p2sh_p2wpkh(key, main = False): + key = check_key(key) + p2shscript = CScript([OP_0, hash160(key)]) + return script_to_p2sh(p2shscript, main) + +def program_to_witness(version, program, main = False): + if (type(program) is str): + program = hex_str_to_bytes(program) + assert 0 <= version <= 16 + assert 2 <= len(program) <= 40 + assert version > 0 or len(program) in [20, 32] + return segwit_addr.encode("bc" if main else "bcrt", version, program) + +def script_to_p2wsh(script, main = False): + script = check_script(script) + return program_to_witness(0, sha256(script), main) + +def key_to_p2wpkh(key, main = False): + key = check_key(key) + return program_to_witness(0, hash160(key), main) + +def script_to_p2sh_p2wsh(script, main = False): + script = check_script(script) + p2shscript = CScript([OP_0, sha256(script)]) + return script_to_p2sh(p2shscript, main) + +def check_key(key): + if (type(key) is str): + key = hex_str_to_bytes(key) # Assuming this is hex string + if (type(key) is bytes and (len(key) == 33 or len(key) == 65)): + return key + assert False + +def check_script(script): + if (type(script) is str): + script = hex_str_to_bytes(script) # Assuming this is hex string + if (type(script) is bytes or type(script) is CScript): + return script + assert False + + +class TestFrameworkScript(unittest.TestCase): + def test_base58encodedecode(self): + def check_base58(data, version): + self.assertEqual(base58_to_byte(byte_to_base58(data, version)), (data, version)) + + check_base58(b'\x1f\x8e\xa1p*{\xd4\x94\x1b\xca\tA\xb8R\xc4\xbb\xfe\xdb.\x05', 111) + check_base58(b':\x0b\x05\xf4\xd7\xf6l;\xa7\x00\x9fE50)l\x84\\\xc9\xcf', 111) + check_base58(b'A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\0\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\x1f\x8e\xa1p*{\xd4\x94\x1b\xca\tA\xb8R\xc4\xbb\xfe\xdb.\x05', 0) + check_base58(b':\x0b\x05\xf4\xd7\xf6l;\xa7\x00\x9fE50)l\x84\\\xc9\xcf', 0) + check_base58(b'A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) + check_base58(b'\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) + check_base58(b'\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) + check_base58(b'\0\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) diff --git a/basicswap/contrib/test_framework/authproxy.py b/basicswap/contrib/test_framework/authproxy.py new file mode 100644 index 0000000..0530893 --- /dev/null +++ b/basicswap/contrib/test_framework/authproxy.py @@ -0,0 +1,204 @@ +# Copyright (c) 2011 Jeff Garzik +# +# Previous copyright, from python-jsonrpc/jsonrpc/proxy.py: +# +# Copyright (c) 2007 Jan-Klaas Kollhof +# +# This file is part of jsonrpc. +# +# jsonrpc is free software; you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation; either version 2.1 of the License, or +# (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this software; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +"""HTTP proxy for opening RPC connection to bitcoind. + +AuthServiceProxy has the following improvements over python-jsonrpc's +ServiceProxy class: + +- HTTP connections persist for the life of the AuthServiceProxy object + (if server supports HTTP/1.1) +- sends protocol 'version', per JSON-RPC 1.1 +- sends proper, incrementing 'id' +- sends Basic HTTP authentication headers +- parses all JSON numbers that look like floats as Decimal +- uses standard Python json lib +""" + +import base64 +import decimal +from http import HTTPStatus +import http.client +import json +import logging +import os +import socket +import time +import urllib.parse + +HTTP_TIMEOUT = 30 +USER_AGENT = "AuthServiceProxy/0.1" + +log = logging.getLogger("BitcoinRPC") + +class JSONRPCException(Exception): + def __init__(self, rpc_error, http_status=None): + try: + errmsg = '%(message)s (%(code)i)' % rpc_error + except (KeyError, TypeError): + errmsg = '' + super().__init__(errmsg) + self.error = rpc_error + self.http_status = http_status + + +def EncodeDecimal(o): + if isinstance(o, decimal.Decimal): + return str(o) + raise TypeError(repr(o) + " is not JSON serializable") + +class AuthServiceProxy(): + __id_count = 0 + + # ensure_ascii: escape unicode as \uXXXX, passed to json.dumps + def __init__(self, service_url, service_name=None, timeout=HTTP_TIMEOUT, connection=None, ensure_ascii=True): + self.__service_url = service_url + self._service_name = service_name + self.ensure_ascii = ensure_ascii # can be toggled on the fly by tests + self.__url = urllib.parse.urlparse(service_url) + user = None if self.__url.username is None else self.__url.username.encode('utf8') + passwd = None if self.__url.password is None else self.__url.password.encode('utf8') + authpair = user + b':' + passwd + self.__auth_header = b'Basic ' + base64.b64encode(authpair) + self.timeout = timeout + self._set_conn(connection) + + def __getattr__(self, name): + if name.startswith('__') and name.endswith('__'): + # Python internal stuff + raise AttributeError + if self._service_name is not None: + name = "%s.%s" % (self._service_name, name) + return AuthServiceProxy(self.__service_url, name, connection=self.__conn) + + def _request(self, method, path, postdata): + ''' + Do a HTTP request, with retry if we get disconnected (e.g. due to a timeout). + This is a workaround for https://bugs.python.org/issue3566 which is fixed in Python 3.5. + ''' + headers = {'Host': self.__url.hostname, + 'User-Agent': USER_AGENT, + 'Authorization': self.__auth_header, + 'Content-type': 'application/json'} + if os.name == 'nt': + # Windows somehow does not like to re-use connections + # TODO: Find out why the connection would disconnect occasionally and make it reusable on Windows + # Avoid "ConnectionAbortedError: [WinError 10053] An established connection was aborted by the software in your host machine" + self._set_conn() + try: + self.__conn.request(method, path, postdata, headers) + return self._get_response() + except (BrokenPipeError, ConnectionResetError): + # Python 3.5+ raises BrokenPipeError when the connection was reset + # ConnectionResetError happens on FreeBSD + self.__conn.close() + self.__conn.request(method, path, postdata, headers) + return self._get_response() + except OSError as e: + retry = ( + '[WinError 10053] An established connection was aborted by the software in your host machine' in str(e)) + if retry: + self.__conn.close() + self.__conn.request(method, path, postdata, headers) + return self._get_response() + else: + raise + + def get_request(self, *args, **argsn): + AuthServiceProxy.__id_count += 1 + + log.debug("-{}-> {} {}".format( + AuthServiceProxy.__id_count, + self._service_name, + json.dumps(args or argsn, default=EncodeDecimal, ensure_ascii=self.ensure_ascii), + )) + if args and argsn: + raise ValueError('Cannot handle both named and positional arguments') + return {'version': '1.1', + 'method': self._service_name, + 'params': args or argsn, + 'id': AuthServiceProxy.__id_count} + + def __call__(self, *args, **argsn): + postdata = json.dumps(self.get_request(*args, **argsn), default=EncodeDecimal, ensure_ascii=self.ensure_ascii) + response, status = self._request('POST', self.__url.path, postdata.encode('utf-8')) + if response['error'] is not None: + raise JSONRPCException(response['error'], status) + elif 'result' not in response: + raise JSONRPCException({ + 'code': -343, 'message': 'missing JSON-RPC result'}, status) + elif status != HTTPStatus.OK: + raise JSONRPCException({ + 'code': -342, 'message': 'non-200 HTTP status code but no JSON-RPC error'}, status) + else: + return response['result'] + + def batch(self, rpc_call_list): + postdata = json.dumps(list(rpc_call_list), default=EncodeDecimal, ensure_ascii=self.ensure_ascii) + log.debug("--> " + postdata) + response, status = self._request('POST', self.__url.path, postdata.encode('utf-8')) + if status != HTTPStatus.OK: + raise JSONRPCException({ + 'code': -342, 'message': 'non-200 HTTP status code but no JSON-RPC error'}, status) + return response + + def _get_response(self): + req_start_time = time.time() + try: + http_response = self.__conn.getresponse() + except socket.timeout: + raise JSONRPCException({ + 'code': -344, + 'message': '%r RPC took longer than %f seconds. Consider ' + 'using larger timeout for calls that take ' + 'longer to return.' % (self._service_name, + self.__conn.timeout)}) + if http_response is None: + raise JSONRPCException({ + 'code': -342, 'message': 'missing HTTP response from server'}) + + content_type = http_response.getheader('Content-Type') + if content_type != 'application/json': + raise JSONRPCException( + {'code': -342, 'message': 'non-JSON HTTP response with \'%i %s\' from server' % (http_response.status, http_response.reason)}, + http_response.status) + + responsedata = http_response.read().decode('utf8') + response = json.loads(responsedata, parse_float=decimal.Decimal) + elapsed = time.time() - req_start_time + if "error" in response and response["error"] is None: + log.debug("<-%s- [%.6f] %s" % (response["id"], elapsed, json.dumps(response["result"], default=EncodeDecimal, ensure_ascii=self.ensure_ascii))) + else: + log.debug("<-- [%.6f] %s" % (elapsed, responsedata)) + return response, http_response.status + + def __truediv__(self, relative_uri): + return AuthServiceProxy("{}/{}".format(self.__service_url, relative_uri), self._service_name, connection=self.__conn) + + def _set_conn(self, connection=None): + port = 80 if self.__url.port is None else self.__url.port + if connection: + self.__conn = connection + self.timeout = connection.timeout + elif self.__url.scheme == 'https': + self.__conn = http.client.HTTPSConnection(self.__url.hostname, port, timeout=self.timeout) + else: + self.__conn = http.client.HTTPConnection(self.__url.hostname, port, timeout=self.timeout) diff --git a/basicswap/contrib/test_framework/coverage.py b/basicswap/contrib/test_framework/coverage.py new file mode 100644 index 0000000..7705dd3 --- /dev/null +++ b/basicswap/contrib/test_framework/coverage.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright (c) 2015-2018 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Utilities for doing coverage analysis on the RPC interface. + +Provides a way to track which RPC commands are exercised during +testing. +""" + +import os + + +REFERENCE_FILENAME = 'rpc_interface.txt' + + +class AuthServiceProxyWrapper(): + """ + An object that wraps AuthServiceProxy to record specific RPC calls. + + """ + def __init__(self, auth_service_proxy_instance, coverage_logfile=None): + """ + Kwargs: + auth_service_proxy_instance (AuthServiceProxy): the instance + being wrapped. + coverage_logfile (str): if specified, write each service_name + out to a file when called. + + """ + self.auth_service_proxy_instance = auth_service_proxy_instance + self.coverage_logfile = coverage_logfile + + def __getattr__(self, name): + return_val = getattr(self.auth_service_proxy_instance, name) + if not isinstance(return_val, type(self.auth_service_proxy_instance)): + # If proxy getattr returned an unwrapped value, do the same here. + return return_val + return AuthServiceProxyWrapper(return_val, self.coverage_logfile) + + def __call__(self, *args, **kwargs): + """ + Delegates to AuthServiceProxy, then writes the particular RPC method + called to a file. + + """ + return_val = self.auth_service_proxy_instance.__call__(*args, **kwargs) + self._log_call() + return return_val + + def _log_call(self): + rpc_method = self.auth_service_proxy_instance._service_name + + if self.coverage_logfile: + with open(self.coverage_logfile, 'a+', encoding='utf8') as f: + f.write("%s\n" % rpc_method) + + def __truediv__(self, relative_uri): + return AuthServiceProxyWrapper(self.auth_service_proxy_instance / relative_uri, + self.coverage_logfile) + + def get_request(self, *args, **kwargs): + self._log_call() + return self.auth_service_proxy_instance.get_request(*args, **kwargs) + +def get_filename(dirname, n_node): + """ + Get a filename unique to the test process ID and node. + + This file will contain a list of RPC commands covered. + """ + pid = str(os.getpid()) + return os.path.join( + dirname, "coverage.pid%s.node%s.txt" % (pid, str(n_node))) + + +def write_all_rpc_commands(dirname, node): + """ + Write out a list of all RPC functions available in `bitcoin-cli` for + coverage comparison. This will only happen once per coverage + directory. + + Args: + dirname (str): temporary test dir + node (AuthServiceProxy): client + + Returns: + bool. if the RPC interface file was written. + + """ + filename = os.path.join(dirname, REFERENCE_FILENAME) + + if os.path.isfile(filename): + return False + + help_output = node.help().split('\n') + commands = set() + + for line in help_output: + line = line.strip() + + # Ignore blanks and headers + if line and not line.startswith('='): + commands.add("%s\n" % line.split()[0]) + + with open(filename, 'w', encoding='utf8') as f: + f.writelines(list(commands)) + + return True diff --git a/basicswap/contrib/test_framework/key.py b/basicswap/contrib/test_framework/key.py new file mode 100644 index 0000000..55e2de1 --- /dev/null +++ b/basicswap/contrib/test_framework/key.py @@ -0,0 +1,393 @@ +# Copyright (c) 2019 Pieter Wuille +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Test-only secp256k1 elliptic curve implementation + +WARNING: This code is slow, uses bad randomness, does not properly protect +keys, and is trivially vulnerable to side channel attacks. Do not use for +anything but tests.""" +import random + +def modinv(a, n): + """Compute the modular inverse of a modulo n + + See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers. + """ + t1, t2 = 0, 1 + r1, r2 = n, a + while r2 != 0: + q = r1 // r2 + t1, t2 = t2, t1 - q * t2 + r1, r2 = r2, r1 - q * r2 + if r1 > 1: + return None + if t1 < 0: + t1 += n + return t1 + +def jacobi_symbol(n, k): + """Compute the Jacobi symbol of n modulo k + + See http://en.wikipedia.org/wiki/Jacobi_symbol + + For our application k is always prime, so this is the same as the Legendre symbol.""" + assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k" + n %= k + t = 0 + while n != 0: + while n & 1 == 0: + n >>= 1 + r = k & 7 + t ^= (r == 3 or r == 5) + n, k = k, n + t ^= (n & k & 3 == 3) + n = n % k + if k == 1: + return -1 if t else 1 + return 0 + +def modsqrt(a, p): + """Compute the square root of a modulo p when p % 4 = 3. + + The Tonelli-Shanks algorithm can be used. See https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm + + Limiting this function to only work for p % 4 = 3 means we don't need to + iterate through the loop. The highest n such that p - 1 = 2^n Q with Q odd + is n = 1. Therefore Q = (p-1)/2 and sqrt = a^((Q+1)/2) = a^((p+1)/4) + + secp256k1's is defined over field of size 2**256 - 2**32 - 977, which is 3 mod 4. + """ + if p % 4 != 3: + raise NotImplementedError("modsqrt only implemented for p % 4 = 3") + sqrt = pow(a, (p + 1)//4, p) + if pow(sqrt, 2, p) == a % p: + return sqrt + return None + +class EllipticCurve: + def __init__(self, p, a, b): + """Initialize elliptic curve y^2 = x^3 + a*x + b over GF(p).""" + self.p = p + self.a = a % p + self.b = b % p + + def affine(self, p1): + """Convert a Jacobian point tuple p1 to affine form, or None if at infinity. + + An affine point is represented as the Jacobian (x, y, 1)""" + x1, y1, z1 = p1 + if z1 == 0: + return None + inv = modinv(z1, self.p) + inv_2 = (inv**2) % self.p + inv_3 = (inv_2 * inv) % self.p + return ((inv_2 * x1) % self.p, (inv_3 * y1) % self.p, 1) + + def negate(self, p1): + """Negate a Jacobian point tuple p1.""" + x1, y1, z1 = p1 + return (x1, (self.p - y1) % self.p, z1) + + def on_curve(self, p1): + """Determine whether a Jacobian tuple p is on the curve (and not infinity)""" + x1, y1, z1 = p1 + z2 = pow(z1, 2, self.p) + z4 = pow(z2, 2, self.p) + return z1 != 0 and (pow(x1, 3, self.p) + self.a * x1 * z4 + self.b * z2 * z4 - pow(y1, 2, self.p)) % self.p == 0 + + def is_x_coord(self, x): + """Test whether x is a valid X coordinate on the curve.""" + x_3 = pow(x, 3, self.p) + return jacobi_symbol(x_3 + self.a * x + self.b, self.p) != -1 + + def lift_x(self, x): + """Given an X coordinate on the curve, return a corresponding affine point.""" + x_3 = pow(x, 3, self.p) + v = x_3 + self.a * x + self.b + y = modsqrt(v, self.p) + if y is None: + return None + return (x, y, 1) + + def double(self, p1): + """Double a Jacobian tuple p1 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Doubling""" + x1, y1, z1 = p1 + if z1 == 0: + return (0, 1, 0) + y1_2 = (y1**2) % self.p + y1_4 = (y1_2**2) % self.p + x1_2 = (x1**2) % self.p + s = (4*x1*y1_2) % self.p + m = 3*x1_2 + if self.a: + m += self.a * pow(z1, 4, self.p) + m = m % self.p + x2 = (m**2 - 2*s) % self.p + y2 = (m*(s - x2) - 8*y1_4) % self.p + z2 = (2*y1*z1) % self.p + return (x2, y2, z2) + + def add_mixed(self, p1, p2): + """Add a Jacobian tuple p1 and an affine tuple p2 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition (with affine point)""" + x1, y1, z1 = p1 + x2, y2, z2 = p2 + assert(z2 == 1) + # Adding to the point at infinity is a no-op + if z1 == 0: + return p2 + z1_2 = (z1**2) % self.p + z1_3 = (z1_2 * z1) % self.p + u2 = (x2 * z1_2) % self.p + s2 = (y2 * z1_3) % self.p + if x1 == u2: + if (y1 != s2): + # p1 and p2 are inverses. Return the point at infinity. + return (0, 1, 0) + # p1 == p2. The formulas below fail when the two points are equal. + return self.double(p1) + h = u2 - x1 + r = s2 - y1 + h_2 = (h**2) % self.p + h_3 = (h_2 * h) % self.p + u1_h_2 = (x1 * h_2) % self.p + x3 = (r**2 - h_3 - 2*u1_h_2) % self.p + y3 = (r*(u1_h_2 - x3) - y1*h_3) % self.p + z3 = (h*z1) % self.p + return (x3, y3, z3) + + def add(self, p1, p2): + """Add two Jacobian tuples p1 and p2 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition""" + x1, y1, z1 = p1 + x2, y2, z2 = p2 + # Adding the point at infinity is a no-op + if z1 == 0: + return p2 + if z2 == 0: + return p1 + # Adding an Affine to a Jacobian is more efficient since we save field multiplications and squarings when z = 1 + if z1 == 1: + return self.add_mixed(p2, p1) + if z2 == 1: + return self.add_mixed(p1, p2) + z1_2 = (z1**2) % self.p + z1_3 = (z1_2 * z1) % self.p + z2_2 = (z2**2) % self.p + z2_3 = (z2_2 * z2) % self.p + u1 = (x1 * z2_2) % self.p + u2 = (x2 * z1_2) % self.p + s1 = (y1 * z2_3) % self.p + s2 = (y2 * z1_3) % self.p + if u1 == u2: + if (s1 != s2): + # p1 and p2 are inverses. Return the point at infinity. + return (0, 1, 0) + # p1 == p2. The formulas below fail when the two points are equal. + return self.double(p1) + h = u2 - u1 + r = s2 - s1 + h_2 = (h**2) % self.p + h_3 = (h_2 * h) % self.p + u1_h_2 = (u1 * h_2) % self.p + x3 = (r**2 - h_3 - 2*u1_h_2) % self.p + y3 = (r*(u1_h_2 - x3) - s1*h_3) % self.p + z3 = (h*z1*z2) % self.p + return (x3, y3, z3) + + def mul(self, ps): + """Compute a (multi) point multiplication + + ps is a list of (Jacobian tuple, scalar) pairs. + """ + r = (0, 1, 0) + for i in range(255, -1, -1): + r = self.double(r) + for (p, n) in ps: + if ((n >> i) & 1): + r = self.add(r, p) + return r + +SECP256K1 = EllipticCurve(2**256 - 2**32 - 977, 0, 7) +SECP256K1_G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8, 1) +SECP256K1_ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 +SECP256K1_ORDER_HALF = SECP256K1_ORDER // 2 + +class ECPubKey(): + """A secp256k1 public key""" + + def __init__(self): + """Construct an uninitialized public key""" + self.valid = False + + def set_int(self, x, y): + p = (x, y, 1) + self.valid = SECP256K1.on_curve(p) + if self.valid: + self.p = p + self.compressed = False + + def set(self, data): + """Construct a public key from a serialization in compressed or uncompressed format""" + if (len(data) == 65 and data[0] == 0x04): + p = (int.from_bytes(data[1:33], 'big'), int.from_bytes(data[33:65], 'big'), 1) + self.valid = SECP256K1.on_curve(p) + if self.valid: + self.p = p + self.compressed = False + elif (len(data) == 33 and (data[0] == 0x02 or data[0] == 0x03)): + x = int.from_bytes(data[1:33], 'big') + if SECP256K1.is_x_coord(x): + p = SECP256K1.lift_x(x) + # if the oddness of the y co-ord isn't correct, find the other + # valid y + if (p[1] & 1) != (data[0] & 1): + p = SECP256K1.negate(p) + self.p = p + self.valid = True + self.compressed = True + else: + self.valid = False + else: + self.valid = False + + @property + def is_compressed(self): + return self.compressed + + @property + def is_valid(self): + return self.valid + + def get_bytes(self): + assert(self.valid) + p = SECP256K1.affine(self.p) + if p is None: + return None + if self.compressed: + return bytes([0x02 + (p[1] & 1)]) + p[0].to_bytes(32, 'big') + else: + return bytes([0x04]) + p[0].to_bytes(32, 'big') + p[1].to_bytes(32, 'big') + + def verify_ecdsa(self, sig, msg, low_s=True): + """Verify a strictly DER-encoded ECDSA signature against this pubkey. + + See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the + ECDSA verifier algorithm""" + assert(self.valid) + + # Extract r and s from the DER formatted signature. Return false for + # any DER encoding errors. + if (sig[1] + 2 != len(sig)): + return False + if (len(sig) < 4): + return False + if (sig[0] != 0x30): + return False + if (sig[2] != 0x02): + return False + rlen = sig[3] + if (len(sig) < 6 + rlen): + return False + if rlen < 1 or rlen > 33: + return False + if sig[4] >= 0x80: + return False + if (rlen > 1 and (sig[4] == 0) and not (sig[5] & 0x80)): + return False + r = int.from_bytes(sig[4:4+rlen], 'big') + if (sig[4+rlen] != 0x02): + return False + slen = sig[5+rlen] + if slen < 1 or slen > 33: + return False + if (len(sig) != 6 + rlen + slen): + return False + if sig[6+rlen] >= 0x80: + return False + if (slen > 1 and (sig[6+rlen] == 0) and not (sig[7+rlen] & 0x80)): + return False + s = int.from_bytes(sig[6+rlen:6+rlen+slen], 'big') + + # Verify that r and s are within the group order + if r < 1 or s < 1 or r >= SECP256K1_ORDER or s >= SECP256K1_ORDER: + return False + if low_s and s >= SECP256K1_ORDER_HALF: + return False + z = int.from_bytes(msg, 'big') + + # Run verifier algorithm on r, s + w = modinv(s, SECP256K1_ORDER) + u1 = z*w % SECP256K1_ORDER + u2 = r*w % SECP256K1_ORDER + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, u1), (self.p, u2)])) + if R is None or R[0] != r: + return False + return True + +class ECKey(): + """A secp256k1 private key""" + + def __init__(self): + self.valid = False + + def set(self, secret, compressed): + """Construct a private key object with given 32-byte secret and compressed flag.""" + assert(len(secret) == 32) + secret = int.from_bytes(secret, 'big') + self.valid = (secret > 0 and secret < SECP256K1_ORDER) + if self.valid: + self.secret = secret + self.compressed = compressed + + def generate(self, compressed=True): + """Generate a random private key (compressed or uncompressed).""" + self.set(random.randrange(1, SECP256K1_ORDER).to_bytes(32, 'big'), compressed) + + def get_bytes(self): + """Retrieve the 32-byte representation of this key.""" + assert(self.valid) + return self.secret.to_bytes(32, 'big') + + @property + def is_valid(self): + return self.valid + + @property + def is_compressed(self): + return self.compressed + + def get_pubkey(self): + """Compute an ECPubKey object for this secret key.""" + assert(self.valid) + ret = ECPubKey() + p = SECP256K1.mul([(SECP256K1_G, self.secret)]) + ret.p = p + ret.valid = True + ret.compressed = self.compressed + return ret + + def sign_ecdsa(self, msg, low_s=True): + """Construct a DER-encoded ECDSA signature with this key. + + See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the + ECDSA signer algorithm.""" + assert(self.valid) + z = int.from_bytes(msg, 'big') + # Note: no RFC6979, but a simple random nonce (some tests rely on distinct transactions for the same operation) + k = random.randrange(1, SECP256K1_ORDER) + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, k)])) + r = R[0] % SECP256K1_ORDER + s = (modinv(k, SECP256K1_ORDER) * (z + self.secret * r)) % SECP256K1_ORDER + if low_s and s > SECP256K1_ORDER_HALF: + s = SECP256K1_ORDER - s + # Represent in DER format. The byte representations of r and s have + # length rounded up (255 bits becomes 32 bytes and 256 bits becomes 33 + # bytes). + rb = r.to_bytes((r.bit_length() + 8) // 8, 'big') + sb = s.to_bytes((s.bit_length() + 8) // 8, 'big') + return b'\x30' + bytes([4 + len(rb) + len(sb), 2, len(rb)]) + rb + bytes([2, len(sb)]) + sb diff --git a/basicswap/contrib/test_framework/messages.py b/basicswap/contrib/test_framework/messages.py new file mode 100755 index 0000000..70e353e --- /dev/null +++ b/basicswap/contrib/test_framework/messages.py @@ -0,0 +1,1756 @@ +#!/usr/bin/env python3 +# Copyright (c) 2010 ArtForz -- public domain half-a-node +# Copyright (c) 2012 Jeff Garzik +# Copyright (c) 2010-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Bitcoin test framework primitive and message structures + +CBlock, CTransaction, CBlockHeader, CTxIn, CTxOut, etc....: + data structures that should map to corresponding structures in + bitcoin/primitives + +msg_block, msg_tx, msg_headers, etc.: + data structures that represent network messages + +ser_*, deser_*: functions that handle serialization/deserialization. + +Classes use __slots__ to ensure extraneous attributes aren't accidentally added +by tests, compromising their intended effect. +""" +from codecs import encode +import copy +import hashlib +from io import BytesIO +import random +import socket +import struct +import time + +from .siphash import siphash256 +from .util import hex_str_to_bytes, assert_equal + +MIN_VERSION_SUPPORTED = 60001 +#MY_VERSION = 70014 # past bip-31 for ping/pong +MY_VERSION = 90009 +MY_SUBVERSION = b"/python-mininode-tester:0.0.3/" +MY_RELAY = 1 # from version 70001 onwards, fRelay should be appended to version messages (BIP37) + +MAX_LOCATOR_SZ = 101 +MAX_BLOCK_BASE_SIZE = 1000000 +MAX_BLOOM_FILTER_SIZE = 36000 +MAX_BLOOM_HASH_FUNCS = 50 + +COIN = 100000000 # 1 btc in satoshis +MAX_MONEY = 21000000 * COIN + +BIP125_SEQUENCE_NUMBER = 0xfffffffd # Sequence number that is BIP 125 opt-in and BIP 68-opt-out + +NODE_NETWORK = (1 << 0) +NODE_GETUTXO = (1 << 1) +NODE_BLOOM = (1 << 2) +NODE_WITNESS = (1 << 3) +NODE_NETWORK_LIMITED = (1 << 10) + +MSG_TX = 1 +MSG_BLOCK = 2 +MSG_FILTERED_BLOCK = 3 +MSG_CMPCT_BLOCK = 4 +MSG_WITNESS_FLAG = 1 << 30 +MSG_TYPE_MASK = 0xffffffff >> 2 + +FILTER_TYPE_BASIC = 0 + +PARTICL_TX_VERSION = 0xa0 +PARTICL_TX_ANON_MARKER = 0xffffffa0 +OUTPUT_TYPE_STANDARD = 1 +OUTPUT_TYPE_CT = 2 +OUTPUT_TYPE_RINGCT = 3 +OUTPUT_TYPE_DATA = 4 + + +# Serialization/deserialization tools +def sha256(s): + return hashlib.new('sha256', s).digest() + +def hash256(s): + return sha256(sha256(s)) + +def ser_compact_size(l): + r = b"" + if l < 253: + r = struct.pack("B", l) + elif l < 0x10000: + r = struct.pack("<BH", 253, l) + elif l < 0x100000000: + r = struct.pack("<BI", 254, l) + else: + r = struct.pack("<BQ", 255, l) + return r + +def deser_compact_size(f): + nit = struct.unpack("<B", f.read(1))[0] + if nit == 253: + nit = struct.unpack("<H", f.read(2))[0] + elif nit == 254: + nit = struct.unpack("<I", f.read(4))[0] + elif nit == 255: + nit = struct.unpack("<Q", f.read(8))[0] + return nit + +def deser_string(f): + nit = deser_compact_size(f) + return f.read(nit) + +def ser_string(s): + return ser_compact_size(len(s)) + s + +def deser_uint256(f): + r = 0 + for i in range(8): + t = struct.unpack("<I", f.read(4))[0] + r += t << (i * 32) + return r + + +def ser_uint256(u): + rs = b"" + for i in range(8): + rs += struct.pack("<I", u & 0xFFFFFFFF) + u >>= 32 + return rs + + +def uint256_from_str(s): + r = 0 + t = struct.unpack("<IIIIIIII", s[:32]) + for i in range(8): + r += t[i] << (i * 32) + return r + + +def uint256_from_compact(c): + nbytes = (c >> 24) & 0xFF + v = (c & 0xFFFFFF) << (8 * (nbytes - 3)) + return v + + +def deser_vector(f, c): + nit = deser_compact_size(f) + r = [] + for i in range(nit): + t = c() + t.deserialize(f) + r.append(t) + return r + + +# ser_function_name: Allow for an alternate serialization function on the +# entries in the vector (we use this for serializing the vector of transactions +# for a witness block). +def ser_vector(l, ser_function_name=None): + r = ser_compact_size(len(l)) + for i in l: + if ser_function_name: + r += getattr(i, ser_function_name)() + else: + r += i.serialize() + return r + + +def deser_uint256_vector(f): + nit = deser_compact_size(f) + r = [] + for i in range(nit): + t = deser_uint256(f) + r.append(t) + return r + + +def ser_uint256_vector(l): + r = ser_compact_size(len(l)) + for i in l: + r += ser_uint256(i) + return r + + +def deser_string_vector(f): + nit = deser_compact_size(f) + r = [] + for i in range(nit): + t = deser_string(f) + r.append(t) + return r + + +def ser_string_vector(l): + r = ser_compact_size(len(l)) + for sv in l: + r += ser_string(sv) + return r + + +# Deserialize from a hex string representation (eg from RPC) +def FromHex(obj, hex_string): + obj.deserialize(BytesIO(hex_str_to_bytes(hex_string))) + return obj + +# Convert a binary-serializable object to hex (eg for submission via RPC) +def ToHex(obj): + return obj.serialize().hex() + +# Objects that map to bitcoind objects, which can be serialized/deserialized + + +class CAddress: + __slots__ = ("ip", "nServices", "pchReserved", "port", "time") + + def __init__(self): + self.time = 0 + self.nServices = 1 + self.pchReserved = b"\x00" * 10 + b"\xff" * 2 + self.ip = "0.0.0.0" + self.port = 0 + + def deserialize(self, f, with_time=True): + if with_time: + self.time = struct.unpack("<i", f.read(4))[0] + self.nServices = struct.unpack("<Q", f.read(8))[0] + self.pchReserved = f.read(12) + self.ip = socket.inet_ntoa(f.read(4)) + self.port = struct.unpack(">H", f.read(2))[0] + + def serialize(self, with_time=True): + r = b"" + if with_time: + r += struct.pack("<i", self.time) + r += struct.pack("<Q", self.nServices) + r += self.pchReserved + r += socket.inet_aton(self.ip) + r += struct.pack(">H", self.port) + return r + + def __repr__(self): + return "CAddress(nServices=%i ip=%s port=%i)" % (self.nServices, + self.ip, self.port) + + +class CInv: + __slots__ = ("hash", "type") + + typemap = { + 0: "Error", + MSG_TX: "TX", + MSG_BLOCK: "Block", + MSG_TX | MSG_WITNESS_FLAG: "WitnessTx", + MSG_BLOCK | MSG_WITNESS_FLAG: "WitnessBlock", + MSG_FILTERED_BLOCK: "filtered Block", + 4: "CompactBlock" + } + + def __init__(self, t=0, h=0): + self.type = t + self.hash = h + + def deserialize(self, f): + self.type = struct.unpack("<i", f.read(4))[0] + self.hash = deser_uint256(f) + + def serialize(self): + r = b"" + r += struct.pack("<i", self.type) + r += ser_uint256(self.hash) + return r + + def __repr__(self): + return "CInv(type=%s hash=%064x)" \ + % (self.typemap[self.type], self.hash) + + +class CBlockLocator: + __slots__ = ("nVersion", "vHave") + + def __init__(self): + self.nVersion = MY_VERSION + self.vHave = [] + + def deserialize(self, f): + self.nVersion = struct.unpack("<i", f.read(4))[0] + self.vHave = deser_uint256_vector(f) + + def serialize(self): + r = b"" + r += struct.pack("<i", self.nVersion) + r += ser_uint256_vector(self.vHave) + return r + + def __repr__(self): + return "CBlockLocator(nVersion=%i vHave=%s)" \ + % (self.nVersion, repr(self.vHave)) + + +class COutPoint: + __slots__ = ("hash", "n") + + def __init__(self, hash=0, n=0): + self.hash = hash + self.n = n + + def deserialize(self, f): + self.hash = deser_uint256(f) + self.n = struct.unpack("<I", f.read(4))[0] + + def serialize(self): + r = b"" + r += ser_uint256(self.hash) + r += struct.pack("<I", self.n) + return r + + def __repr__(self): + return "COutPoint(hash=%064x n=%i)" % (self.hash, self.n) + + +class CTxIn: + __slots__ = ("nSequence", "prevout", "scriptSig") + + def __init__(self, outpoint=None, scriptSig=b"", nSequence=0): + if outpoint is None: + self.prevout = COutPoint() + else: + self.prevout = outpoint + self.scriptSig = scriptSig + self.nSequence = nSequence + + def deserialize(self, f): + self.prevout = COutPoint() + self.prevout.deserialize(f) + self.scriptSig = deser_string(f) + self.nSequence = struct.unpack("<I", f.read(4))[0] + + def serialize(self): + r = b"" + r += self.prevout.serialize() + r += ser_string(self.scriptSig) + r += struct.pack("<I", self.nSequence) + return r + + def __repr__(self): + return "CTxIn(prevout=%s scriptSig=%s nSequence=%i)" \ + % (repr(self.prevout), self.scriptSig.hex(), + self.nSequence) + +class CTxOutPart: + __slots__ = ("nVersion", "nValue", "scriptPubKey") + + def __init__(self, nValue=0, scriptPubKey=b""): + self.nVersion = OUTPUT_TYPE_STANDARD + self.nValue = nValue + self.scriptPubKey = scriptPubKey + + def deserialize(self, f): + self.nValue = struct.unpack("<q", f.read(8))[0] + self.scriptPubKey = deser_string(f) + + def serialize(self): + r = b"" + r += struct.pack("<q", self.nValue) + r += ser_string(self.scriptPubKey) + return r + + def __repr__(self): + return "CTxOutPart(nValue=%i.%08i scriptPubKey=%s)" \ + % (self.nValue // COIN, self.nValue % COIN, + self.scriptPubKey.hex()) + +class CTxOut: + __slots__ = ("nValue", "scriptPubKey") + + def __init__(self, nValue=0, scriptPubKey=b""): + self.nValue = nValue + self.scriptPubKey = scriptPubKey + + def deserialize(self, f): + self.nValue = struct.unpack("<q", f.read(8))[0] + self.scriptPubKey = deser_string(f) + + def serialize(self): + r = b"" + r += struct.pack("<q", self.nValue) + r += ser_string(self.scriptPubKey) + return r + + def __repr__(self): + return "CTxOut(nValue=%i.%08i scriptPubKey=%s)" \ + % (self.nValue // COIN, self.nValue % COIN, + self.scriptPubKey.hex()) + + +class CScriptWitness: + __slots__ = ("stack",) + + def __init__(self): + # stack is a vector of strings + self.stack = [] + + def __repr__(self): + return "CScriptWitness(%s)" % \ + (",".join([x.hex() for x in self.stack])) + + def is_null(self): + if self.stack: + return False + return True + + +class CTxInWitness: + __slots__ = ("scriptWitness",) + + def __init__(self): + self.scriptWitness = CScriptWitness() + + def deserialize(self, f): + self.scriptWitness.stack = deser_string_vector(f) + + def serialize(self): + return ser_string_vector(self.scriptWitness.stack) + + def __repr__(self): + return repr(self.scriptWitness) + + def is_null(self): + return self.scriptWitness.is_null() + + +class CTxWitness: + __slots__ = ("vtxinwit",) + + def __init__(self): + self.vtxinwit = [] + + def deserialize(self, f): + for i in range(len(self.vtxinwit)): + self.vtxinwit[i].deserialize(f) + + def serialize(self): + r = b"" + # This is different than the usual vector serialization -- + # we omit the length of the vector, which is required to be + # the same length as the transaction's vin vector. + for x in self.vtxinwit: + r += x.serialize() + return r + + def __repr__(self): + return "CTxWitness(%s)" % \ + (';'.join([repr(x) for x in self.vtxinwit])) + + def is_null(self): + for x in self.vtxinwit: + if not x.is_null(): + return False + return True + + +class CTransaction: + __slots__ = ("hash", "nLockTime", "nVersion", "sha256", "vin", "vout", + "wit") + + def __init__(self, tx=None): + if tx is None: + self.nVersion = 1 + self.vin = [] + self.vout = [] + self.wit = CTxWitness() + self.nLockTime = 0 + self.sha256 = None + self.hash = None + else: + self.nVersion = tx.nVersion + self.vin = copy.deepcopy(tx.vin) + self.vout = copy.deepcopy(tx.vout) + self.nLockTime = tx.nLockTime + self.sha256 = tx.sha256 + self.hash = tx.hash + self.wit = copy.deepcopy(tx.wit) + + def deserialize(self, f): + self.nVersion = int(struct.unpack("<B", f.read(1))[0]) + if self.nVersion == PARTICL_TX_VERSION: + self.nVersion |= int(struct.unpack("<B", f.read(1))[0]) << 8 + + self.nLockTime = struct.unpack("<I", f.read(4))[0] + + self.vin = deser_vector(f, CTxIn) + + num_outputs = deser_compact_size(f) + self.vout.clear() + for i in range(num_outputs): + txo = CTxOutPart() + txo.nVersion = int(struct.unpack("<B", f.read(1))[0]) + txo.deserialize(f) + self.vout.append(txo) + + self.wit.vtxinwit = [CTxInWitness() for i in range(len(self.vin))] + self.wit.deserialize(f) + + self.sha256 = None + self.hash = None + return + + self.nVersion |= int(struct.unpack("<B", f.read(1))[0]) << 8 + self.nVersion |= int(struct.unpack("<B", f.read(1))[0]) << 16 + self.nVersion |= int(struct.unpack("<B", f.read(1))[0]) << 24 + #self.nVersion = struct.unpack("<i", f.read(4))[0] + self.vin = deser_vector(f, CTxIn) + flags = 0 + if len(self.vin) == 0: + flags = struct.unpack("<B", f.read(1))[0] + # Not sure why flags can't be zero, but this + # matches the implementation in bitcoind + if (flags != 0): + self.vin = deser_vector(f, CTxIn) + self.vout = deser_vector(f, CTxOut) + else: + self.vout = deser_vector(f, CTxOut) + if flags != 0: + self.wit.vtxinwit = [CTxInWitness() for i in range(len(self.vin))] + self.wit.deserialize(f) + else: + self.wit = CTxWitness() + self.nLockTime = struct.unpack("<I", f.read(4))[0] + self.sha256 = None + self.hash = None + + def serialize_without_witness(self): + if self.nVersion == PARTICL_TX_VERSION: + r = struct.pack("<H", self.nVersion) + r += struct.pack("<I", self.nLockTime) + r += ser_vector(self.vin) + r += ser_compact_size(len(self.vout)) + for txo in self.vout: + r += bytes((txo.nVersion,)) + r += txo.serialize() + return r + r = b"" + r += struct.pack("<i", self.nVersion) + r += ser_vector(self.vin) + r += ser_vector(self.vout) + r += struct.pack("<I", self.nLockTime) + return r + + # Only serialize with witness when explicitly called for + def serialize_with_witness(self): + if self.nVersion == PARTICL_TX_VERSION: + r = self.serialize_without_witness() + while len(self.wit.vtxinwit) < len(self.vin): + self.wit.vtxinwit.append(CTxInWitness()) + r += self.wit.serialize() + return r + flags = 0 + if not self.wit.is_null(): + flags |= 1 + r = b"" + r += struct.pack("<i", self.nVersion) + if flags: + dummy = [] + r += ser_vector(dummy) + r += struct.pack("<B", flags) + r += ser_vector(self.vin) + r += ser_vector(self.vout) + if flags & 1: + if (len(self.wit.vtxinwit) != len(self.vin)): + # vtxinwit must have the same length as vin + self.wit.vtxinwit = self.wit.vtxinwit[:len(self.vin)] + for i in range(len(self.wit.vtxinwit), len(self.vin)): + self.wit.vtxinwit.append(CTxInWitness()) + r += self.wit.serialize() + r += struct.pack("<I", self.nLockTime) + return r + + # Regular serialization is with witness -- must explicitly + # call serialize_without_witness to exclude witness data. + def serialize(self): + return self.serialize_with_witness() + + # Recalculate the txid (transaction hash without witness) + def rehash(self): + self.sha256 = None + self.calc_sha256() + return self.hash + + # We will only cache the serialization without witness in + # self.sha256 and self.hash -- those are expected to be the txid. + def calc_sha256(self, with_witness=False): + if with_witness: + # Don't cache the result, just return it + return uint256_from_str(hash256(self.serialize_with_witness())) + + if self.sha256 is None: + self.sha256 = uint256_from_str(hash256(self.serialize_without_witness())) + self.hash = encode(hash256(self.serialize_without_witness())[::-1], 'hex_codec').decode('ascii') + + def is_valid(self): + self.calc_sha256() + for tout in self.vout: + if tout.nValue < 0 or tout.nValue > 21000000 * COIN: + return False + return True + + def __repr__(self): + return "CTransaction(nVersion=%i vin=%s vout=%s wit=%s nLockTime=%i)" \ + % (self.nVersion, repr(self.vin), repr(self.vout), repr(self.wit), self.nLockTime) + + +class CBlockHeader: + __slots__ = ("hash", "hashMerkleRoot", "hashPrevBlock", "nBits", "nNonce", + "nTime", "nVersion", "sha256", + "is_part", "hashWitnessMerkleRoot") + + def __init__(self, header=None, is_part=False): + self.is_part = is_part + if header is None: + self.set_null() + else: + self.is_part = header.is_part + self.nVersion = header.nVersion + self.hashPrevBlock = header.hashPrevBlock + self.hashMerkleRoot = header.hashMerkleRoot + if self.is_part: + self.hashWitnessMerkleRoot = header.hashWitnessMerkleRoot + self.nTime = header.nTime + self.nBits = header.nBits + self.nNonce = header.nNonce + self.sha256 = header.sha256 + self.hash = header.hash + self.calc_sha256() + + def set_null(self): + self.nVersion = 1 + self.hashPrevBlock = 0 + self.hashMerkleRoot = 0 + if self.is_part: + self.hashWitnessMerkleRoot = 0 + self.nTime = 0 + self.nBits = 0 + self.nNonce = 0 + self.sha256 = None + self.hash = None + + def deserialize(self, f): + self.nVersion = struct.unpack("<i", f.read(4))[0] + self.hashPrevBlock = deser_uint256(f) + self.hashMerkleRoot = deser_uint256(f) + if self.is_part: + self.hashWitnessMerkleRoot = deser_uint256(f) + self.nTime = struct.unpack("<I", f.read(4))[0] + self.nBits = struct.unpack("<I", f.read(4))[0] + self.nNonce = struct.unpack("<I", f.read(4))[0] + self.sha256 = None + self.hash = None + + def serialize(self): + r = b"" + r += struct.pack("<i", self.nVersion) + r += ser_uint256(self.hashPrevBlock) + r += ser_uint256(self.hashMerkleRoot) + if self.is_part: + r += ser_uint256(self.hashWitnessMerkleRoot) + r += struct.pack("<I", self.nTime) + r += struct.pack("<I", self.nBits) + r += struct.pack("<I", self.nNonce) + return r + + def calc_sha256(self): + if self.sha256 is None: + r = b"" + r += struct.pack("<i", self.nVersion) + r += ser_uint256(self.hashPrevBlock) + r += ser_uint256(self.hashMerkleRoot) + if self.is_part: + r += ser_uint256(self.hashWitnessMerkleRoot) + r += struct.pack("<I", self.nTime) + r += struct.pack("<I", self.nBits) + r += struct.pack("<I", self.nNonce) + self.sha256 = uint256_from_str(hash256(r)) + self.hash = encode(hash256(r)[::-1], 'hex_codec').decode('ascii') + + def rehash(self): + self.sha256 = None + self.calc_sha256() + return self.sha256 + + def __repr__(self): + return "CBlockHeader(nVersion=%i hashPrevBlock=%064x hashMerkleRoot=%064x nTime=%s nBits=%08x nNonce=%08x)" \ + % (self.nVersion, self.hashPrevBlock, self.hashMerkleRoot, + time.ctime(self.nTime), self.nBits, self.nNonce) + +BLOCK_HEADER_SIZE = len(CBlockHeader().serialize()) +assert_equal(BLOCK_HEADER_SIZE, 80) + +class CBlock(CBlockHeader): + __slots__ = ("vtx",) + + def __init__(self, header=None): + super().__init__(header) + self.vtx = [] + + def deserialize(self, f): + super().deserialize(f) + self.vtx = deser_vector(f, CTransaction) + + def serialize(self, with_witness=True): + r = b"" + r += super().serialize() + if with_witness: + r += ser_vector(self.vtx, "serialize_with_witness") + else: + r += ser_vector(self.vtx, "serialize_without_witness") + return r + + # Calculate the merkle root given a vector of transaction hashes + @classmethod + def get_merkle_root(cls, hashes): + while len(hashes) > 1: + newhashes = [] + for i in range(0, len(hashes), 2): + i2 = min(i+1, len(hashes)-1) + newhashes.append(hash256(hashes[i] + hashes[i2])) + hashes = newhashes + return uint256_from_str(hashes[0]) + + def calc_merkle_root(self): + hashes = [] + for tx in self.vtx: + tx.calc_sha256() + hashes.append(ser_uint256(tx.sha256)) + return self.get_merkle_root(hashes) + + def calc_witness_merkle_root(self): + # For witness root purposes, the hash of the + # coinbase, with witness, is defined to be 0...0 + hashes = [ser_uint256(0)] + + for tx in self.vtx[1:]: + # Calculate the hashes with witness data + hashes.append(ser_uint256(tx.calc_sha256(True))) + + return self.get_merkle_root(hashes) + + def is_valid(self): + self.calc_sha256() + target = uint256_from_compact(self.nBits) + if self.sha256 > target: + return False + for tx in self.vtx: + if not tx.is_valid(): + return False + if self.calc_merkle_root() != self.hashMerkleRoot: + return False + return True + + def solve(self): + self.rehash() + target = uint256_from_compact(self.nBits) + while self.sha256 > target: + self.nNonce += 1 + self.rehash() + + def __repr__(self): + return "CBlock(nVersion=%i hashPrevBlock=%064x hashMerkleRoot=%064x nTime=%s nBits=%08x nNonce=%08x vtx=%s)" \ + % (self.nVersion, self.hashPrevBlock, self.hashMerkleRoot, + time.ctime(self.nTime), self.nBits, self.nNonce, repr(self.vtx)) + + +class PrefilledTransaction: + __slots__ = ("index", "tx") + + def __init__(self, index=0, tx = None): + self.index = index + self.tx = tx + + def deserialize(self, f): + self.index = deser_compact_size(f) + self.tx = CTransaction() + self.tx.deserialize(f) + + def serialize(self, with_witness=True): + r = b"" + r += ser_compact_size(self.index) + if with_witness: + r += self.tx.serialize_with_witness() + else: + r += self.tx.serialize_without_witness() + return r + + def serialize_without_witness(self): + return self.serialize(with_witness=False) + + def serialize_with_witness(self): + return self.serialize(with_witness=True) + + def __repr__(self): + return "PrefilledTransaction(index=%d, tx=%s)" % (self.index, repr(self.tx)) + + +# This is what we send on the wire, in a cmpctblock message. +class P2PHeaderAndShortIDs: + __slots__ = ("header", "nonce", "prefilled_txn", "prefilled_txn_length", + "shortids", "shortids_length") + + def __init__(self): + self.header = CBlockHeader() + self.nonce = 0 + self.shortids_length = 0 + self.shortids = [] + self.prefilled_txn_length = 0 + self.prefilled_txn = [] + + def deserialize(self, f): + self.header.deserialize(f) + self.nonce = struct.unpack("<Q", f.read(8))[0] + self.shortids_length = deser_compact_size(f) + for i in range(self.shortids_length): + # shortids are defined to be 6 bytes in the spec, so append + # two zero bytes and read it in as an 8-byte number + self.shortids.append(struct.unpack("<Q", f.read(6) + b'\x00\x00')[0]) + self.prefilled_txn = deser_vector(f, PrefilledTransaction) + self.prefilled_txn_length = len(self.prefilled_txn) + + # When using version 2 compact blocks, we must serialize with_witness. + def serialize(self, with_witness=False): + r = b"" + r += self.header.serialize() + r += struct.pack("<Q", self.nonce) + r += ser_compact_size(self.shortids_length) + for x in self.shortids: + # We only want the first 6 bytes + r += struct.pack("<Q", x)[0:6] + if with_witness: + r += ser_vector(self.prefilled_txn, "serialize_with_witness") + else: + r += ser_vector(self.prefilled_txn, "serialize_without_witness") + return r + + def __repr__(self): + return "P2PHeaderAndShortIDs(header=%s, nonce=%d, shortids_length=%d, shortids=%s, prefilled_txn_length=%d, prefilledtxn=%s" % (repr(self.header), self.nonce, self.shortids_length, repr(self.shortids), self.prefilled_txn_length, repr(self.prefilled_txn)) + + +# P2P version of the above that will use witness serialization (for compact +# block version 2) +class P2PHeaderAndShortWitnessIDs(P2PHeaderAndShortIDs): + __slots__ = () + def serialize(self): + return super().serialize(with_witness=True) + +# Calculate the BIP 152-compact blocks shortid for a given transaction hash +def calculate_shortid(k0, k1, tx_hash): + expected_shortid = siphash256(k0, k1, tx_hash) + expected_shortid &= 0x0000ffffffffffff + return expected_shortid + + +# This version gets rid of the array lengths, and reinterprets the differential +# encoding into indices that can be used for lookup. +class HeaderAndShortIDs: + __slots__ = ("header", "nonce", "prefilled_txn", "shortids", "use_witness") + + def __init__(self, p2pheaders_and_shortids = None): + self.header = CBlockHeader() + self.nonce = 0 + self.shortids = [] + self.prefilled_txn = [] + self.use_witness = False + + if p2pheaders_and_shortids is not None: + self.header = p2pheaders_and_shortids.header + self.nonce = p2pheaders_and_shortids.nonce + self.shortids = p2pheaders_and_shortids.shortids + last_index = -1 + for x in p2pheaders_and_shortids.prefilled_txn: + self.prefilled_txn.append(PrefilledTransaction(x.index + last_index + 1, x.tx)) + last_index = self.prefilled_txn[-1].index + + def to_p2p(self): + if self.use_witness: + ret = P2PHeaderAndShortWitnessIDs() + else: + ret = P2PHeaderAndShortIDs() + ret.header = self.header + ret.nonce = self.nonce + ret.shortids_length = len(self.shortids) + ret.shortids = self.shortids + ret.prefilled_txn_length = len(self.prefilled_txn) + ret.prefilled_txn = [] + last_index = -1 + for x in self.prefilled_txn: + ret.prefilled_txn.append(PrefilledTransaction(x.index - last_index - 1, x.tx)) + last_index = x.index + return ret + + def get_siphash_keys(self): + header_nonce = self.header.serialize() + header_nonce += struct.pack("<Q", self.nonce) + hash_header_nonce_as_str = sha256(header_nonce) + key0 = struct.unpack("<Q", hash_header_nonce_as_str[0:8])[0] + key1 = struct.unpack("<Q", hash_header_nonce_as_str[8:16])[0] + return [ key0, key1 ] + + # Version 2 compact blocks use wtxid in shortids (rather than txid) + def initialize_from_block(self, block, nonce=0, prefill_list=None, use_witness=False): + if prefill_list is None: + prefill_list = [0] + self.header = CBlockHeader(block) + self.nonce = nonce + self.prefilled_txn = [ PrefilledTransaction(i, block.vtx[i]) for i in prefill_list ] + self.shortids = [] + self.use_witness = use_witness + [k0, k1] = self.get_siphash_keys() + for i in range(len(block.vtx)): + if i not in prefill_list: + tx_hash = block.vtx[i].sha256 + if use_witness: + tx_hash = block.vtx[i].calc_sha256(with_witness=True) + self.shortids.append(calculate_shortid(k0, k1, tx_hash)) + + def __repr__(self): + return "HeaderAndShortIDs(header=%s, nonce=%d, shortids=%s, prefilledtxn=%s" % (repr(self.header), self.nonce, repr(self.shortids), repr(self.prefilled_txn)) + + +class BlockTransactionsRequest: + __slots__ = ("blockhash", "indexes") + + def __init__(self, blockhash=0, indexes = None): + self.blockhash = blockhash + self.indexes = indexes if indexes is not None else [] + + def deserialize(self, f): + self.blockhash = deser_uint256(f) + indexes_length = deser_compact_size(f) + for i in range(indexes_length): + self.indexes.append(deser_compact_size(f)) + + def serialize(self): + r = b"" + r += ser_uint256(self.blockhash) + r += ser_compact_size(len(self.indexes)) + for x in self.indexes: + r += ser_compact_size(x) + return r + + # helper to set the differentially encoded indexes from absolute ones + def from_absolute(self, absolute_indexes): + self.indexes = [] + last_index = -1 + for x in absolute_indexes: + self.indexes.append(x-last_index-1) + last_index = x + + def to_absolute(self): + absolute_indexes = [] + last_index = -1 + for x in self.indexes: + absolute_indexes.append(x+last_index+1) + last_index = absolute_indexes[-1] + return absolute_indexes + + def __repr__(self): + return "BlockTransactionsRequest(hash=%064x indexes=%s)" % (self.blockhash, repr(self.indexes)) + + +class BlockTransactions: + __slots__ = ("blockhash", "transactions") + + def __init__(self, blockhash=0, transactions = None): + self.blockhash = blockhash + self.transactions = transactions if transactions is not None else [] + + def deserialize(self, f): + self.blockhash = deser_uint256(f) + self.transactions = deser_vector(f, CTransaction) + + def serialize(self, with_witness=True): + r = b"" + r += ser_uint256(self.blockhash) + if with_witness: + r += ser_vector(self.transactions, "serialize_with_witness") + else: + r += ser_vector(self.transactions, "serialize_without_witness") + return r + + def __repr__(self): + return "BlockTransactions(hash=%064x transactions=%s)" % (self.blockhash, repr(self.transactions)) + + +class CPartialMerkleTree: + __slots__ = ("nTransactions", "vBits", "vHash") + + def __init__(self): + self.nTransactions = 0 + self.vHash = [] + self.vBits = [] + + def deserialize(self, f): + self.nTransactions = struct.unpack("<i", f.read(4))[0] + self.vHash = deser_uint256_vector(f) + vBytes = deser_string(f) + self.vBits = [] + for i in range(len(vBytes) * 8): + self.vBits.append(vBytes[i//8] & (1 << (i % 8)) != 0) + + def serialize(self): + r = b"" + r += struct.pack("<i", self.nTransactions) + r += ser_uint256_vector(self.vHash) + vBytesArray = bytearray([0x00] * ((len(self.vBits) + 7)//8)) + for i in range(len(self.vBits)): + vBytesArray[i // 8] |= self.vBits[i] << (i % 8) + r += ser_string(bytes(vBytesArray)) + return r + + def __repr__(self): + return "CPartialMerkleTree(nTransactions=%d, vHash=%s, vBits=%s)" % (self.nTransactions, repr(self.vHash), repr(self.vBits)) + + +class CMerkleBlock: + __slots__ = ("header", "txn") + + def __init__(self): + self.header = CBlockHeader() + self.txn = CPartialMerkleTree() + + def deserialize(self, f): + self.header.deserialize(f) + self.txn.deserialize(f) + + def serialize(self): + r = b"" + r += self.header.serialize() + r += self.txn.serialize() + return r + + def __repr__(self): + return "CMerkleBlock(header=%s, txn=%s)" % (repr(self.header), repr(self.txn)) + + +# Objects that correspond to messages on the wire +class msg_version: + __slots__ = ("addrFrom", "addrTo", "nNonce", "nRelay", "nServices", + "nStartingHeight", "nTime", "nVersion", "strSubVer") + msgtype = b"version" + + def __init__(self): + self.nVersion = MY_VERSION + self.nServices = NODE_NETWORK | NODE_WITNESS + self.nTime = int(time.time()) + self.addrTo = CAddress() + self.addrFrom = CAddress() + self.nNonce = random.getrandbits(64) + self.strSubVer = MY_SUBVERSION + self.nStartingHeight = -1 + self.nRelay = MY_RELAY + + def deserialize(self, f): + self.nVersion = struct.unpack("<i", f.read(4))[0] + self.nServices = struct.unpack("<Q", f.read(8))[0] + self.nTime = struct.unpack("<q", f.read(8))[0] + self.addrTo = CAddress() + self.addrTo.deserialize(f, False) + + self.addrFrom = CAddress() + self.addrFrom.deserialize(f, False) + self.nNonce = struct.unpack("<Q", f.read(8))[0] + self.strSubVer = deser_string(f) + + self.nStartingHeight = struct.unpack("<i", f.read(4))[0] + + if self.nVersion >= 70001: + # Relay field is optional for version 70001 onwards + try: + self.nRelay = struct.unpack("<b", f.read(1))[0] + except: + self.nRelay = 0 + else: + self.nRelay = 0 + + def serialize(self): + r = b"" + r += struct.pack("<i", self.nVersion) + r += struct.pack("<Q", self.nServices) + r += struct.pack("<q", self.nTime) + r += self.addrTo.serialize(False) + r += self.addrFrom.serialize(False) + r += struct.pack("<Q", self.nNonce) + r += ser_string(self.strSubVer) + r += struct.pack("<i", self.nStartingHeight) + r += struct.pack("<b", self.nRelay) + return r + + def __repr__(self): + return 'msg_version(nVersion=%i nServices=%i nTime=%s addrTo=%s addrFrom=%s nNonce=0x%016X strSubVer=%s nStartingHeight=%i nRelay=%i)' \ + % (self.nVersion, self.nServices, time.ctime(self.nTime), + repr(self.addrTo), repr(self.addrFrom), self.nNonce, + self.strSubVer, self.nStartingHeight, self.nRelay) + + +class msg_verack: + __slots__ = () + msgtype = b"verack" + + def __init__(self): + pass + + def deserialize(self, f): + pass + + def serialize(self): + return b"" + + def __repr__(self): + return "msg_verack()" + + +class msg_addr: + __slots__ = ("addrs",) + msgtype = b"addr" + + def __init__(self): + self.addrs = [] + + def deserialize(self, f): + self.addrs = deser_vector(f, CAddress) + + def serialize(self): + return ser_vector(self.addrs) + + def __repr__(self): + return "msg_addr(addrs=%s)" % (repr(self.addrs)) + + +class msg_inv: + __slots__ = ("inv",) + msgtype = b"inv" + + def __init__(self, inv=None): + if inv is None: + self.inv = [] + else: + self.inv = inv + + def deserialize(self, f): + self.inv = deser_vector(f, CInv) + + def serialize(self): + return ser_vector(self.inv) + + def __repr__(self): + return "msg_inv(inv=%s)" % (repr(self.inv)) + + +class msg_getdata: + __slots__ = ("inv",) + msgtype = b"getdata" + + def __init__(self, inv=None): + self.inv = inv if inv is not None else [] + + def deserialize(self, f): + self.inv = deser_vector(f, CInv) + + def serialize(self): + return ser_vector(self.inv) + + def __repr__(self): + return "msg_getdata(inv=%s)" % (repr(self.inv)) + + +class msg_getblocks: + __slots__ = ("locator", "hashstop") + msgtype = b"getblocks" + + def __init__(self): + self.locator = CBlockLocator() + self.hashstop = 0 + + def deserialize(self, f): + self.locator = CBlockLocator() + self.locator.deserialize(f) + self.hashstop = deser_uint256(f) + + def serialize(self): + r = b"" + r += self.locator.serialize() + r += ser_uint256(self.hashstop) + return r + + def __repr__(self): + return "msg_getblocks(locator=%s hashstop=%064x)" \ + % (repr(self.locator), self.hashstop) + + +class msg_tx: + __slots__ = ("tx",) + msgtype = b"tx" + + def __init__(self, tx=CTransaction()): + self.tx = tx + + def deserialize(self, f): + self.tx.deserialize(f) + + def serialize(self): + return self.tx.serialize_with_witness() + + def __repr__(self): + return "msg_tx(tx=%s)" % (repr(self.tx)) + + +class msg_no_witness_tx(msg_tx): + __slots__ = () + + def serialize(self): + return self.tx.serialize_without_witness() + + +class msg_block: + __slots__ = ("block",) + msgtype = b"block" + + def __init__(self, block=None): + if block is None: + self.block = CBlock() + else: + self.block = block + + def deserialize(self, f): + self.block.deserialize(f) + + def serialize(self): + return self.block.serialize() + + def __repr__(self): + return "msg_block(block=%s)" % (repr(self.block)) + + +# for cases where a user needs tighter control over what is sent over the wire +# note that the user must supply the name of the msgtype, and the data +class msg_generic: + __slots__ = ("msgtype", "data") + + def __init__(self, msgtype, data=None): + self.msgtype = msgtype + self.data = data + + def serialize(self): + return self.data + + def __repr__(self): + return "msg_generic()" + + +class msg_no_witness_block(msg_block): + __slots__ = () + def serialize(self): + return self.block.serialize(with_witness=False) + + +class msg_getaddr: + __slots__ = () + msgtype = b"getaddr" + + def __init__(self): + pass + + def deserialize(self, f): + pass + + def serialize(self): + return b"" + + def __repr__(self): + return "msg_getaddr()" + + +class msg_ping: + __slots__ = ("nonce", "height") + msgtype = b"ping" + + def __init__(self, nonce=0, height=0): + self.nonce = nonce + self.height = height + + def deserialize(self, f): + self.nonce = struct.unpack("<Q", f.read(8))[0] + self.height = struct.unpack("<i", f.read(4))[0] + + def serialize(self): + r = b"" + r += struct.pack("<Q", self.nonce) + r += struct.pack("<i", self.height) + return r + + def __repr__(self): + return "msg_ping(nonce=%08x)" % self.nonce + + +class msg_pong: + __slots__ = ("nonce",) + msgtype = b"pong" + + def __init__(self, nonce=0): + self.nonce = nonce + + def deserialize(self, f): + self.nonce = struct.unpack("<Q", f.read(8))[0] + + def serialize(self): + r = b"" + r += struct.pack("<Q", self.nonce) + return r + + def __repr__(self): + return "msg_pong(nonce=%08x)" % self.nonce + + +class msg_mempool: + __slots__ = () + msgtype = b"mempool" + + def __init__(self): + pass + + def deserialize(self, f): + pass + + def serialize(self): + return b"" + + def __repr__(self): + return "msg_mempool()" + + +class msg_notfound: + __slots__ = ("vec", ) + msgtype = b"notfound" + + def __init__(self, vec=None): + self.vec = vec or [] + + def deserialize(self, f): + self.vec = deser_vector(f, CInv) + + def serialize(self): + return ser_vector(self.vec) + + def __repr__(self): + return "msg_notfound(vec=%s)" % (repr(self.vec)) + + +class msg_sendheaders: + __slots__ = () + msgtype = b"sendheaders" + + def __init__(self): + pass + + def deserialize(self, f): + pass + + def serialize(self): + return b"" + + def __repr__(self): + return "msg_sendheaders()" + + +# getheaders message has +# number of entries +# vector of hashes +# hash_stop (hash of last desired block header, 0 to get as many as possible) +class msg_getheaders: + __slots__ = ("hashstop", "locator",) + msgtype = b"getheaders" + + def __init__(self): + self.locator = CBlockLocator() + self.hashstop = 0 + + def deserialize(self, f): + self.locator = CBlockLocator() + self.locator.deserialize(f) + self.hashstop = deser_uint256(f) + + def serialize(self): + r = b"" + r += self.locator.serialize() + r += ser_uint256(self.hashstop) + return r + + def __repr__(self): + return "msg_getheaders(locator=%s, stop=%064x)" \ + % (repr(self.locator), self.hashstop) + + +# headers message has +# <count> <vector of block headers> +class msg_headers: + __slots__ = ("headers",) + msgtype = b"headers" + + def __init__(self, headers=None): + self.headers = headers if headers is not None else [] + + def deserialize(self, f): + # comment in bitcoind indicates these should be deserialized as blocks + blocks = deser_vector(f, CBlock) + for x in blocks: + self.headers.append(CBlockHeader(x)) + + def serialize(self): + blocks = [CBlock(x) for x in self.headers] + return ser_vector(blocks) + + def __repr__(self): + return "msg_headers(headers=%s)" % repr(self.headers) + + +class msg_merkleblock: + __slots__ = ("merkleblock",) + msgtype = b"merkleblock" + + def __init__(self, merkleblock=None): + if merkleblock is None: + self.merkleblock = CMerkleBlock() + else: + self.merkleblock = merkleblock + + def deserialize(self, f): + self.merkleblock.deserialize(f) + + def serialize(self): + return self.merkleblock.serialize() + + def __repr__(self): + return "msg_merkleblock(merkleblock=%s)" % (repr(self.merkleblock)) + + +class msg_filterload: + __slots__ = ("data", "nHashFuncs", "nTweak", "nFlags") + msgtype = b"filterload" + + def __init__(self, data=b'00', nHashFuncs=0, nTweak=0, nFlags=0): + self.data = data + self.nHashFuncs = nHashFuncs + self.nTweak = nTweak + self.nFlags = nFlags + + def deserialize(self, f): + self.data = deser_string(f) + self.nHashFuncs = struct.unpack("<I", f.read(4))[0] + self.nTweak = struct.unpack("<I", f.read(4))[0] + self.nFlags = struct.unpack("<B", f.read(1))[0] + + def serialize(self): + r = b"" + r += ser_string(self.data) + r += struct.pack("<I", self.nHashFuncs) + r += struct.pack("<I", self.nTweak) + r += struct.pack("<B", self.nFlags) + return r + + def __repr__(self): + return "msg_filterload(data={}, nHashFuncs={}, nTweak={}, nFlags={})".format( + self.data, self.nHashFuncs, self.nTweak, self.nFlags) + + +class msg_filteradd: + __slots__ = ("data") + msgtype = b"filteradd" + + def __init__(self, data): + self.data = data + + def deserialize(self, f): + self.data = deser_string(f) + + def serialize(self): + r = b"" + r += ser_string(self.data) + return r + + def __repr__(self): + return "msg_filteradd(data={})".format(self.data) + + +class msg_filterclear: + __slots__ = () + msgtype = b"filterclear" + + def __init__(self): + pass + + def deserialize(self, f): + pass + + def serialize(self): + return b"" + + def __repr__(self): + return "msg_filterclear()" + + +class msg_feefilter: + __slots__ = ("feerate",) + msgtype = b"feefilter" + + def __init__(self, feerate=0): + self.feerate = feerate + + def deserialize(self, f): + self.feerate = struct.unpack("<Q", f.read(8))[0] + + def serialize(self): + r = b"" + r += struct.pack("<Q", self.feerate) + return r + + def __repr__(self): + return "msg_feefilter(feerate=%08x)" % self.feerate + + +class msg_sendcmpct: + __slots__ = ("announce", "version") + msgtype = b"sendcmpct" + + def __init__(self): + self.announce = False + self.version = 1 + + def deserialize(self, f): + self.announce = struct.unpack("<?", f.read(1))[0] + self.version = struct.unpack("<Q", f.read(8))[0] + + def serialize(self): + r = b"" + r += struct.pack("<?", self.announce) + r += struct.pack("<Q", self.version) + return r + + def __repr__(self): + return "msg_sendcmpct(announce=%s, version=%lu)" % (self.announce, self.version) + + +class msg_cmpctblock: + __slots__ = ("header_and_shortids",) + msgtype = b"cmpctblock" + + def __init__(self, header_and_shortids = None): + self.header_and_shortids = header_and_shortids + + def deserialize(self, f): + self.header_and_shortids = P2PHeaderAndShortIDs() + self.header_and_shortids.deserialize(f) + + def serialize(self): + r = b"" + r += self.header_and_shortids.serialize() + return r + + def __repr__(self): + return "msg_cmpctblock(HeaderAndShortIDs=%s)" % repr(self.header_and_shortids) + + +class msg_getblocktxn: + __slots__ = ("block_txn_request",) + msgtype = b"getblocktxn" + + def __init__(self): + self.block_txn_request = None + + def deserialize(self, f): + self.block_txn_request = BlockTransactionsRequest() + self.block_txn_request.deserialize(f) + + def serialize(self): + r = b"" + r += self.block_txn_request.serialize() + return r + + def __repr__(self): + return "msg_getblocktxn(block_txn_request=%s)" % (repr(self.block_txn_request)) + + +class msg_blocktxn: + __slots__ = ("block_transactions",) + msgtype = b"blocktxn" + + def __init__(self): + self.block_transactions = BlockTransactions() + + def deserialize(self, f): + self.block_transactions.deserialize(f) + + def serialize(self): + r = b"" + r += self.block_transactions.serialize() + return r + + def __repr__(self): + return "msg_blocktxn(block_transactions=%s)" % (repr(self.block_transactions)) + + +class msg_no_witness_blocktxn(msg_blocktxn): + __slots__ = () + + def serialize(self): + return self.block_transactions.serialize(with_witness=False) + + +class msg_getcfilters: + __slots__ = ("filter_type", "start_height", "stop_hash") + msgtype = b"getcfilters" + + def __init__(self, filter_type, start_height, stop_hash): + self.filter_type = filter_type + self.start_height = start_height + self.stop_hash = stop_hash + + def deserialize(self, f): + self.filter_type = struct.unpack("<B", f.read(1))[0] + self.start_height = struct.unpack("<I", f.read(4))[0] + self.stop_hash = deser_uint256(f) + + def serialize(self): + r = b"" + r += struct.pack("<B", self.filter_type) + r += struct.pack("<I", self.start_height) + r += ser_uint256(self.stop_hash) + return r + + def __repr__(self): + return "msg_getcfilters(filter_type={:#x}, start_height={}, stop_hash={:x})".format( + self.filter_type, self.start_height, self.stop_hash) + +class msg_cfilter: + __slots__ = ("filter_type", "block_hash", "filter_data") + msgtype = b"cfilter" + + def __init__(self, filter_type=None, block_hash=None, filter_data=None): + self.filter_type = filter_type + self.block_hash = block_hash + self.filter_data = filter_data + + def deserialize(self, f): + self.filter_type = struct.unpack("<B", f.read(1))[0] + self.block_hash = deser_uint256(f) + self.filter_data = deser_string(f) + + def serialize(self): + r = b"" + r += struct.pack("<B", self.filter_type) + r += ser_uint256(self.block_hash) + r += ser_string(self.filter_data) + return r + + def __repr__(self): + return "msg_cfilter(filter_type={:#x}, block_hash={:x})".format( + self.filter_type, self.block_hash) + +class msg_getcfheaders: + __slots__ = ("filter_type", "start_height", "stop_hash") + msgtype = b"getcfheaders" + + def __init__(self, filter_type, start_height, stop_hash): + self.filter_type = filter_type + self.start_height = start_height + self.stop_hash = stop_hash + + def deserialize(self, f): + self.filter_type = struct.unpack("<B", f.read(1))[0] + self.start_height = struct.unpack("<I", f.read(4))[0] + self.stop_hash = deser_uint256(f) + + def serialize(self): + r = b"" + r += struct.pack("<B", self.filter_type) + r += struct.pack("<I", self.start_height) + r += ser_uint256(self.stop_hash) + return r + + def __repr__(self): + return "msg_getcfheaders(filter_type={:#x}, start_height={}, stop_hash={:x})".format( + self.filter_type, self.start_height, self.stop_hash) + +class msg_cfheaders: + __slots__ = ("filter_type", "stop_hash", "prev_header", "hashes") + msgtype = b"cfheaders" + + def __init__(self, filter_type=None, stop_hash=None, prev_header=None, hashes=None): + self.filter_type = filter_type + self.stop_hash = stop_hash + self.prev_header = prev_header + self.hashes = hashes + + def deserialize(self, f): + self.filter_type = struct.unpack("<B", f.read(1))[0] + self.stop_hash = deser_uint256(f) + self.prev_header = deser_uint256(f) + self.hashes = deser_uint256_vector(f) + + def serialize(self): + r = b"" + r += struct.pack("<B", self.filter_type) + r += ser_uint256(self.stop_hash) + r += ser_uint256(self.prev_header) + r += ser_uint256_vector(self.hashes) + return r + + def __repr__(self): + return "msg_cfheaders(filter_type={:#x}, stop_hash={:x})".format( + self.filter_type, self.stop_hash) + +class msg_getcfcheckpt: + __slots__ = ("filter_type", "stop_hash") + msgtype = b"getcfcheckpt" + + def __init__(self, filter_type, stop_hash): + self.filter_type = filter_type + self.stop_hash = stop_hash + + def deserialize(self, f): + self.filter_type = struct.unpack("<B", f.read(1))[0] + self.stop_hash = deser_uint256(f) + + def serialize(self): + r = b"" + r += struct.pack("<B", self.filter_type) + r += ser_uint256(self.stop_hash) + return r + + def __repr__(self): + return "msg_getcfcheckpt(filter_type={:#x}, stop_hash={:x})".format( + self.filter_type, self.stop_hash) + +class msg_cfcheckpt: + __slots__ = ("filter_type", "stop_hash", "headers") + msgtype = b"cfcheckpt" + + def __init__(self, filter_type=None, stop_hash=None, headers=None): + self.filter_type = filter_type + self.stop_hash = stop_hash + self.headers = headers + + def deserialize(self, f): + self.filter_type = struct.unpack("<B", f.read(1))[0] + self.stop_hash = deser_uint256(f) + self.headers = deser_uint256_vector(f) + + def serialize(self): + r = b"" + r += struct.pack("<B", self.filter_type) + r += ser_uint256(self.stop_hash) + r += ser_uint256_vector(self.headers) + return r + + def __repr__(self): + return "msg_cfcheckpt(filter_type={:#x}, stop_hash={:x})".format( + self.filter_type, self.stop_hash) diff --git a/basicswap/contrib/test_framework/script.py b/basicswap/contrib/test_framework/script.py new file mode 100644 index 0000000..cc5f830 --- /dev/null +++ b/basicswap/contrib/test_framework/script.py @@ -0,0 +1,740 @@ +#!/usr/bin/env python3 +# Copyright (c) 2015-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Functionality to build scripts, as well as signature hash functions. + +This file is modified from python-bitcoinlib. +""" +import hashlib +import struct +import unittest +from typing import List, Dict + +from .messages import ( + CTransaction, + CTxOut, + hash256, + ser_string, + ser_uint256, + sha256, + uint256_from_str, +) + +MAX_SCRIPT_ELEMENT_SIZE = 520 +OPCODE_NAMES = {} # type: Dict[CScriptOp, str] + +def hash160(s): + return hashlib.new('ripemd160', sha256(s)).digest() + +def bn2vch(v): + """Convert number to bitcoin-specific little endian format.""" + # We need v.bit_length() bits, plus a sign bit for every nonzero number. + n_bits = v.bit_length() + (v != 0) + # The number of bytes for that is: + n_bytes = (n_bits + 7) // 8 + # Convert number to absolute value + sign in top bit. + encoded_v = 0 if v == 0 else abs(v) | ((v < 0) << (n_bytes * 8 - 1)) + # Serialize to bytes + return encoded_v.to_bytes(n_bytes, 'little') + +_opcode_instances = [] # type: List[CScriptOp] +class CScriptOp(int): + """A single script opcode""" + __slots__ = () + + @staticmethod + def encode_op_pushdata(d): + """Encode a PUSHDATA op, returning bytes""" + if len(d) < 0x4c: + return b'' + bytes([len(d)]) + d # OP_PUSHDATA + elif len(d) <= 0xff: + return b'\x4c' + bytes([len(d)]) + d # OP_PUSHDATA1 + elif len(d) <= 0xffff: + return b'\x4d' + struct.pack(b'<H', len(d)) + d # OP_PUSHDATA2 + elif len(d) <= 0xffffffff: + return b'\x4e' + struct.pack(b'<I', len(d)) + d # OP_PUSHDATA4 + else: + raise ValueError("Data too long to encode in a PUSHDATA op") + + @staticmethod + def encode_op_n(n): + """Encode a small integer op, returning an opcode""" + if not (0 <= n <= 16): + raise ValueError('Integer must be in range 0 <= n <= 16, got %d' % n) + + if n == 0: + return OP_0 + else: + return CScriptOp(OP_1 + n - 1) + + def decode_op_n(self): + """Decode a small integer opcode, returning an integer""" + if self == OP_0: + return 0 + + if not (self == OP_0 or OP_1 <= self <= OP_16): + raise ValueError('op %r is not an OP_N' % self) + + return int(self - OP_1 + 1) + + def is_small_int(self): + """Return true if the op pushes a small integer to the stack""" + if 0x51 <= self <= 0x60 or self == 0: + return True + else: + return False + + def __str__(self): + return repr(self) + + def __repr__(self): + if self in OPCODE_NAMES: + return OPCODE_NAMES[self] + else: + return 'CScriptOp(0x%x)' % self + + def __new__(cls, n): + try: + return _opcode_instances[n] + except IndexError: + assert len(_opcode_instances) == n + _opcode_instances.append(super().__new__(cls, n)) + return _opcode_instances[n] + +# Populate opcode instance table +for n in range(0xff + 1): + CScriptOp(n) + + +# push value +OP_0 = CScriptOp(0x00) +OP_FALSE = OP_0 +OP_PUSHDATA1 = CScriptOp(0x4c) +OP_PUSHDATA2 = CScriptOp(0x4d) +OP_PUSHDATA4 = CScriptOp(0x4e) +OP_1NEGATE = CScriptOp(0x4f) +OP_RESERVED = CScriptOp(0x50) +OP_1 = CScriptOp(0x51) +OP_TRUE = OP_1 +OP_2 = CScriptOp(0x52) +OP_3 = CScriptOp(0x53) +OP_4 = CScriptOp(0x54) +OP_5 = CScriptOp(0x55) +OP_6 = CScriptOp(0x56) +OP_7 = CScriptOp(0x57) +OP_8 = CScriptOp(0x58) +OP_9 = CScriptOp(0x59) +OP_10 = CScriptOp(0x5a) +OP_11 = CScriptOp(0x5b) +OP_12 = CScriptOp(0x5c) +OP_13 = CScriptOp(0x5d) +OP_14 = CScriptOp(0x5e) +OP_15 = CScriptOp(0x5f) +OP_16 = CScriptOp(0x60) + +# control +OP_NOP = CScriptOp(0x61) +OP_VER = CScriptOp(0x62) +OP_IF = CScriptOp(0x63) +OP_NOTIF = CScriptOp(0x64) +OP_VERIF = CScriptOp(0x65) +OP_VERNOTIF = CScriptOp(0x66) +OP_ELSE = CScriptOp(0x67) +OP_ENDIF = CScriptOp(0x68) +OP_VERIFY = CScriptOp(0x69) +OP_RETURN = CScriptOp(0x6a) + +# stack ops +OP_TOALTSTACK = CScriptOp(0x6b) +OP_FROMALTSTACK = CScriptOp(0x6c) +OP_2DROP = CScriptOp(0x6d) +OP_2DUP = CScriptOp(0x6e) +OP_3DUP = CScriptOp(0x6f) +OP_2OVER = CScriptOp(0x70) +OP_2ROT = CScriptOp(0x71) +OP_2SWAP = CScriptOp(0x72) +OP_IFDUP = CScriptOp(0x73) +OP_DEPTH = CScriptOp(0x74) +OP_DROP = CScriptOp(0x75) +OP_DUP = CScriptOp(0x76) +OP_NIP = CScriptOp(0x77) +OP_OVER = CScriptOp(0x78) +OP_PICK = CScriptOp(0x79) +OP_ROLL = CScriptOp(0x7a) +OP_ROT = CScriptOp(0x7b) +OP_SWAP = CScriptOp(0x7c) +OP_TUCK = CScriptOp(0x7d) + +# splice ops +OP_CAT = CScriptOp(0x7e) +OP_SUBSTR = CScriptOp(0x7f) +OP_LEFT = CScriptOp(0x80) +OP_RIGHT = CScriptOp(0x81) +OP_SIZE = CScriptOp(0x82) + +# bit logic +OP_INVERT = CScriptOp(0x83) +OP_AND = CScriptOp(0x84) +OP_OR = CScriptOp(0x85) +OP_XOR = CScriptOp(0x86) +OP_EQUAL = CScriptOp(0x87) +OP_EQUALVERIFY = CScriptOp(0x88) +OP_RESERVED1 = CScriptOp(0x89) +OP_RESERVED2 = CScriptOp(0x8a) + +# numeric +OP_1ADD = CScriptOp(0x8b) +OP_1SUB = CScriptOp(0x8c) +OP_2MUL = CScriptOp(0x8d) +OP_2DIV = CScriptOp(0x8e) +OP_NEGATE = CScriptOp(0x8f) +OP_ABS = CScriptOp(0x90) +OP_NOT = CScriptOp(0x91) +OP_0NOTEQUAL = CScriptOp(0x92) + +OP_ADD = CScriptOp(0x93) +OP_SUB = CScriptOp(0x94) +OP_MUL = CScriptOp(0x95) +OP_DIV = CScriptOp(0x96) +OP_MOD = CScriptOp(0x97) +OP_LSHIFT = CScriptOp(0x98) +OP_RSHIFT = CScriptOp(0x99) + +OP_BOOLAND = CScriptOp(0x9a) +OP_BOOLOR = CScriptOp(0x9b) +OP_NUMEQUAL = CScriptOp(0x9c) +OP_NUMEQUALVERIFY = CScriptOp(0x9d) +OP_NUMNOTEQUAL = CScriptOp(0x9e) +OP_LESSTHAN = CScriptOp(0x9f) +OP_GREATERTHAN = CScriptOp(0xa0) +OP_LESSTHANOREQUAL = CScriptOp(0xa1) +OP_GREATERTHANOREQUAL = CScriptOp(0xa2) +OP_MIN = CScriptOp(0xa3) +OP_MAX = CScriptOp(0xa4) + +OP_WITHIN = CScriptOp(0xa5) + +# crypto +OP_RIPEMD160 = CScriptOp(0xa6) +OP_SHA1 = CScriptOp(0xa7) +OP_SHA256 = CScriptOp(0xa8) +OP_HASH160 = CScriptOp(0xa9) +OP_HASH256 = CScriptOp(0xaa) +OP_CODESEPARATOR = CScriptOp(0xab) +OP_CHECKSIG = CScriptOp(0xac) +OP_CHECKSIGVERIFY = CScriptOp(0xad) +OP_CHECKMULTISIG = CScriptOp(0xae) +OP_CHECKMULTISIGVERIFY = CScriptOp(0xaf) + +# expansion +OP_NOP1 = CScriptOp(0xb0) +OP_CHECKLOCKTIMEVERIFY = CScriptOp(0xb1) +OP_CHECKSEQUENCEVERIFY = CScriptOp(0xb2) +OP_NOP4 = CScriptOp(0xb3) +OP_NOP5 = CScriptOp(0xb4) +OP_NOP6 = CScriptOp(0xb5) +OP_NOP7 = CScriptOp(0xb6) +OP_NOP8 = CScriptOp(0xb7) +OP_NOP9 = CScriptOp(0xb8) +OP_NOP10 = CScriptOp(0xb9) + +# template matching params +OP_SMALLINTEGER = CScriptOp(0xfa) +OP_PUBKEYS = CScriptOp(0xfb) +OP_PUBKEYHASH = CScriptOp(0xfd) +OP_PUBKEY = CScriptOp(0xfe) + +OP_INVALIDOPCODE = CScriptOp(0xff) + +OPCODE_NAMES.update({ + OP_0: 'OP_0', + OP_PUSHDATA1: 'OP_PUSHDATA1', + OP_PUSHDATA2: 'OP_PUSHDATA2', + OP_PUSHDATA4: 'OP_PUSHDATA4', + OP_1NEGATE: 'OP_1NEGATE', + OP_RESERVED: 'OP_RESERVED', + OP_1: 'OP_1', + OP_2: 'OP_2', + OP_3: 'OP_3', + OP_4: 'OP_4', + OP_5: 'OP_5', + OP_6: 'OP_6', + OP_7: 'OP_7', + OP_8: 'OP_8', + OP_9: 'OP_9', + OP_10: 'OP_10', + OP_11: 'OP_11', + OP_12: 'OP_12', + OP_13: 'OP_13', + OP_14: 'OP_14', + OP_15: 'OP_15', + OP_16: 'OP_16', + OP_NOP: 'OP_NOP', + OP_VER: 'OP_VER', + OP_IF: 'OP_IF', + OP_NOTIF: 'OP_NOTIF', + OP_VERIF: 'OP_VERIF', + OP_VERNOTIF: 'OP_VERNOTIF', + OP_ELSE: 'OP_ELSE', + OP_ENDIF: 'OP_ENDIF', + OP_VERIFY: 'OP_VERIFY', + OP_RETURN: 'OP_RETURN', + OP_TOALTSTACK: 'OP_TOALTSTACK', + OP_FROMALTSTACK: 'OP_FROMALTSTACK', + OP_2DROP: 'OP_2DROP', + OP_2DUP: 'OP_2DUP', + OP_3DUP: 'OP_3DUP', + OP_2OVER: 'OP_2OVER', + OP_2ROT: 'OP_2ROT', + OP_2SWAP: 'OP_2SWAP', + OP_IFDUP: 'OP_IFDUP', + OP_DEPTH: 'OP_DEPTH', + OP_DROP: 'OP_DROP', + OP_DUP: 'OP_DUP', + OP_NIP: 'OP_NIP', + OP_OVER: 'OP_OVER', + OP_PICK: 'OP_PICK', + OP_ROLL: 'OP_ROLL', + OP_ROT: 'OP_ROT', + OP_SWAP: 'OP_SWAP', + OP_TUCK: 'OP_TUCK', + OP_CAT: 'OP_CAT', + OP_SUBSTR: 'OP_SUBSTR', + OP_LEFT: 'OP_LEFT', + OP_RIGHT: 'OP_RIGHT', + OP_SIZE: 'OP_SIZE', + OP_INVERT: 'OP_INVERT', + OP_AND: 'OP_AND', + OP_OR: 'OP_OR', + OP_XOR: 'OP_XOR', + OP_EQUAL: 'OP_EQUAL', + OP_EQUALVERIFY: 'OP_EQUALVERIFY', + OP_RESERVED1: 'OP_RESERVED1', + OP_RESERVED2: 'OP_RESERVED2', + OP_1ADD: 'OP_1ADD', + OP_1SUB: 'OP_1SUB', + OP_2MUL: 'OP_2MUL', + OP_2DIV: 'OP_2DIV', + OP_NEGATE: 'OP_NEGATE', + OP_ABS: 'OP_ABS', + OP_NOT: 'OP_NOT', + OP_0NOTEQUAL: 'OP_0NOTEQUAL', + OP_ADD: 'OP_ADD', + OP_SUB: 'OP_SUB', + OP_MUL: 'OP_MUL', + OP_DIV: 'OP_DIV', + OP_MOD: 'OP_MOD', + OP_LSHIFT: 'OP_LSHIFT', + OP_RSHIFT: 'OP_RSHIFT', + OP_BOOLAND: 'OP_BOOLAND', + OP_BOOLOR: 'OP_BOOLOR', + OP_NUMEQUAL: 'OP_NUMEQUAL', + OP_NUMEQUALVERIFY: 'OP_NUMEQUALVERIFY', + OP_NUMNOTEQUAL: 'OP_NUMNOTEQUAL', + OP_LESSTHAN: 'OP_LESSTHAN', + OP_GREATERTHAN: 'OP_GREATERTHAN', + OP_LESSTHANOREQUAL: 'OP_LESSTHANOREQUAL', + OP_GREATERTHANOREQUAL: 'OP_GREATERTHANOREQUAL', + OP_MIN: 'OP_MIN', + OP_MAX: 'OP_MAX', + OP_WITHIN: 'OP_WITHIN', + OP_RIPEMD160: 'OP_RIPEMD160', + OP_SHA1: 'OP_SHA1', + OP_SHA256: 'OP_SHA256', + OP_HASH160: 'OP_HASH160', + OP_HASH256: 'OP_HASH256', + OP_CODESEPARATOR: 'OP_CODESEPARATOR', + OP_CHECKSIG: 'OP_CHECKSIG', + OP_CHECKSIGVERIFY: 'OP_CHECKSIGVERIFY', + OP_CHECKMULTISIG: 'OP_CHECKMULTISIG', + OP_CHECKMULTISIGVERIFY: 'OP_CHECKMULTISIGVERIFY', + OP_NOP1: 'OP_NOP1', + OP_CHECKLOCKTIMEVERIFY: 'OP_CHECKLOCKTIMEVERIFY', + OP_CHECKSEQUENCEVERIFY: 'OP_CHECKSEQUENCEVERIFY', + OP_NOP4: 'OP_NOP4', + OP_NOP5: 'OP_NOP5', + OP_NOP6: 'OP_NOP6', + OP_NOP7: 'OP_NOP7', + OP_NOP8: 'OP_NOP8', + OP_NOP9: 'OP_NOP9', + OP_NOP10: 'OP_NOP10', + OP_SMALLINTEGER: 'OP_SMALLINTEGER', + OP_PUBKEYS: 'OP_PUBKEYS', + OP_PUBKEYHASH: 'OP_PUBKEYHASH', + OP_PUBKEY: 'OP_PUBKEY', + OP_INVALIDOPCODE: 'OP_INVALIDOPCODE', +}) + +class CScriptInvalidError(Exception): + """Base class for CScript exceptions""" + pass + +class CScriptTruncatedPushDataError(CScriptInvalidError): + """Invalid pushdata due to truncation""" + def __init__(self, msg, data): + self.data = data + super().__init__(msg) + + +# This is used, eg, for blockchain heights in coinbase scripts (bip34) +class CScriptNum: + __slots__ = ("value",) + + def __init__(self, d=0): + self.value = d + + @staticmethod + def encode(obj): + r = bytearray(0) + if obj.value == 0: + return bytes(r) + neg = obj.value < 0 + absvalue = -obj.value if neg else obj.value + while (absvalue): + r.append(absvalue & 0xff) + absvalue >>= 8 + if r[-1] & 0x80: + r.append(0x80 if neg else 0) + elif neg: + r[-1] |= 0x80 + return bytes([len(r)]) + r + + @staticmethod + def decode(vch): + result = 0 + # We assume valid push_size and minimal encoding + value = vch[1:] + if len(value) == 0: + return result + for i, byte in enumerate(value): + result |= int(byte) << 8 * i + if value[-1] >= 0x80: + # Mask for all but the highest result bit + num_mask = (2**(len(value) * 8) - 1) >> 1 + result &= num_mask + result *= -1 + return result + + +class CScript(bytes): + """Serialized script + + A bytes subclass, so you can use this directly whenever bytes are accepted. + Note that this means that indexing does *not* work - you'll get an index by + byte rather than opcode. This format was chosen for efficiency so that the + general case would not require creating a lot of little CScriptOP objects. + + iter(script) however does iterate by opcode. + """ + __slots__ = () + + @classmethod + def __coerce_instance(cls, other): + # Coerce other into bytes + if isinstance(other, CScriptOp): + other = bytes([other]) + elif isinstance(other, CScriptNum): + if (other.value == 0): + other = bytes([CScriptOp(OP_0)]) + else: + other = CScriptNum.encode(other) + elif isinstance(other, int): + if 0 <= other <= 16: + other = bytes([CScriptOp.encode_op_n(other)]) + elif other == -1: + other = bytes([OP_1NEGATE]) + else: + other = CScriptOp.encode_op_pushdata(bn2vch(other)) + elif isinstance(other, (bytes, bytearray)): + other = CScriptOp.encode_op_pushdata(other) + return other + + def __add__(self, other): + # add makes no sense for a CScript() + raise NotImplementedError + + def join(self, iterable): + # join makes no sense for a CScript() + raise NotImplementedError + + def __new__(cls, value=b''): + if isinstance(value, bytes) or isinstance(value, bytearray): + return super().__new__(cls, value) + else: + def coerce_iterable(iterable): + for instance in iterable: + yield cls.__coerce_instance(instance) + # Annoyingly on both python2 and python3 bytes.join() always + # returns a bytes instance even when subclassed. + return super().__new__(cls, b''.join(coerce_iterable(value))) + + def raw_iter(self): + """Raw iteration + + Yields tuples of (opcode, data, sop_idx) so that the different possible + PUSHDATA encodings can be accurately distinguished, as well as + determining the exact opcode byte indexes. (sop_idx) + """ + i = 0 + while i < len(self): + sop_idx = i + opcode = self[i] + i += 1 + + if opcode > OP_PUSHDATA4: + yield (opcode, None, sop_idx) + else: + datasize = None + pushdata_type = None + if opcode < OP_PUSHDATA1: + pushdata_type = 'PUSHDATA(%d)' % opcode + datasize = opcode + + elif opcode == OP_PUSHDATA1: + pushdata_type = 'PUSHDATA1' + if i >= len(self): + raise CScriptInvalidError('PUSHDATA1: missing data length') + datasize = self[i] + i += 1 + + elif opcode == OP_PUSHDATA2: + pushdata_type = 'PUSHDATA2' + if i + 1 >= len(self): + raise CScriptInvalidError('PUSHDATA2: missing data length') + datasize = self[i] + (self[i + 1] << 8) + i += 2 + + elif opcode == OP_PUSHDATA4: + pushdata_type = 'PUSHDATA4' + if i + 3 >= len(self): + raise CScriptInvalidError('PUSHDATA4: missing data length') + datasize = self[i] + (self[i + 1] << 8) + (self[i + 2] << 16) + (self[i + 3] << 24) + i += 4 + + else: + assert False # shouldn't happen + + data = bytes(self[i:i + datasize]) + + # Check for truncation + if len(data) < datasize: + raise CScriptTruncatedPushDataError('%s: truncated data' % pushdata_type, data) + + i += datasize + + yield (opcode, data, sop_idx) + + def __iter__(self): + """'Cooked' iteration + + Returns either a CScriptOP instance, an integer, or bytes, as + appropriate. + + See raw_iter() if you need to distinguish the different possible + PUSHDATA encodings. + """ + for (opcode, data, sop_idx) in self.raw_iter(): + if data is not None: + yield data + else: + opcode = CScriptOp(opcode) + + if opcode.is_small_int(): + yield opcode.decode_op_n() + else: + yield CScriptOp(opcode) + + def __repr__(self): + def _repr(o): + if isinstance(o, bytes): + return "x('%s')" % o.hex() + else: + return repr(o) + + ops = [] + i = iter(self) + while True: + op = None + try: + op = _repr(next(i)) + except CScriptTruncatedPushDataError as err: + op = '%s...<ERROR: %s>' % (_repr(err.data), err) + break + except CScriptInvalidError as err: + op = '<ERROR: %s>' % err + break + except StopIteration: + break + finally: + if op is not None: + ops.append(op) + + return "CScript([%s])" % ', '.join(ops) + + def GetSigOpCount(self, fAccurate): + """Get the SigOp count. + + fAccurate - Accurately count CHECKMULTISIG, see BIP16 for details. + + Note that this is consensus-critical. + """ + n = 0 + lastOpcode = OP_INVALIDOPCODE + for (opcode, data, sop_idx) in self.raw_iter(): + if opcode in (OP_CHECKSIG, OP_CHECKSIGVERIFY): + n += 1 + elif opcode in (OP_CHECKMULTISIG, OP_CHECKMULTISIGVERIFY): + if fAccurate and (OP_1 <= lastOpcode <= OP_16): + n += opcode.decode_op_n() + else: + n += 20 + lastOpcode = opcode + return n + + +SIGHASH_ALL = 1 +SIGHASH_NONE = 2 +SIGHASH_SINGLE = 3 +SIGHASH_ANYONECANPAY = 0x80 + +def FindAndDelete(script, sig): + """Consensus critical, see FindAndDelete() in Satoshi codebase""" + r = b'' + last_sop_idx = sop_idx = 0 + skip = True + for (opcode, data, sop_idx) in script.raw_iter(): + if not skip: + r += script[last_sop_idx:sop_idx] + last_sop_idx = sop_idx + if script[sop_idx:sop_idx + len(sig)] == sig: + skip = True + else: + skip = False + if not skip: + r += script[last_sop_idx:] + return CScript(r) + + +def LegacySignatureHash(script, txTo, inIdx, hashtype): + """Consensus-correct SignatureHash + + Returns (hash, err) to precisely match the consensus-critical behavior of + the SIGHASH_SINGLE bug. (inIdx is *not* checked for validity) + """ + HASH_ONE = b'\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + + if inIdx >= len(txTo.vin): + return (HASH_ONE, "inIdx %d out of range (%d)" % (inIdx, len(txTo.vin))) + txtmp = CTransaction(txTo) + + for txin in txtmp.vin: + txin.scriptSig = b'' + txtmp.vin[inIdx].scriptSig = FindAndDelete(script, CScript([OP_CODESEPARATOR])) + + if (hashtype & 0x1f) == SIGHASH_NONE: + txtmp.vout = [] + + for i in range(len(txtmp.vin)): + if i != inIdx: + txtmp.vin[i].nSequence = 0 + + elif (hashtype & 0x1f) == SIGHASH_SINGLE: + outIdx = inIdx + if outIdx >= len(txtmp.vout): + return (HASH_ONE, "outIdx %d out of range (%d)" % (outIdx, len(txtmp.vout))) + + tmp = txtmp.vout[outIdx] + txtmp.vout = [] + for i in range(outIdx): + txtmp.vout.append(CTxOut(-1)) + txtmp.vout.append(tmp) + + for i in range(len(txtmp.vin)): + if i != inIdx: + txtmp.vin[i].nSequence = 0 + + if hashtype & SIGHASH_ANYONECANPAY: + tmp = txtmp.vin[inIdx] + txtmp.vin = [] + txtmp.vin.append(tmp) + + s = txtmp.serialize_without_witness() + s += struct.pack(b"<I", hashtype) + + hash = hash256(s) + + return (hash, None) + +# TODO: Allow cached hashPrevouts/hashSequence/hashOutputs to be provided. +# Performance optimization probably not necessary for python tests, however. +# Note that this corresponds to sigversion == 1 in EvalScript, which is used +# for version 0 witnesses. +def SegwitV0SignatureHash(script, txTo, inIdx, hashtype, amount): + + hashPrevouts = 0 + hashSequence = 0 + hashOutputs = 0 + + if not (hashtype & SIGHASH_ANYONECANPAY): + serialize_prevouts = bytes() + for i in txTo.vin: + serialize_prevouts += i.prevout.serialize() + hashPrevouts = uint256_from_str(hash256(serialize_prevouts)) + + if (not (hashtype & SIGHASH_ANYONECANPAY) and (hashtype & 0x1f) != SIGHASH_SINGLE and (hashtype & 0x1f) != SIGHASH_NONE): + serialize_sequence = bytes() + for i in txTo.vin: + serialize_sequence += struct.pack("<I", i.nSequence) + hashSequence = uint256_from_str(hash256(serialize_sequence)) + + if ((hashtype & 0x1f) != SIGHASH_SINGLE and (hashtype & 0x1f) != SIGHASH_NONE): + serialize_outputs = bytes() + for o in txTo.vout: + serialize_outputs += o.serialize() + hashOutputs = uint256_from_str(hash256(serialize_outputs)) + elif ((hashtype & 0x1f) == SIGHASH_SINGLE and inIdx < len(txTo.vout)): + serialize_outputs = txTo.vout[inIdx].serialize() + hashOutputs = uint256_from_str(hash256(serialize_outputs)) + + ss = bytes() + ss += struct.pack("<i", txTo.nVersion) + ss += ser_uint256(hashPrevouts) + ss += ser_uint256(hashSequence) + ss += txTo.vin[inIdx].prevout.serialize() + ss += ser_string(script) + ss += struct.pack("<q", amount) + ss += struct.pack("<I", txTo.vin[inIdx].nSequence) + ss += ser_uint256(hashOutputs) + ss += struct.pack("<i", txTo.nLockTime) + ss += struct.pack("<I", hashtype) + + return hash256(ss) + +class TestFrameworkScript(unittest.TestCase): + def test_bn2vch(self): + self.assertEqual(bn2vch(0), bytes([])) + self.assertEqual(bn2vch(1), bytes([0x01])) + self.assertEqual(bn2vch(-1), bytes([0x81])) + self.assertEqual(bn2vch(0x7F), bytes([0x7F])) + self.assertEqual(bn2vch(-0x7F), bytes([0xFF])) + self.assertEqual(bn2vch(0x80), bytes([0x80, 0x00])) + self.assertEqual(bn2vch(-0x80), bytes([0x80, 0x80])) + self.assertEqual(bn2vch(0xFF), bytes([0xFF, 0x00])) + self.assertEqual(bn2vch(-0xFF), bytes([0xFF, 0x80])) + self.assertEqual(bn2vch(0x100), bytes([0x00, 0x01])) + self.assertEqual(bn2vch(-0x100), bytes([0x00, 0x81])) + self.assertEqual(bn2vch(0x7FFF), bytes([0xFF, 0x7F])) + self.assertEqual(bn2vch(-0x8000), bytes([0x00, 0x80, 0x80])) + self.assertEqual(bn2vch(-0x7FFFFF), bytes([0xFF, 0xFF, 0xFF])) + self.assertEqual(bn2vch(0x80000000), bytes([0x00, 0x00, 0x00, 0x80, 0x00])) + self.assertEqual(bn2vch(-0x80000000), bytes([0x00, 0x00, 0x00, 0x80, 0x80])) + self.assertEqual(bn2vch(0xFFFFFFFF), bytes([0xFF, 0xFF, 0xFF, 0xFF, 0x00])) + self.assertEqual(bn2vch(123456789), bytes([0x15, 0xCD, 0x5B, 0x07])) + self.assertEqual(bn2vch(-54321), bytes([0x31, 0xD4, 0x80])) + + def test_cscriptnum_encoding(self): + # round-trip negative and multi-byte CScriptNums + values = [0, 1, -1, -2, 127, 128, -255, 256, (1 << 15) - 1, -(1 << 16), (1 << 24) - 1, (1 << 31), 1 - (1 << 32), 1 << 40, 1500, -1500] + for value in values: + self.assertEqual(CScriptNum.decode(CScriptNum.encode(CScriptNum(value))), value) diff --git a/basicswap/contrib/test_framework/segwit_addr.py b/basicswap/contrib/test_framework/segwit_addr.py new file mode 100644 index 0000000..02368e9 --- /dev/null +++ b/basicswap/contrib/test_framework/segwit_addr.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017 Pieter Wuille +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Reference implementation for Bech32 and segwit addresses.""" + + +CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" + + +def bech32_polymod(values): + """Internal function that computes the Bech32 checksum.""" + generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3] + chk = 1 + for value in values: + top = chk >> 25 + chk = (chk & 0x1ffffff) << 5 ^ value + for i in range(5): + chk ^= generator[i] if ((top >> i) & 1) else 0 + return chk + + +def bech32_hrp_expand(hrp): + """Expand the HRP into values for checksum computation.""" + return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] + + +def bech32_verify_checksum(hrp, data): + """Verify a checksum given HRP and converted data characters.""" + return bech32_polymod(bech32_hrp_expand(hrp) + data) == 1 + + +def bech32_create_checksum(hrp, data): + """Compute the checksum values given HRP and data.""" + values = bech32_hrp_expand(hrp) + data + polymod = bech32_polymod(values + [0, 0, 0, 0, 0, 0]) ^ 1 + return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)] + + +def bech32_encode(hrp, data): + """Compute a Bech32 string given HRP and data values.""" + combined = data + bech32_create_checksum(hrp, data) + return hrp + '1' + ''.join([CHARSET[d] for d in combined]) + + +def bech32_decode(bech): + """Validate a Bech32 string, and determine HRP and data.""" + if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or + (bech.lower() != bech and bech.upper() != bech)): + return (None, None) + bech = bech.lower() + pos = bech.rfind('1') + if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: + return (None, None) + if not all(x in CHARSET for x in bech[pos+1:]): + return (None, None) + hrp = bech[:pos] + data = [CHARSET.find(x) for x in bech[pos+1:]] + if not bech32_verify_checksum(hrp, data): + return (None, None) + return (hrp, data[:-6]) + + +def convertbits(data, frombits, tobits, pad=True): + """General power-of-2 base conversion.""" + acc = 0 + bits = 0 + ret = [] + maxv = (1 << tobits) - 1 + max_acc = (1 << (frombits + tobits - 1)) - 1 + for value in data: + if value < 0 or (value >> frombits): + return None + acc = ((acc << frombits) | value) & max_acc + bits += frombits + while bits >= tobits: + bits -= tobits + ret.append((acc >> bits) & maxv) + if pad: + if bits: + ret.append((acc << (tobits - bits)) & maxv) + elif bits >= frombits or ((acc << (tobits - bits)) & maxv): + return None + return ret + + +def decode(hrp, addr): + """Decode a segwit address.""" + hrpgot, data = bech32_decode(addr) + if hrpgot != hrp: + return (None, None) + decoded = convertbits(data[1:], 5, 8, False) + if decoded is None or len(decoded) < 2 or len(decoded) > 40: + return (None, None) + if data[0] > 16: + return (None, None) + if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: + return (None, None) + return (data[0], decoded) + + +def encode(hrp, witver, witprog): + """Encode a segwit address.""" + ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5)) + if decode(hrp, ret) == (None, None): + return None + return ret diff --git a/basicswap/contrib/test_framework/siphash.py b/basicswap/contrib/test_framework/siphash.py new file mode 100644 index 0000000..8583684 --- /dev/null +++ b/basicswap/contrib/test_framework/siphash.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2016-2018 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Specialized SipHash-2-4 implementations. + +This implements SipHash-2-4 for 256-bit integers. +""" + +def rotl64(n, b): + return n >> (64 - b) | (n & ((1 << (64 - b)) - 1)) << b + +def siphash_round(v0, v1, v2, v3): + v0 = (v0 + v1) & ((1 << 64) - 1) + v1 = rotl64(v1, 13) + v1 ^= v0 + v0 = rotl64(v0, 32) + v2 = (v2 + v3) & ((1 << 64) - 1) + v3 = rotl64(v3, 16) + v3 ^= v2 + v0 = (v0 + v3) & ((1 << 64) - 1) + v3 = rotl64(v3, 21) + v3 ^= v0 + v2 = (v2 + v1) & ((1 << 64) - 1) + v1 = rotl64(v1, 17) + v1 ^= v2 + v2 = rotl64(v2, 32) + return (v0, v1, v2, v3) + +def siphash256(k0, k1, h): + n0 = h & ((1 << 64) - 1) + n1 = (h >> 64) & ((1 << 64) - 1) + n2 = (h >> 128) & ((1 << 64) - 1) + n3 = (h >> 192) & ((1 << 64) - 1) + v0 = 0x736f6d6570736575 ^ k0 + v1 = 0x646f72616e646f6d ^ k1 + v2 = 0x6c7967656e657261 ^ k0 + v3 = 0x7465646279746573 ^ k1 ^ n0 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n0 + v3 ^= n1 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n1 + v3 ^= n2 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n2 + v3 ^= n3 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n3 + v3 ^= 0x2000000000000000 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= 0x2000000000000000 + v2 ^= 0xFF + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + return v0 ^ v1 ^ v2 ^ v3 diff --git a/basicswap/contrib/test_framework/util.py b/basicswap/contrib/test_framework/util.py new file mode 100644 index 0000000..c9f55e8 --- /dev/null +++ b/basicswap/contrib/test_framework/util.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +# Copyright (c) 2014-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Helpful routines for regression testing.""" + +from base64 import b64encode +from binascii import unhexlify +from decimal import Decimal, ROUND_DOWN +from subprocess import CalledProcessError +import inspect +import json +import logging +import os +import random +import re +import time + +from . import coverage +from .authproxy import AuthServiceProxy, JSONRPCException +from io import BytesIO + +logger = logging.getLogger("TestFramework.utils") + +# Assert functions +################## + + +def assert_approx(v, vexp, vspan=0.00001): + """Assert that `v` is within `vspan` of `vexp`""" + if v < vexp - vspan: + raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) + if v > vexp + vspan: + raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) + + +def assert_fee_amount(fee, tx_size, fee_per_kB): + """Assert the fee was in range""" + target_fee = round(tx_size * fee_per_kB / 1000, 8) + if fee < target_fee: + raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee))) + # allow the wallet's estimation to be at most 2 bytes off + if fee > (tx_size + 2) * fee_per_kB / 1000: + raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee))) + + +def assert_equal(thing1, thing2, *args): + if thing1 != thing2 or any(thing1 != arg for arg in args): + raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args)) + + +def assert_greater_than(thing1, thing2): + if thing1 <= thing2: + raise AssertionError("%s <= %s" % (str(thing1), str(thing2))) + + +def assert_greater_than_or_equal(thing1, thing2): + if thing1 < thing2: + raise AssertionError("%s < %s" % (str(thing1), str(thing2))) + + +def assert_raises(exc, fun, *args, **kwds): + assert_raises_message(exc, None, fun, *args, **kwds) + + +def assert_raises_message(exc, message, fun, *args, **kwds): + try: + fun(*args, **kwds) + except JSONRPCException: + raise AssertionError("Use assert_raises_rpc_error() to test RPC failures") + except exc as e: + if message is not None and message not in e.error['message']: + raise AssertionError( + "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format( + message, e.error['message'])) + except Exception as e: + raise AssertionError("Unexpected exception raised: " + type(e).__name__) + else: + raise AssertionError("No exception raised") + + +def assert_raises_process_error(returncode, output, fun, *args, **kwds): + """Execute a process and asserts the process return code and output. + + Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError + and verifies that the return code and output are as expected. Throws AssertionError if + no CalledProcessError was raised or if the return code and output are not as expected. + + Args: + returncode (int): the process return code. + output (string): [a substring of] the process output. + fun (function): the function to call. This should execute a process. + args*: positional arguments for the function. + kwds**: named arguments for the function. + """ + try: + fun(*args, **kwds) + except CalledProcessError as e: + if returncode != e.returncode: + raise AssertionError("Unexpected returncode %i" % e.returncode) + if output not in e.output: + raise AssertionError("Expected substring not found:" + e.output) + else: + raise AssertionError("No exception raised") + + +def assert_raises_rpc_error(code, message, fun, *args, **kwds): + """Run an RPC and verify that a specific JSONRPC exception code and message is raised. + + Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException + and verifies that the error code and message are as expected. Throws AssertionError if + no JSONRPCException was raised or if the error code/message are not as expected. + + Args: + code (int), optional: the error code returned by the RPC call (defined + in src/rpc/protocol.h). Set to None if checking the error code is not required. + message (string), optional: [a substring of] the error string returned by the + RPC call. Set to None if checking the error string is not required. + fun (function): the function to call. This should be the name of an RPC. + args*: positional arguments for the function. + kwds**: named arguments for the function. + """ + assert try_rpc(code, message, fun, *args, **kwds), "No exception raised" + + +def try_rpc(code, message, fun, *args, **kwds): + """Tries to run an rpc command. + + Test against error code and message if the rpc fails. + Returns whether a JSONRPCException was raised.""" + try: + fun(*args, **kwds) + except JSONRPCException as e: + # JSONRPCException was thrown as expected. Check the code and message values are correct. + if (code is not None) and (code != e.error["code"]): + raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"]) + if (message is not None) and (message not in e.error['message']): + raise AssertionError( + "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format( + message, e.error['message'])) + return True + except Exception as e: + raise AssertionError("Unexpected exception raised: " + type(e).__name__) + else: + return False + + +def assert_is_hex_string(string): + try: + int(string, 16) + except Exception as e: + raise AssertionError("Couldn't interpret %r as hexadecimal; raised: %s" % (string, e)) + + +def assert_is_hash_string(string, length=64): + if not isinstance(string, str): + raise AssertionError("Expected a string, got type %r" % type(string)) + elif length and len(string) != length: + raise AssertionError("String of length %d expected; got %d" % (length, len(string))) + elif not re.match('[abcdef0-9]+$', string): + raise AssertionError("String %r contains invalid characters for a hash." % string) + + +def assert_array_result(object_array, to_match, expected, should_not_find=False): + """ + Pass in array of JSON objects, a dictionary with key/value pairs + to match against, and another dictionary with expected key/value + pairs. + If the should_not_find flag is true, to_match should not be found + in object_array + """ + if should_not_find: + assert_equal(expected, {}) + num_matched = 0 + for item in object_array: + all_match = True + for key, value in to_match.items(): + if item[key] != value: + all_match = False + if not all_match: + continue + elif should_not_find: + num_matched = num_matched + 1 + for key, value in expected.items(): + if item[key] != value: + raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value))) + num_matched = num_matched + 1 + if num_matched == 0 and not should_not_find: + raise AssertionError("No objects matched %s" % (str(to_match))) + if num_matched > 0 and should_not_find: + raise AssertionError("Objects were found %s" % (str(to_match))) + + +# Utility functions +################### + + +def check_json_precision(): + """Make sure json library being used does not lose precision converting BTC values""" + n = Decimal("20000000.00000003") + satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8) + if satoshis != 2000000000000003: + raise RuntimeError("JSON encode/decode loses precision") + + +def EncodeDecimal(o): + if isinstance(o, Decimal): + return str(o) + raise TypeError(repr(o) + " is not JSON serializable") + + +def count_bytes(hex_string): + return len(bytearray.fromhex(hex_string)) + + +def hex_str_to_bytes(hex_str): + return unhexlify(hex_str.encode('ascii')) + + +def str_to_b64str(string): + return b64encode(string.encode('utf-8')).decode('ascii') + + +def satoshi_round(amount): + return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) + + +def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0): + if attempts == float('inf') and timeout == float('inf'): + timeout = 60 + timeout = timeout * timeout_factor + attempt = 0 + time_end = time.time() + timeout + + while attempt < attempts and time.time() < time_end: + if lock: + with lock: + if predicate(): + return + else: + if predicate(): + return + attempt += 1 + time.sleep(0.05) + + # Print the cause of the timeout + predicate_source = "''''\n" + inspect.getsource(predicate) + "'''" + logger.error("wait_until() failed. Predicate: {}".format(predicate_source)) + if attempt >= attempts: + raise AssertionError("Predicate {} not true after {} attempts".format(predicate_source, attempts)) + elif time.time() >= time_end: + raise AssertionError("Predicate {} not true after {} seconds".format(predicate_source, timeout)) + raise RuntimeError('Unreachable') + + +# RPC/P2P connection constants and functions +############################################ + +# The maximum number of nodes a single test can spawn +MAX_NODES = 12 +# Don't assign rpc or p2p ports lower than this +PORT_MIN = int(os.getenv('TEST_RUNNER_PORT_MIN', default=11000)) +# The number of ports to "reserve" for p2p and rpc, each +PORT_RANGE = 5000 + + +class PortSeed: + # Must be initialized with a unique integer for each process + n = None + + +def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None): + """ + Args: + url (str): URL of the RPC server to call + node_number (int): the node number (or id) that this calls to + + Kwargs: + timeout (int): HTTP timeout in seconds + coveragedir (str): Directory + + Returns: + AuthServiceProxy. convenience object for making RPC calls. + + """ + proxy_kwargs = {} + if timeout is not None: + proxy_kwargs['timeout'] = int(timeout) + + proxy = AuthServiceProxy(url, **proxy_kwargs) + proxy.url = url # store URL on proxy for info + + coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None + + return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile) + + +def p2p_port(n): + assert n <= MAX_NODES + return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES) + + +def rpc_port(n): + return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES) + + +def rpc_url(datadir, i, chain, rpchost): + rpc_u, rpc_p = get_auth_cookie(datadir, chain) + host = '127.0.0.1' + port = rpc_port(i) + if rpchost: + parts = rpchost.split(':') + if len(parts) == 2: + host, port = parts + else: + host = rpchost + return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port)) + + +# Node functions +################ + + +def initialize_datadir(dirname, n, chain): + datadir = get_datadir_path(dirname, n) + if not os.path.isdir(datadir): + os.makedirs(datadir) + # Translate chain name to config name + if chain == 'testnet3': + chain_name_conf_arg = 'testnet' + chain_name_conf_section = 'test' + else: + chain_name_conf_arg = chain + chain_name_conf_section = chain + with open(os.path.join(datadir, "particl.conf"), 'w', encoding='utf8') as f: + f.write("{}=1\n".format(chain_name_conf_arg)) + f.write("[{}]\n".format(chain_name_conf_section)) + f.write("port=" + str(p2p_port(n)) + "\n") + f.write("rpcport=" + str(rpc_port(n)) + "\n") + f.write("fallbackfee=0.0002\n") + f.write("server=1\n") + f.write("keypool=1\n") + f.write("discover=0\n") + f.write("dnsseed=0\n") + f.write("listenonion=0\n") + f.write("printtoconsole=0\n") + f.write("upnp=0\n") + f.write("shrinkdebugfile=0\n") + os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True) + os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True) + return datadir + + +def get_datadir_path(dirname, n): + return os.path.join(dirname, "node" + str(n)) + + +def append_config(datadir, options): + with open(os.path.join(datadir, "particl.conf"), 'a', encoding='utf8') as f: + for option in options: + f.write(option + "\n") + + +def get_auth_cookie(datadir, chain): + user = None + password = None + if os.path.isfile(os.path.join(datadir, "particl.conf")): + with open(os.path.join(datadir, "particl.conf"), 'r', encoding='utf8') as f: + for line in f: + if line.startswith("rpcuser="): + assert user is None # Ensure that there is only one rpcuser line + user = line.split("=")[1].strip("\n") + if line.startswith("rpcpassword="): + assert password is None # Ensure that there is only one rpcpassword line + password = line.split("=")[1].strip("\n") + try: + with open(os.path.join(datadir, chain, ".cookie"), 'r', encoding="ascii") as f: + userpass = f.read() + split_userpass = userpass.split(':') + user = split_userpass[0] + password = split_userpass[1] + except OSError: + pass + if user is None or password is None: + raise ValueError("No RPC credentials") + return user, password + + +# If a cookie file exists in the given datadir, delete it. +def delete_cookie_file(datadir, chain): + if os.path.isfile(os.path.join(datadir, chain, ".cookie")): + logger.debug("Deleting leftover cookie file") + os.remove(os.path.join(datadir, chain, ".cookie")) + + +def softfork_active(node, key): + """Return whether a softfork is active.""" + return node.getblockchaininfo()['softforks'][key]['active'] + + +def set_node_times(nodes, t): + for node in nodes: + node.setmocktime(t) + + +def disconnect_nodes(from_connection, node_num): + def get_peer_ids(): + result = [] + for peer in from_connection.getpeerinfo(): + if "testnode{}".format(node_num) in peer['subver']: + result.append(peer['id']) + return result + + peer_ids = get_peer_ids() + if not peer_ids: + logger.warning("disconnect_nodes: {} and {} were not connected".format( + from_connection.index, + node_num, + )) + return + for peer_id in peer_ids: + try: + from_connection.disconnectnode(nodeid=peer_id) + except JSONRPCException as e: + # If this node is disconnected between calculating the peer id + # and issuing the disconnect, don't worry about it. + # This avoids a race condition if we're mass-disconnecting peers. + if e.error['code'] != -29: # RPC_CLIENT_NODE_NOT_CONNECTED + raise + + # wait to disconnect + wait_until(lambda: not get_peer_ids(), timeout=5) + + +def connect_nodes(from_connection, node_num): + ip_port = "127.0.0.1:" + str(p2p_port(node_num)) + from_connection.addnode(ip_port, "onetry") + # poll until version handshake complete to avoid race conditions + # with transaction relaying + # See comments in net_processing: + # * Must have a version message before anything else + # * Must have a verack message before anything else + wait_until(lambda: all(peer['version'] != 0 for peer in from_connection.getpeerinfo())) + wait_until(lambda: all(peer['bytesrecv_per_msg'].pop('verack', 0) == 24 for peer in from_connection.getpeerinfo())) + + +# Transaction/Block functions +############################# + + +def find_output(node, txid, amount, *, blockhash=None): + """ + Return index to output of txid with value amount + Raises exception if there is none. + """ + txdata = node.getrawtransaction(txid, 1, blockhash) + for i in range(len(txdata["vout"])): + if txdata["vout"][i]["value"] == amount: + return i + raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount))) + + +def gather_inputs(from_node, amount_needed, confirmations_required=1): + """ + Return a random set of unspent txouts that are enough to pay amount_needed + """ + assert confirmations_required >= 0 + utxo = from_node.listunspent(confirmations_required) + random.shuffle(utxo) + inputs = [] + total_in = Decimal("0.00000000") + while total_in < amount_needed and len(utxo) > 0: + t = utxo.pop() + total_in += t["amount"] + inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]}) + if total_in < amount_needed: + raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in)) + return (total_in, inputs) + + +def make_change(from_node, amount_in, amount_out, fee): + """ + Create change output(s), return them + """ + outputs = {} + amount = amount_out + fee + change = amount_in - amount + if change > amount * 2: + # Create an extra change output to break up big inputs + change_address = from_node.getnewaddress() + # Split change in two, being careful of rounding: + outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) + change = amount_in - amount - outputs[change_address] + if change > 0: + outputs[from_node.getnewaddress()] = change + return outputs + + +def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants): + """ + Create a random transaction. + Returns (txid, hex-encoded-transaction-data, fee) + """ + from_node = random.choice(nodes) + to_node = random.choice(nodes) + fee = min_fee + fee_increment * random.randint(0, fee_variants) + + (total_in, inputs) = gather_inputs(from_node, amount + fee) + outputs = make_change(from_node, total_in, amount, fee) + outputs[to_node.getnewaddress()] = float(amount) + + rawtx = from_node.createrawtransaction(inputs, outputs) + signresult = from_node.signrawtransactionwithwallet(rawtx) + txid = from_node.sendrawtransaction(signresult["hex"], 0) + + return (txid, signresult["hex"], fee) + + +# Helper to create at least "count" utxos +# Pass in a fee that is sufficient for relay and mining new transactions. +def create_confirmed_utxos(fee, node, count): + to_generate = int(0.5 * count) + 101 + while to_generate > 0: + node.generate(min(25, to_generate)) + to_generate -= 25 + utxos = node.listunspent() + iterations = count - len(utxos) + addr1 = node.getnewaddress() + addr2 = node.getnewaddress() + if iterations <= 0: + return utxos + for i in range(iterations): + t = utxos.pop() + inputs = [] + inputs.append({"txid": t["txid"], "vout": t["vout"]}) + outputs = {} + send_value = t['amount'] - fee + outputs[addr1] = satoshi_round(send_value / 2) + outputs[addr2] = satoshi_round(send_value / 2) + raw_tx = node.createrawtransaction(inputs, outputs) + signed_tx = node.signrawtransactionwithwallet(raw_tx)["hex"] + node.sendrawtransaction(signed_tx) + + while (node.getmempoolinfo()['size'] > 0): + node.generate(1) + + utxos = node.listunspent() + assert len(utxos) >= count + return utxos + + +# Create large OP_RETURN txouts that can be appended to a transaction +# to make it large (helper for constructing large transactions). +def gen_return_txouts(): + # Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create + # So we have big transactions (and therefore can't fit very many into each block) + # create one script_pubkey + script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes + for i in range(512): + script_pubkey = script_pubkey + "01" + # concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change + txouts = [] + from .messages import CTxOut + txout = CTxOut() + txout.nValue = 0 + txout.scriptPubKey = hex_str_to_bytes(script_pubkey) + for k in range(128): + txouts.append(txout) + return txouts + + +# Create a spend of each passed-in utxo, splicing in "txouts" to each raw +# transaction to make it large. See gen_return_txouts() above. +def create_lots_of_big_transactions(node, txouts, utxos, num, fee): + addr = node.getnewaddress() + txids = [] + from .messages import CTransaction + for _ in range(num): + t = utxos.pop() + inputs = [{"txid": t["txid"], "vout": t["vout"]}] + outputs = {} + change = t['amount'] - fee + outputs[addr] = satoshi_round(change) + rawtx = node.createrawtransaction(inputs, outputs) + tx = CTransaction() + tx.deserialize(BytesIO(hex_str_to_bytes(rawtx))) + for txout in txouts: + tx.vout.append(txout) + newtx = tx.serialize().hex() + signresult = node.signrawtransactionwithwallet(newtx, None, "NONE") + txid = node.sendrawtransaction(signresult["hex"], 0) + txids.append(txid) + return txids + + +def mine_large_block(node, utxos=None): + # generate a 66k transaction, + # and 14 of them is close to the 1MB block limit + num = 14 + txouts = gen_return_txouts() + utxos = utxos if utxos is not None else [] + if len(utxos) < num: + utxos.clear() + utxos.extend(node.listunspent()) + fee = 100 * node.getnetworkinfo()["relayfee"] + create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee) + node.generate(1) + + +def find_vout_for_address(node, txid, addr): + """ + Locate the vout index of the given transaction sending to the + given address. Raises runtime error exception if not found. + """ + tx = node.getrawtransaction(txid, True) + for i in range(len(tx["vout"])): + if any([addr == a for a in tx["vout"][i]["scriptPubKey"]["addresses"]]): + return i + raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr)) diff --git a/basicswap/contrib/test_framework/wallet_util.py b/basicswap/contrib/test_framework/wallet_util.py new file mode 100755 index 0000000..688d565 --- /dev/null +++ b/basicswap/contrib/test_framework/wallet_util.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright (c) 2018-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Useful util functions for testing the wallet""" +from collections import namedtuple + +from .address import ( + byte_to_base58, + key_to_p2pkh, + key_to_p2sh_p2wpkh, + key_to_p2wpkh, + script_to_p2sh, + script_to_p2sh_p2wsh, + script_to_p2wsh, +) +from .key import ECKey +from .script import ( + CScript, + OP_0, + OP_2, + OP_3, + OP_CHECKMULTISIG, + OP_CHECKSIG, + OP_DUP, + OP_EQUAL, + OP_EQUALVERIFY, + OP_HASH160, + hash160, + sha256, +) +from .util import hex_str_to_bytes + +Key = namedtuple('Key', ['privkey', + 'pubkey', + 'p2pkh_script', + 'p2pkh_addr', + 'p2wpkh_script', + 'p2wpkh_addr', + 'p2sh_p2wpkh_script', + 'p2sh_p2wpkh_redeem_script', + 'p2sh_p2wpkh_addr']) + +Multisig = namedtuple('Multisig', ['privkeys', + 'pubkeys', + 'p2sh_script', + 'p2sh_addr', + 'redeem_script', + 'p2wsh_script', + 'p2wsh_addr', + 'p2sh_p2wsh_script', + 'p2sh_p2wsh_addr']) + +def get_key(node): + """Generate a fresh key on node + + Returns a named tuple of privkey, pubkey and all address and scripts.""" + addr = node.getnewaddress() + pubkey = node.getaddressinfo(addr)['pubkey'] + pkh = hash160(hex_str_to_bytes(pubkey)) + return Key(privkey=node.dumpprivkey(addr), + pubkey=pubkey, + p2pkh_script=CScript([OP_DUP, OP_HASH160, pkh, OP_EQUALVERIFY, OP_CHECKSIG]).hex(), + p2pkh_addr=key_to_p2pkh(pubkey), + p2wpkh_script=CScript([OP_0, pkh]).hex(), + p2wpkh_addr=key_to_p2wpkh(pubkey), + p2sh_p2wpkh_script=CScript([OP_HASH160, hash160(CScript([OP_0, pkh])), OP_EQUAL]).hex(), + p2sh_p2wpkh_redeem_script=CScript([OP_0, pkh]).hex(), + p2sh_p2wpkh_addr=key_to_p2sh_p2wpkh(pubkey)) + +def get_generate_key(): + """Generate a fresh key + + Returns a named tuple of privkey, pubkey and all address and scripts.""" + eckey = ECKey() + eckey.generate() + privkey = bytes_to_wif(eckey.get_bytes()) + pubkey = eckey.get_pubkey().get_bytes().hex() + pkh = hash160(hex_str_to_bytes(pubkey)) + return Key(privkey=privkey, + pubkey=pubkey, + p2pkh_script=CScript([OP_DUP, OP_HASH160, pkh, OP_EQUALVERIFY, OP_CHECKSIG]).hex(), + p2pkh_addr=key_to_p2pkh(pubkey), + p2wpkh_script=CScript([OP_0, pkh]).hex(), + p2wpkh_addr=key_to_p2wpkh(pubkey), + p2sh_p2wpkh_script=CScript([OP_HASH160, hash160(CScript([OP_0, pkh])), OP_EQUAL]).hex(), + p2sh_p2wpkh_redeem_script=CScript([OP_0, pkh]).hex(), + p2sh_p2wpkh_addr=key_to_p2sh_p2wpkh(pubkey)) + +def get_multisig(node): + """Generate a fresh 2-of-3 multisig on node + + Returns a named tuple of privkeys, pubkeys and all address and scripts.""" + addrs = [] + pubkeys = [] + for _ in range(3): + addr = node.getaddressinfo(node.getnewaddress()) + addrs.append(addr['address']) + pubkeys.append(addr['pubkey']) + script_code = CScript([OP_2] + [hex_str_to_bytes(pubkey) for pubkey in pubkeys] + [OP_3, OP_CHECKMULTISIG]) + witness_script = CScript([OP_0, sha256(script_code)]) + return Multisig(privkeys=[node.dumpprivkey(addr) for addr in addrs], + pubkeys=pubkeys, + p2sh_script=CScript([OP_HASH160, hash160(script_code), OP_EQUAL]).hex(), + p2sh_addr=script_to_p2sh(script_code), + redeem_script=script_code.hex(), + p2wsh_script=witness_script.hex(), + p2wsh_addr=script_to_p2wsh(script_code), + p2sh_p2wsh_script=CScript([OP_HASH160, witness_script, OP_EQUAL]).hex(), + p2sh_p2wsh_addr=script_to_p2sh_p2wsh(script_code)) + +def test_address(node, address, **kwargs): + """Get address info for `address` and test whether the returned values are as expected.""" + addr_info = node.getaddressinfo(address) + for key, value in kwargs.items(): + if value is None: + if key in addr_info.keys(): + raise AssertionError("key {} unexpectedly returned in getaddressinfo.".format(key)) + elif addr_info[key] != value: + raise AssertionError("key {} value {} did not match expected value {}".format(key, addr_info[key], value)) + +def bytes_to_wif(b, compressed=True, prefix=239): + if compressed: + b += b'\x01' + return byte_to_base58(b, prefix) + +def generate_wif_key(): + # Makes a WIF privkey for imports + k = ECKey() + k.generate() + return bytes_to_wif(k.get_bytes(), k.is_compressed) diff --git a/basicswap/db.py b/basicswap/db.py index ffdc322..02f32a8 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -97,7 +97,7 @@ class Bid(Base): participate_txn_refund = sa.Column(sa.LargeBinary) state = sa.Column(sa.Integer) - state_time = sa.Column(sa.BigInteger) # timestamp of last state change + state_time = sa.Column(sa.BigInteger) # Timestamp of last state change states = sa.Column(sa.LargeBinary) # Packed states and times state_note = sa.Column(sa.String) diff --git a/basicswap/ecc_util.py b/basicswap/ecc_util.py new file mode 100644 index 0000000..e88a010 --- /dev/null +++ b/basicswap/ecc_util.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import codecs +import hashlib +import secrets + +from .contrib.ellipticcurve import CurveFp, Point, INFINITY, jacobi_symbol + + +class ECCParameters(): + def __init__(self, p, a, b, Gx, Gy, o): + self.p = p + self.a = a + self.b = b + self.Gx = Gx + self.Gy = Gy + self.o = o + + +ep = ECCParameters( \ + p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f, \ + a = 0x0, \ + b = 0x7, \ + Gx = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, \ + Gy = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8, \ + o = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141) # noqa: E221,E251,E502 + +curve_secp256k1 = CurveFp(ep.p, ep.a, ep.b) +G = Point(curve_secp256k1, ep.Gx, ep.Gy, ep.o) +SECP256K1_ORDER_HALF = ep.o // 2 + + +def ToDER(P): + return bytes((4, )) + int(P.x()).to_bytes(32, byteorder='big') + int(P.y()).to_bytes(32, byteorder='big') + + +def bytes32ToInt(b): + return int.from_bytes(b, byteorder='big') + + +def intToBytes32(i): + return i.to_bytes(32, byteorder='big') + + +def intToBytes32_le(i): + return i.to_bytes(32, byteorder='little') + + +def bytesToHexStr(b): + return codecs.encode(b, 'hex').decode('utf-8') + + +def hexStrToBytes(h): + if h.startswith('0x'): + h = h[2:] + return bytes.fromhex(h) + + +def getSecretBytes(): + i = 1 + secrets.randbelow(ep.o - 1) + return intToBytes32(i) + + +def getSecretInt(): + return 1 + secrets.randbelow(ep.o - 1) + + +def getInsecureBytes(): + while True: + s = os.urandom(32) + + s_test = int.from_bytes(s, byteorder='big') + if s_test > 1 and s_test < ep.o: + return s + + +def getInsecureInt(): + while True: + s = os.urandom(32) + + s_test = int.from_bytes(s, byteorder='big') + if s_test > 1 and s_test < ep.o: + return s_test + + +def powMod(x, y, z): + # Calculate (x ** y) % z efficiently. + number = 1 + while y: + if y & 1: + number = number * x % z + y >>= 1 # y //= 2 + + x = x * x % z + return number + + +def ExpandPoint(xb, sign): + x = int.from_bytes(xb, byteorder='big') + a = (powMod(x, 3, ep.p) + 7) % ep.p + y = powMod(a, (ep.p + 1) // 4, ep.p) + + if sign: + y = ep.p - y + return Point(curve_secp256k1, x, y, ep.o) + + +def CPKToPoint(cpk): + y_parity = cpk[0] - 2 + + x = int.from_bytes(cpk[1:], byteorder='big') + a = (powMod(x, 3, ep.p) + 7) % ep.p + y = powMod(a, (ep.p + 1) // 4, ep.p) + + if y % 2 != y_parity: + y = ep.p - y + + return Point(curve_secp256k1, x, y, ep.o) + + +def pointToCPK2(point, ind=0x09): + # The function is_square(x), where x is an integer, returns whether or not x is a quadratic residue modulo p. Since p is prime, it is equivalent to the Legendre symbol (x / p) = x(p-1)/2 mod p being equal to 1[8]. + ind = bytes((ind ^ (1 if jacobi_symbol(point.y(), ep.p) == 1 else 0),)) + return ind + point.x().to_bytes(32, byteorder='big') + + +def pointToCPK(point): + + y = point.y().to_bytes(32, byteorder='big') + ind = bytes((0x03,)) if y[31] % 2 else bytes((0x02,)) + + cpk = ind + point.x().to_bytes(32, byteorder='big') + return cpk + + +def secretToCPK(secret): + secretInt = secret if isinstance(secret, int) \ + else int.from_bytes(secret, byteorder='big') + + R = G * secretInt + + Y = R.y().to_bytes(32, byteorder='big') + ind = bytes((0x03,)) if Y[31] % 2 else bytes((0x02,)) + + pubkey = ind + R.x().to_bytes(32, byteorder='big') + + return pubkey + + +def getKeypair(): + secretBytes = getSecretBytes() + return secretBytes, secretToCPK(secretBytes) + + +def hashToCurve(pubkey): + + xBytes = hashlib.sha256(pubkey).digest() + x = int.from_bytes(xBytes, byteorder='big') + + for k in range(0, 100): + # get matching y element for point + y_parity = 0 # always pick 0, + a = (powMod(x, 3, ep.p) + 7) % ep.p + y = powMod(a, (ep.p + 1) // 4, ep.p) + + # print("before parity %x" % (y)) + if y % 2 != y_parity: + y = ep.p - y + + # If x is always mod P, can R ever not be on the curve? + try: + R = Point(curve_secp256k1, x, y, ep.o) + except Exception: + x = (x + 1) % ep.p # % P? + continue + + if R == INFINITY or R * ep.o != INFINITY: # is R * O != INFINITY check necessary? Validation of Elliptic Curve Public Keys says no if cofactor = 1 + x = (x + 1) % ep.p # % P? + continue + return R + + raise ValueError('hashToCurve failed for 100 tries') + + +def hash256(inb): + return hashlib.sha256(inb).digest() + + +i2b = intToBytes32 +b2i = bytes32ToInt +b2h = bytesToHexStr +h2b = hexStrToBytes + + +def i2h(x): + return b2h(i2b(x)) + + +def testEccUtils(): + print('testEccUtils()') + + G_enc = ToDER(G) + assert(G_enc.hex() == '0479be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8') + + G_enc = pointToCPK(G) + assert(G_enc.hex() == '0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798') + G_dec = CPKToPoint(G_enc) + assert(G_dec == G) + + G_enc = pointToCPK2(G) + assert(G_enc.hex() == '0879be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798') + + H = hashToCurve(ToDER(G)) + assert(pointToCPK(H).hex() == '0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0') + + print('Passed.') + + +if __name__ == "__main__": + testEccUtils() diff --git a/basicswap/http_server.py b/basicswap/http_server.py index 1eeac72..1731a78 100644 --- a/basicswap/http_server.py +++ b/basicswap/http_server.py @@ -19,7 +19,7 @@ from . import __version__ from .util import ( COIN, format8, - makeInt, + make_int, dumpj, ) from .chainparams import ( @@ -129,7 +129,7 @@ def validateAmountString(amount): def inputAmount(amount_str): validateAmountString(amount_str) - return makeInt(amount_str) + return make_int(amount_str) def setCoinFilter(form_data, field_name): diff --git a/basicswap/interface_btc.py b/basicswap/interface_btc.py new file mode 100644 index 0000000..7c989e8 --- /dev/null +++ b/basicswap/interface_btc.py @@ -0,0 +1,805 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import time +import hashlib +import logging +from io import BytesIO + +from .util import ( + decodeScriptNum, + getCompactSizeLen, + dumpj, + format_amount, + make_int +) + +from .ecc_util import ( + G, ep, + pointToCPK, CPKToPoint, + getSecretInt, + b2h, i2b, b2i, i2h) + +from .contrib.test_framework.messages import ( + COIN, + COutPoint, + CTransaction, + CTxIn, + CTxInWitness, + CTxOut, + FromHex, + ToHex) + +from .contrib.test_framework.script import ( + CScript, + CScriptOp, + CScriptNum, + OP_IF, OP_ELSE, OP_ENDIF, + OP_0, + OP_2, + OP_16, + OP_EQUALVERIFY, + OP_CHECKSIG, + OP_SIZE, + OP_SHA256, + OP_CHECKMULTISIG, + OP_CHECKSEQUENCEVERIFY, + OP_DROP, + SIGHASH_ALL, + SegwitV0SignatureHash, + hash160) + +from .contrib.test_framework.key import ECKey, ECPubKey + +from .chainparams import CoinInterface +from .rpc import make_rpc_func +from .util import assert_cond + + +def findOutput(tx, script_pk): + for i in range(len(tx.vout)): + if tx.vout[i].scriptPubKey == script_pk: + return i + return None + + +class BTCInterface(CoinInterface): + @staticmethod + def exp(): + return 8 + + @staticmethod + def nbk(): + return 32 + + @staticmethod + def nbK(): # No. of bytes requires to encode a public key + return 33 + + @staticmethod + def witnessScaleFactor(): + return 4 + + @staticmethod + def txVersion(): + return 2 + + @staticmethod + def getTxOutputValue(tx): + rv = 0 + for output in tx.vout: + rv += output.nValue + return rv + + def compareFeeRates(self, a, b): + return abs(a - b) < 20 + + def __init__(self, coin_settings): + self.rpc_callback = make_rpc_func(coin_settings['rpcport'], coin_settings['rpcauth']) + self.txoType = CTxOut + + def getNewSecretKey(self): + return getSecretInt() + + def pubkey(self, key): + return G * key + + def encodePubkey(self, pk): + return pointToCPK(pk) + + def decodePubkey(self, pke): + return CPKToPoint(pke) + + def decodeKey(self, k): + i = b2i(k) + assert(i < ep.o) + return i + + def sumKeys(self, ka, kb): + return (ka + kb) % ep.o + + def sumPubkeys(self, Ka, Kb): + return Ka + Kb + + def extractScriptLockScriptValues(self, script_bytes): + script_len = len(script_bytes) + assert_cond(script_len > 112, 'Bad script length') + assert_cond(script_bytes[0] == OP_IF) + assert_cond(script_bytes[1] == OP_SIZE) + assert_cond(script_bytes[2:4] == bytes((1, 32))) # 0120, CScriptNum length, then data + assert_cond(script_bytes[4] == OP_EQUALVERIFY) + assert_cond(script_bytes[5] == OP_SHA256) + assert_cond(script_bytes[6] == 32) + secret_hash = script_bytes[7: 7 + 32] + assert_cond(script_bytes[39] == OP_EQUALVERIFY) + assert_cond(script_bytes[40] == OP_2) + assert_cond(script_bytes[41] == 33) + pk1 = script_bytes[42: 42 + 33] + assert_cond(script_bytes[75] == 33) + pk2 = script_bytes[76: 76 + 33] + assert_cond(script_bytes[109] == OP_2) + assert_cond(script_bytes[110] == OP_CHECKMULTISIG) + assert_cond(script_bytes[111] == OP_ELSE) + o = 112 + + # Decode script num + csv_val, nb = decodeScriptNum(script_bytes, o) + o += nb + + assert_cond(script_len == o + 8 + 66, 'Bad script length') # Fails if script too long + assert_cond(script_bytes[o] == OP_CHECKSEQUENCEVERIFY) + o += 1 + assert_cond(script_bytes[o] == OP_DROP) + o += 1 + assert_cond(script_bytes[o] == OP_2) + o += 1 + assert_cond(script_bytes[o] == 33) + o += 1 + pk3 = script_bytes[o: o + 33] + o += 33 + assert_cond(script_bytes[o] == 33) + o += 1 + pk4 = script_bytes[o: o + 33] + o += 33 + assert_cond(script_bytes[o] == OP_2) + o += 1 + assert_cond(script_bytes[o] == OP_CHECKMULTISIG) + o += 1 + assert_cond(script_bytes[o] == OP_ENDIF) + + return secret_hash, pk1, pk2, csv_val, pk3, pk4 + + def genScriptLockTxScript(self, sh, Kal, Kaf, lock_blocks, Karl, Karf): + return CScript([ + CScriptOp(OP_IF), + CScriptOp(OP_SIZE), 32, CScriptOp(OP_EQUALVERIFY), + CScriptOp(OP_SHA256), sh, CScriptOp(OP_EQUALVERIFY), + 2, self.encodePubkey(Kal), self.encodePubkey(Kaf), 2, CScriptOp(OP_CHECKMULTISIG), + CScriptOp(OP_ELSE), + lock_blocks, CScriptOp(OP_CHECKSEQUENCEVERIFY), CScriptOp(OP_DROP), + 2, self.encodePubkey(Karl), self.encodePubkey(Karf), 2, CScriptOp(OP_CHECKMULTISIG), + CScriptOp(OP_ENDIF)]) + + def createScriptLockTx(self, value, sh, Kal, Kaf, lock_blocks, Karl, Karf): + + script = self.genScriptLockTxScript(sh, Kal, Kaf, lock_blocks, Karl, Karf) + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vout.append(self.txoType(value, CScript([OP_0, hashlib.sha256(script).digest()]))) + + return tx, script + + def extractScriptLockRefundScriptValues(self, script_bytes): + script_len = len(script_bytes) + assert_cond(script_len > 73, 'Bad script length') + assert_cond(script_bytes[0] == OP_IF) + assert_cond(script_bytes[1] == OP_2) + assert_cond(script_bytes[2] == 33) + pk1 = script_bytes[3: 3 + 33] + assert_cond(script_bytes[36] == 33) + pk2 = script_bytes[37: 37 + 33] + assert_cond(script_bytes[70] == OP_2) + assert_cond(script_bytes[71] == OP_CHECKMULTISIG) + assert_cond(script_bytes[72] == OP_ELSE) + o = 73 + csv_val, nb = decodeScriptNum(script_bytes, o) + o += nb + + assert_cond(script_len == o + 5 + 33, 'Bad script length') # Fails if script too long + assert_cond(script_bytes[o] == OP_CHECKSEQUENCEVERIFY) + o += 1 + assert_cond(script_bytes[o] == OP_DROP) + o += 1 + assert_cond(script_bytes[o] == 33) + o += 1 + pk3 = script_bytes[o: o + 33] + o += 33 + assert_cond(script_bytes[o] == OP_CHECKSIG) + o += 1 + assert_cond(script_bytes[o] == OP_ENDIF) + + return pk1, pk2, csv_val, pk3 + + def genScriptLockRefundTxScript(self, Karl, Karf, csv_val, Kaf): + return CScript([ + CScriptOp(OP_IF), + 2, self.encodePubkey(Karl), self.encodePubkey(Karf), 2, CScriptOp(OP_CHECKMULTISIG), + CScriptOp(OP_ELSE), + csv_val, CScriptOp(OP_CHECKSEQUENCEVERIFY), CScriptOp(OP_DROP), + self.encodePubkey(Kaf), CScriptOp(OP_CHECKSIG), + CScriptOp(OP_ENDIF)]) + + def createScriptLockRefundTx(self, tx_lock, script_lock, Karl, Karf, csv_val, Kaf, tx_fee_rate): + + output_script = CScript([OP_0, hashlib.sha256(script_lock).digest()]) + locked_n = findOutput(tx_lock, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock.vout[locked_n].nValue + + tx_lock.rehash() + tx_lock_hash_int = tx_lock.sha256 + + sh, A, B, lock1_value, C, D = self.extractScriptLockScriptValues(script_lock) + + refund_script = self.genScriptLockRefundTxScript(Karl, Karf, csv_val, Kaf) + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_hash_int, locked_n), nSequence=lock1_value)) + tx.vout.append(self.txoType(locked_coin, CScript([OP_0, hashlib.sha256(refund_script).digest()]))) + + witness_bytes = len(script_lock) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 2 # 2 empty witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockRefundTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx, refund_script, tx.vout[0].nValue + + def createScriptLockRefundSpendTx(self, tx_lock_refund, script_lock_refund, Kal, tx_fee_rate): + # Returns the coinA locked coin to the leader + # The follower will sign the multisig path with a signature encumbered by the leader's coinB spend pubkey + # When the leader publishes the decrypted signature the leader's coinB spend privatekey will be revealed to the follower + + output_script = CScript([OP_0, hashlib.sha256(script_lock_refund).digest()]) + locked_n = findOutput(tx_lock_refund, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock_refund.vout[locked_n].nValue + + tx_lock_refund.rehash() + tx_lock_refund_hash_int = tx_lock_refund.sha256 + + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_refund_hash_int, locked_n), nSequence=0)) + + pubkeyhash = hash160(self.encodePubkey(Kal)) + tx.vout.append(self.txoType(locked_coin, CScript([OP_0, pubkeyhash]))) + + witness_bytes = len(script_lock_refund) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byte size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockRefundSpendTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx + + def createScriptLockRefundSpendToFTx(self, tx_lock_refund, script_lock_refund, pkh_dest, tx_fee_rate): + # Sends the coinA locked coin to the follower + output_script = CScript([OP_0, hashlib.sha256(script_lock_refund).digest()]) + locked_n = findOutput(tx_lock_refund, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock_refund.vout[locked_n].nValue + + A, B, lock2_value, C = self.extractScriptLockRefundScriptValues(script_lock_refund) + + tx_lock_refund.rehash() + tx_lock_refund_hash_int = tx_lock_refund.sha256 + + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_refund_hash_int, locked_n), nSequence=lock2_value)) + + tx.vout.append(self.txoType(locked_coin, CScript([OP_0, pkh_dest]))) + + witness_bytes = len(script_lock_refund) + witness_bytes += 73 # signature (72 + 1 byte size) + witness_bytes += 1 # 1 empty stack value + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockRefundSpendToFTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx + + def createScriptLockSpendTx(self, tx_lock, script_lock, pkh_dest, tx_fee_rate): + + output_script = CScript([OP_0, hashlib.sha256(script_lock).digest()]) + locked_n = findOutput(tx_lock, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock.vout[locked_n].nValue + + tx_lock.rehash() + tx_lock_hash_int = tx_lock.sha256 + + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_hash_int, locked_n))) + + p2wpkh = CScript([OP_0, pkh_dest]) + tx.vout.append(self.txoType(locked_coin, p2wpkh)) + + witness_bytes = len(script_lock) + witness_bytes += 33 # sv, size + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockSpendTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx + + def verifyLockTx(self, tx, script_out, + swap_value, + sh, + Kal, Kaf, + lock_value, feerate, + Karl, Karf, + check_lock_tx_inputs): + # Verify: + # + + # Not necessary to check the lock txn is mineable, as protocol will wait for it to confirm + # However by checking early we can avoid wasting time processing unmineable txns + # Check fee is reasonable + + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'Bad nLockTime') + + script_pk = CScript([OP_0, hashlib.sha256(script_out).digest()]) + locked_n = findOutput(tx, script_pk) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx.vout[locked_n].nValue + + assert_cond(locked_coin == swap_value, 'Bad locked value') + + # Check script and values + shv, A, B, csv_val, C, D = self.extractScriptLockScriptValues(script_out) + assert_cond(shv == sh, 'Bad hash lock') + assert_cond(A == self.encodePubkey(Kal), 'Bad script pubkey') + assert_cond(B == self.encodePubkey(Kaf), 'Bad script pubkey') + assert_cond(csv_val == lock_value, 'Bad script csv value') + assert_cond(C == self.encodePubkey(Karl), 'Bad script pubkey') + assert_cond(D == self.encodePubkey(Karf), 'Bad script pubkey') + + if check_lock_tx_inputs: + # Check that inputs are unspent and verify fee rate + inputs_value = 0 + add_bytes = 0 + add_witness_bytes = getCompactSizeLen(len(tx.vin)) + for pi in tx.vin: + ptx = self.rpc_callback('getrawtransaction', [i2h(pi.prevout.hash), True]) + print('ptx', dumpj(ptx)) + prevout = ptx['vout'][pi.prevout.n] + inputs_value += make_int(prevout['value']) + + prevout_type = prevout['scriptPubKey']['type'] + if prevout_type == 'witness_v0_keyhash': + add_witness_bytes += 107 # sig 72, pk 33 and 2 size bytes + add_witness_bytes += getCompactSizeLen(107) + else: + # Assume P2PKH, TODO more types + add_bytes += 107 # OP_PUSH72 <ecdsa_signature> OP_PUSH33 <public_key> + + outputs_value = 0 + for txo in tx.vout: + outputs_value += txo.nValue + fee_paid = inputs_value - outputs_value + assert(fee_paid > 0) + + vsize = self.getTxVSize(tx, add_bytes, add_witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', locked_coin, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + logging.warning('feerate paid doesn\'t match expected: %ld, %ld', fee_rate_paid, feerate) + # TODO: Display warning to user + + return tx_hash, locked_n + + def verifyLockRefundTx(self, tx, script_out, + prevout_id, prevout_n, prevout_seq, prevout_script, + Karl, Karf, csv_val_expect, Kaf, swap_value, feerate): + # Verify: + # Must have only one input with correct prevout and sequence + # Must have only one output to the p2wsh of the lock refund script + # Output value must be locked_coin - lock tx fee + + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock refund tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'nLockTime not 0') + assert_cond(len(tx.vin) == 1, 'tx doesn\'t have one input') + + assert_cond(tx.vin[0].nSequence == prevout_seq, 'Bad input nSequence') + assert_cond(len(tx.vin[0].scriptSig) == 0, 'Input scriptsig not empty') + assert_cond(tx.vin[0].prevout.hash == b2i(prevout_id) and tx.vin[0].prevout.n == prevout_n, 'Input prevout mismatch') + + assert_cond(len(tx.vout) == 1, 'tx doesn\'t have one output') + + script_pk = CScript([OP_0, hashlib.sha256(script_out).digest()]) + locked_n = findOutput(tx, script_pk) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx.vout[locked_n].nValue + + # Check script and values + A, B, csv_val, C = self.extractScriptLockRefundScriptValues(script_out) + assert_cond(A == self.encodePubkey(Karl), 'Bad script pubkey') + assert_cond(B == self.encodePubkey(Karf), 'Bad script pubkey') + assert_cond(csv_val == csv_val_expect, 'Bad script csv value') + assert_cond(C == self.encodePubkey(Kaf), 'Bad script pubkey') + + fee_paid = swap_value - locked_coin + assert(fee_paid > 0) + + witness_bytes = len(prevout_script) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 2 # 2 empty witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', locked_coin, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + raise ValueError('Bad fee rate') + + return tx_hash, locked_coin + + def verifyLockRefundSpendTx(self, tx, + lock_refund_tx_id, prevout_script, + Kal, + prevout_value, feerate): + # Verify: + # Must have only one input with correct prevout (n is always 0) and sequence + # Must have only one output sending lock refund tx value - fee to leader's address, TODO: follower shouldn't need to verify destination addr + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock refund spend tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'nLockTime not 0') + assert_cond(len(tx.vin) == 1, 'tx doesn\'t have one input') + + assert_cond(tx.vin[0].nSequence == 0, 'Bad input nSequence') + assert_cond(len(tx.vin[0].scriptSig) == 0, 'Input scriptsig not empty') + assert_cond(tx.vin[0].prevout.hash == b2i(lock_refund_tx_id) and tx.vin[0].prevout.n == 0, 'Input prevout mismatch') + + assert_cond(len(tx.vout) == 1, 'tx doesn\'t have one output') + + p2wpkh = CScript([OP_0, hash160(self.encodePubkey(Kal))]) + locked_n = findOutput(tx, p2wpkh) + assert_cond(locked_n is not None, 'Output not found in lock refund spend tx') + tx_value = tx.vout[locked_n].nValue + + fee_paid = prevout_value - tx_value + assert(fee_paid > 0) + + witness_bytes = len(prevout_script) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', tx_value, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + raise ValueError('Bad fee rate') + + return True + + def verifyLockSpendTx(self, tx, + lock_tx, lock_tx_script, + a_pkhash_f, feerate): + # Verify: + # Must have only one input with correct prevout (n is always 0) and sequence + # Must have only one output with destination and amount + + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock spend tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'nLockTime not 0') + assert_cond(len(tx.vin) == 1, 'tx doesn\'t have one input') + + lock_tx_id = self.getTxHash(lock_tx) + + output_script = CScript([OP_0, hashlib.sha256(lock_tx_script).digest()]) + locked_n = findOutput(lock_tx, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = lock_tx.vout[locked_n].nValue + + assert_cond(tx.vin[0].nSequence == 0, 'Bad input nSequence') + assert_cond(len(tx.vin[0].scriptSig) == 0, 'Input scriptsig not empty') + assert_cond(tx.vin[0].prevout.hash == b2i(lock_tx_id) and tx.vin[0].prevout.n == locked_n, 'Input prevout mismatch') + + assert_cond(len(tx.vout) == 1, 'tx doesn\'t have one output') + p2wpkh = CScript([OP_0, a_pkhash_f]) + assert_cond(tx.vout[0].scriptPubKey == p2wpkh, 'Bad output destination') + + fee_paid = locked_coin - tx.vout[0].nValue + assert(fee_paid > 0) + + witness_bytes = len(lock_tx_script) + witness_bytes += 33 # sv, size + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', tx.vout[0].nValue, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + raise ValueError('Bad fee rate') + + return True + + def signTx(self, key_int, tx, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + + eck = ECKey() + eck.set(i2b(key_int), compressed=True) + + return eck.sign_ecdsa(sig_hash) + b'\x01' # 0x1 is SIGHASH_ALL + + def signTxOtVES(self, key_sign, key_encrypt, tx, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + return otves.EncSign(key_sign, key_encrypt, sig_hash) + + def verifyTxOtVES(self, tx, sig, Ks, Ke, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + return otves.EncVrfy(Ks, Ke, sig_hash, sig) + + def decryptOtVES(self, k, esig): + return otves.DecSig(k, esig) + b'\x01' # 0x1 is SIGHASH_ALL + + def verifyTxSig(self, tx, sig, K, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + + ecK = ECPubKey() + ecK.set_int(K.x(), K.y()) + return ecK.verify_ecdsa(sig[: -1], sig_hash) # Pop the hashtype byte + + def fundTx(self, tx, feerate): + feerate_str = format_amount(feerate, self.exp()) + rv = self.rpc_callback('fundrawtransaction', [ToHex(tx), {'feeRate': feerate_str}]) + return FromHex(tx, rv['hex']) + + def signTxWithWallet(self, tx): + rv = self.rpc_callback('signrawtransactionwithwallet', [ToHex(tx)]) + + return FromHex(tx, rv['hex']) + + def publishTx(self, tx): + return self.rpc_callback('sendrawtransaction', [ToHex(tx)]) + + def encodeTx(self, tx): + return tx.serialize() + + def loadTx(self, tx_bytes): + # Load tx from bytes to internal representation + tx = CTransaction() + tx.deserialize(BytesIO(tx_bytes)) + return tx + + def getTxHash(self, tx): + tx.rehash() + return i2b(tx.sha256) + + def getPubkeyHash(self, K): + return hash160(self.encodePubkey(K)) + + def getScriptDest(self, script): + return CScript([OP_0, hashlib.sha256(script).digest()]) + + def getPkDest(self, K): + return CScript([OP_0, self.getPubkeyHash(K)]) + + def scanTxOutset(self, dest): + return self.rpc_callback('scantxoutset', ['start', ['raw({})'.format(dest.hex())]]) + + def getTransaction(self, txid): + try: + return self.rpc_callback('getrawtransaction', [txid.hex()]) + except Exception as ex: + # TODO: filter errors + return None + + def setTxSignature(self, tx, stack): + tx.wit.vtxinwit.clear() + tx.wit.vtxinwit.append(CTxInWitness()) + tx.wit.vtxinwit[0].scriptWitness.stack = stack + return True + + def extractLeaderSig(self, tx): + return tx.wit.vtxinwit[0].scriptWitness.stack[1] + + def extractFollowerSig(self, tx): + return tx.wit.vtxinwit[0].scriptWitness.stack[2] + + def createBLockTx(self, Kbs, output_amount): + tx = CTransaction() + tx.nVersion = self.txVersion() + p2wpkh = self.getPkDest(Kbs) + tx.vout.append(self.txoType(output_amount, p2wpkh)) + return tx + + def publishBLockTx(self, Kbv, Kbs, output_amount, feerate): + b_lock_tx = self.createBLockTx(Kbs, output_amount) + + b_lock_tx = self.fundTx(b_lock_tx, feerate) + b_lock_tx_id = self.getTxHash(b_lock_tx) + b_lock_tx = self.signTxWithWallet(b_lock_tx) + + return self.publishTx(b_lock_tx) + + def recoverEncKey(self, esig, sig, K): + return otves.RecoverEncKey(esig, sig[:-1], K) # Strip sighash type + + def getTxVSize(self, tx, add_bytes=0, add_witness_bytes=0): + wsf = self.witnessScaleFactor() + len_full = len(tx.serialize_with_witness()) + add_bytes + add_witness_bytes + len_nwit = len(tx.serialize_without_witness()) + add_bytes + weight = len_nwit * (wsf - 1) + len_full + return (weight + wsf - 1) // wsf + + def findTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed, restore_height): + raw_dest = self.getPkDest(Kbs) + + rv = self.scanTxOutset(raw_dest) + print('scanTxOutset', dumpj(rv)) + + for utxo in rv['unspents']: + if 'height' in utxo and utxo['height'] > 0 and rv['height'] - utxo['height'] > cb_block_confirmed: + if utxo['amount'] * COIN != cb_swap_value: + logging.warning('Found output to lock tx pubkey of incorrect value: %s', str(utxo['amount'])) + else: + return True + return False + + def waitForLockTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed): + + raw_dest = self.getPkDest(Kbs) + + for i in range(20): + time.sleep(1) + rv = self.scanTxOutset(raw_dest) + print('scanTxOutset', dumpj(rv)) + + for utxo in rv['unspents']: + if 'height' in utxo and utxo['height'] > 0 and rv['height'] - utxo['height'] > cb_block_confirmed: + + if utxo['amount'] * COIN != cb_swap_value: + logging.warning('Found output to lock tx pubkey of incorrect value: %s', str(utxo['amount'])) + else: + return True + return False + + def spendBLockTx(self, address_to, kbv, kbs, cb_swap_value, b_fee, restore_height): + print('TODO: spendBLockTx') + + +def testBTCInterface(): + print('testBTCInterface') + script_bytes = bytes.fromhex('6382012088a820aaf125ff9a34a74c7a17f5e7ee9d07d17cc5e53a539f345d5f73baa7e79b65e28852210224019219ad43c47288c937ae508f26998dd81ec066827773db128fd5e262c04f21039a0fd752bd1a2234820707852e7a30253620052ecd162948a06532a817710b5952ae670114b2755221038689deba25c5578e5457ddadbaf8aeb8badf438dc22f540503dbd4ae10e14f512103c9c5d5acc996216d10852a72cd67c701bfd4b9137a4076350fd32f08db39575552ae68') + i = BTCInterface(None) + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes) + assert(csv_val == 20) + + script_bytes_t = script_bytes + bytes((0x00,)) + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad script length') + + script_bytes_t = script_bytes[:-1] + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad script length') + + script_bytes_t = bytes((0x00,)) + script_bytes[1:] + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad opcode') + + # Remove the csv value + script_part_a = script_bytes[:112] + script_part_b = script_bytes[114:] + + script_bytes_t = script_part_a + bytes((0x00,)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 0) + + script_bytes_t = script_part_a + bytes((OP_16,)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 16) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(17)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 17) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(-15)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == -15) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(4000)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 4000) + + max_pos = 0x7FFFFFFF + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(max_pos)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == max_pos) + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(max_pos - 1)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == max_pos - 1) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(max_pos + 1)) + script_part_b + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad scriptnum length') + + min_neg = -2147483647 + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(min_neg)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == min_neg) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(min_neg - 1)) + script_part_b + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad scriptnum length') + + print('Passed.') + + +if __name__ == "__main__": + testBTCInterface() diff --git a/basicswap/interface_ltc.py b/basicswap/interface_ltc.py new file mode 100644 index 0000000..c052766 --- /dev/null +++ b/basicswap/interface_ltc.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +from .interface_btc import BTCInterface + + +class LTCInterface(BTCInterface): + pass diff --git a/basicswap/interface_part.py b/basicswap/interface_part.py new file mode 100644 index 0000000..a12ce97 --- /dev/null +++ b/basicswap/interface_part.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +from .contrib.test_framework.messages import ( + CTxOutPart, +) + +from .interface_btc import BTCInterface +from .chainparams import CoinInterface +from .rpc import make_rpc_func + + +class PARTInterface(BTCInterface): + @staticmethod + def witnessScaleFactor(): + return 2 + + @staticmethod + def txVersion(): + return 0xa0 + + def __init__(self, coin_settings): + self.rpc_callback = make_rpc_func(coin_settings['rpcport'], coin_settings['rpcauth']) + self.txoType = CTxOutPart diff --git a/basicswap/interface_xmr.py b/basicswap/interface_xmr.py new file mode 100644 index 0000000..59289e1 --- /dev/null +++ b/basicswap/interface_xmr.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import time +import logging + +from .chainparams import CoinInterface +from .rpc_xmr import make_xmr_rpc_func, make_xmr_wallet_rpc_func + +XMR_COIN = 10 ** 12 + + +class XMRInterface(CoinInterface): + @staticmethod + def exp(): + return 12 + + @staticmethod + def nbk(): + return 32 + + @staticmethod + def nbK(): # No. of bytes requires to encode a public key + return 32 + + def __init__(self, coin_settings): + rpc_cb = make_xmr_rpc_func(coin_settings['rpcport']) + rpc_wallet_cb = make_xmr_wallet_rpc_func(coin_settings['walletrpcport'], coin_settings['walletrpcauth']) + + self.rpc_cb = rpc_cb # Not essential + self.rpc_wallet_cb = rpc_wallet_cb + + def getNewSecretKey(self): + return edu.get_secret() + + def pubkey(self, key): + return edf.scalarmult_B(key) + + def encodePubkey(self, pk): + return edu.encodepoint(pk) + + def decodePubkey(self, pke): + return edf.decodepoint(pke) + + def decodeKey(self, k): + i = b2i(k) + assert(i < edf.l and i > 8) + return i + + def sumKeys(self, ka, kb): + return (ka + kb) % edf.l + + def sumPubkeys(self, Ka, Kb): + return edf.edwards_add(Ka, Kb) + + def publishBLockTx(self, Kbv, Kbs, output_amount, feerate): + + shared_addr = xmr_util.encode_address(self.encodePubkey(Kbv), self.encodePubkey(Kbs)) + + # TODO: How to set feerate? + params = {'destinations': [{'amount': output_amount, 'address': shared_addr}]} + rv = self.rpc_wallet_cb('transfer', params) + logging.info('publishBLockTx %s to address_b58 %s', rv['tx_hash'], shared_addr) + + return rv['tx_hash'] + + def findTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed, restore_height): + Kbv_enc = self.encodePubkey(self.pubkey(kbv)) + address_b58 = xmr_util.encode_address(Kbv_enc, self.encodePubkey(Kbs)) + + try: + self.rpc_wallet_cb('close_wallet') + except Exception as e: + logging.warning('close_wallet failed %s', str(e)) + + params = { + 'restore_height': restore_height, + 'filename': address_b58, + 'address': address_b58, + 'viewkey': b2h(intToBytes32_le(kbv)), + } + + try: + rv = self.rpc_wallet_cb('open_wallet', {'filename': address_b58}) + except Exception as e: + rv = self.rpc_wallet_cb('generate_from_keys', params) + logging.info('generate_from_keys %s', dumpj(rv)) + rv = self.rpc_wallet_cb('open_wallet', {'filename': address_b58}) + + # Debug + try: + current_height = self.rpc_cb('get_block_count')['count'] + logging.info('findTxB XMR current_height %d\nAddress: %s', current_height, address_b58) + except Exception as e: + logging.info('rpc_cb failed %s', str(e)) + current_height = None # If the transfer is available it will be deep enough + + # For a while after opening the wallet rpc cmds return empty data + for i in range(5): + params = {'transfer_type': 'available'} + rv = self.rpc_wallet_cb('incoming_transfers', params) + if 'transfers' in rv: + for transfer in rv['transfers']: + if transfer['amount'] == cb_swap_value \ + and (current_height is None or current_height - transfer['block_height'] > cb_block_confirmed): + return True + time.sleep(1 + i) + + return False + + def waitForLockTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed, restore_height): + + Kbv_enc = self.encodePubkey(self.pubkey(kbv)) + address_b58 = xmr_util.encode_address(Kbv_enc, self.encodePubkey(Kbs)) + + try: + self.rpc_wallet_cb('close_wallet') + except Exception as e: + logging.warning('close_wallet failed %s', str(e)) + + params = { + 'filename': address_b58, + 'address': address_b58, + 'viewkey': b2h(intToBytes32_le(kbv)), + 'restore_height': restore_height, + } + self.rpc_wallet_cb('generate_from_keys', params) + + self.rpc_wallet_cb('open_wallet', {'filename': address_b58}) + # For a while after opening the wallet rpc cmds return empty data + + num_tries = 40 + for i in range(num_tries + 1): + try: + current_height = self.rpc_cb('get_block_count')['count'] + print('current_height', current_height) + except Exception as e: + logging.warning('rpc_cb failed %s', str(e)) + current_height = None # If the transfer is available it will be deep enough + + # TODO: Make accepting current_height == None a user selectable option + # Or look for all transfers and check height + + params = {'transfer_type': 'available'} + rv = self.rpc_wallet_cb('incoming_transfers', params) + print('rv', rv) + + if 'transfers' in rv: + for transfer in rv['transfers']: + if transfer['amount'] == cb_swap_value \ + and (current_height is None or current_height - transfer['block_height'] > cb_block_confirmed): + return True + + # TODO: Is it necessary to check the address? + + ''' + rv = self.rpc_wallet_cb('get_balance') + print('get_balance', rv) + + if 'per_subaddress' in rv: + for sub_addr in rv['per_subaddress']: + if sub_addr['address'] == address_b58: + + ''' + + if i >= num_tries: + raise ValueError('Balance not confirming on node') + time.sleep(1) + + return False + + def spendBLockTx(self, address_to, kbv, kbs, cb_swap_value, b_fee_rate, restore_height): + + Kbv_enc = self.encodePubkey(self.pubkey(kbv)) + Kbs_enc = self.encodePubkey(self.pubkey(kbs)) + address_b58 = xmr_util.encode_address(Kbv_enc, Kbs_enc) + + try: + self.rpc_wallet_cb('close_wallet') + except Exception as e: + logging.warning('close_wallet failed %s', str(e)) + + wallet_filename = address_b58 + '_spend' + + params = { + 'filename': wallet_filename, + 'address': address_b58, + 'viewkey': b2h(intToBytes32_le(kbv)), + 'spendkey': b2h(intToBytes32_le(kbs)), + 'restore_height': restore_height, + } + + try: + self.rpc_wallet_cb('open_wallet', {'filename': wallet_filename}) + except Exception as e: + rv = self.rpc_wallet_cb('generate_from_keys', params) + logging.info('generate_from_keys %s', dumpj(rv)) + self.rpc_wallet_cb('open_wallet', {'filename': wallet_filename}) + + # For a while after opening the wallet rpc cmds return empty data + for i in range(10): + rv = self.rpc_wallet_cb('get_balance') + print('get_balance', rv) + if rv['balance'] >= cb_swap_value: + break + + time.sleep(1 + i) + + # TODO: need a subfee from output option + b_fee = b_fee_rate * 10 # Guess + + num_tries = 20 + for i in range(1 + num_tries): + try: + params = {'destinations': [{'amount': cb_swap_value - b_fee, 'address': address_to}]} + rv = self.rpc_wallet_cb('transfer', params) + print('transfer', rv) + break + except Exception as e: + print('str(e)', str(e)) + if i >= num_tries: + raise ValueError('transfer failed.') + b_fee += b_fee_rate + logging.info('Raising fee to %d', b_fee) + + return rv['tx_hash'] diff --git a/basicswap/rpc.py b/basicswap/rpc.py index 3cf87cf..bcd7b59 100644 --- a/basicswap/rpc.py +++ b/basicswap/rpc.py @@ -93,8 +93,8 @@ class Jsonrpc(): def callrpc(rpc_port, auth, method, params=[], wallet=None): try: url = 'http://%s@127.0.0.1:%d/' % (auth, rpc_port) - if wallet: - url += 'wallet/' + wallet + if wallet is not None: + url += 'wallet/' + urllib.parse.quote(wallet) x = Jsonrpc(url) v = x.json_request(method, params) @@ -126,3 +126,14 @@ def callrpc_cli(bindir, datadir, chain, cmd, cli_bin='particl-cli'): except Exception: pass return r + + +def make_rpc_func(port, auth, wallet=None): + port = port + auth = auth + wallet = wallet + + def rpc_func(method, params=None, wallet_override=None): + nonlocal port, auth, wallet + return callrpc(port, auth, method, params, wallet if wallet_override is None else wallet_override) + return rpc_func diff --git a/basicswap/rpc_xmr.py b/basicswap/rpc_xmr.py new file mode 100644 index 0000000..9f15650 --- /dev/null +++ b/basicswap/rpc_xmr.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +import json +import requests + + +def callrpc_xmr(rpc_port, auth, method, params=[], path='json_rpc'): + # auth is a tuple: (username, password) + try: + url = 'http://127.0.0.1:{}/{}'.format(rpc_port, path) + request_body = { + 'method': method, + 'params': params, + 'id': 2, + 'jsonrpc': '2.0' + } + headers = { + 'content-type': 'application/json' + } + p = requests.post(url, data=json.dumps(request_body), auth=requests.auth.HTTPDigestAuth(auth[0], auth[1]), headers=headers) + r = json.loads(p.text) + except Exception as ex: + raise ValueError('RPC Server Error: {}'.format(str(ex))) + + if 'error' in r and r['error'] is not None: + raise ValueError('RPC error ' + str(r['error'])) + + return r['result'] + + +def callrpc_xmr_na(rpc_port, method, params=[], path='json_rpc'): + try: + url = 'http://127.0.0.1:{}/{}'.format(rpc_port, path) + request_body = { + 'method': method, + 'params': params, + 'id': 2, + 'jsonrpc': '2.0' + } + headers = { + 'content-type': 'application/json' + } + p = requests.post(url, data=json.dumps(request_body), headers=headers) + r = json.loads(p.text) + except Exception as ex: + raise ValueError('RPC Server Error: {}'.format(str(ex))) + + if 'error' in r and r['error'] is not None: + raise ValueError('RPC error ' + str(r['error'])) + + return r['result'] + + +def callrpc_xmr2(rpc_port, method, params=[]): + try: + url = 'http://127.0.0.1:{}/{}'.format(rpc_port, method) + headers = { + 'content-type': 'application/json' + } + p = requests.post(url, data=json.dumps(params), headers=headers) + r = json.loads(p.text) + except Exception as ex: + raise ValueError('RPC Server Error: {}'.format(str(ex))) + + return r + + +def make_xmr_rpc_func(port): + port = port + + def rpc_func(method, params=None, wallet=None): + nonlocal port + return callrpc_xmr_na(port, method, params) + return rpc_func + + +def make_xmr_wallet_rpc_func(port, auth): + port = port + auth = auth + + def rpc_func(method, params=None, wallet=None): + nonlocal port, auth + return callrpc_xmr(port, auth, method, params) + return rpc_func + diff --git a/basicswap/util.py b/basicswap/util.py index c700869..024a864 100644 --- a/basicswap/util.py +++ b/basicswap/util.py @@ -9,12 +9,15 @@ import json import hashlib from .contrib.segwit_addr import bech32_decode, convertbits, bech32_encode +OP_1 = 0x51 +OP_16 = 0x60 COIN = 100000000 DCOIN = decimal.Decimal(COIN) -def makeInt(v): - return int(dquantize(decimal.Decimal(v) * DCOIN).quantize(decimal.Decimal(1))) +def assert_cond(v, err='Bad opcode'): + if not v: + raise ValueError(err) def format8(i): @@ -188,3 +191,105 @@ def DeserialiseNum(b, o=0): if b[o + nb - 1] & 0x80: return -(v & ~(0x80 << (8 * (nb - 1)))) return v + + +def decodeScriptNum(script_bytes, o): + v = 0 + num_len = script_bytes[o] + if num_len >= OP_1 and num_len <= OP_16: + return((num_len - OP_1) + 1, 1) + + if num_len > 4: + raise ValueError('Bad scriptnum length') # Max 4 bytes + if num_len + o >= len(script_bytes): + raise ValueError('Bad script length') + o += 1 + for i in range(num_len): + b = script_bytes[o + i] + # Negative flag set in last byte, if num is positive and > 0x80 an extra 0x00 byte will be appended + if i == num_len - 1 and b & 0x80: + b &= (~(0x80) & 0xFF) + v += int(b) << 8 * i + v *= -1 + else: + v += int(b) << 8 * i + return(v, 1 + num_len) + + +def getCompactSizeLen(v): + # Compact Size + if v < 253: + return 1 + if v < 0xffff: # USHRT_MAX + return 3 + if v < 0xffffffff: # UINT_MAX + return 5 + if v < 0xffffffffffffffff: # UINT_MAX + return 9 + raise ValueError('Value too large') + + +def make_int(v, precision=8, r=0): # r = 0, no rounding, fail, r > 0 round up, r < 0 floor + if type(v) == float: + v = str(v) + elif type(v) == int: + return v * 10 ** precision + + ep = 10 ** precision + have_dp = False + rv = 0 + for c in v: + if c == '.': + rv *= ep + have_dp = True + continue + if not c.isdigit(): + raise ValueError('Invalid char') + if have_dp: + ep //= 10 + if ep <= 0: + if r == 0: + raise ValueError('Mantissa too long') + if r > 0: + # Round up + if int(c) > 4: + rv += 1 + break + + rv += ep * int(c) + else: + rv = rv * 10 + int(c) + if not have_dp: + rv *= ep + return rv + + +def validate_amount(amount, precision=8): + str_amount = str(amount) + has_decimal = False + for c in str_amount: + if c == '.' and not has_decimal: + has_decimal = True + continue + if not c.isdigit(): + raise ValueError('Invalid amount') + + ar = str_amount.split('.') + if len(ar) > 1 and len(ar[1]) > precision: + raise ValueError('Too many decimal places in amount {}'.format(str_amount)) + return True + + +def format_amount(i, display_precision, precision=None): + if precision is None: + precision = display_precision + ep = 10 ** precision + n = abs(i) + quotient = n // ep + remainder = n % ep + if display_precision != precision: + remainder %= (10 ** display_precision) + rv = '{}.{:0>{prec}}'.format(quotient, remainder, prec=display_precision) + if i < 0: + rv = '-' + rv + return rv diff --git a/basicswap/util_xmr.py b/basicswap/util_xmr.py new file mode 100644 index 0000000..75f5185 --- /dev/null +++ b/basicswap/util_xmr.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +import xmrswap.contrib.Keccak as Keccak +from .contrib.MoneroPy.base58 import encode as xmr_b58encode + + +def cn_fast_hash(s): + k = Keccak.Keccak() + return k.Keccak((len(s) * 8, s.hex()), 1088, 512, 32 * 8, False).lower() # r = bitrate = 1088, c = capacity, n = output length in bits + + +def encode_address(view_point, spend_point, version=18): + buf = bytes((version,)) + spend_point + view_point + h = cn_fast_hash(buf) + buf = buf + bytes.fromhex(h[0: 8]) + + return xmr_b58encode(buf.hex()) diff --git a/setup.py b/setup.py index 65e89dc..deed653 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ setuptools.setup( "sqlalchemy", "python-gnupg", "Jinja2", + "requests", ], entry_points={ "console_scripts": [ diff --git a/tests/basicswap/__init__.py b/tests/basicswap/__init__.py index 0c173cf..30104e0 100644 --- a/tests/basicswap/__init__.py +++ b/tests/basicswap/__init__.py @@ -4,6 +4,7 @@ import tests.basicswap.test_other as test_other import tests.basicswap.test_prepare as test_prepare import tests.basicswap.test_run as test_run import tests.basicswap.test_reload as test_reload +import tests.basicswap.test_xmr as test_xmr def test_suite(): @@ -12,5 +13,6 @@ def test_suite(): suite.addTests(loader.loadTestsFromModule(test_prepare)) suite.addTests(loader.loadTestsFromModule(test_run)) suite.addTests(loader.loadTestsFromModule(test_reload)) + suite.addTests(loader.loadTestsFromModule(test_xmr)) return suite diff --git a/tests/basicswap/common.py b/tests/basicswap/common.py new file mode 100644 index 0000000..3139587 --- /dev/null +++ b/tests/basicswap/common.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE.txt or http://www.opensource.org/licenses/mit-license.php. + +def checkForks(ro): + if 'bip9_softforks' in ro: + assert(ro['bip9_softforks']['csv']['status'] == 'active') + assert(ro['bip9_softforks']['segwit']['status'] == 'active') + else: + assert(ro['softforks']['csv']['active']) + assert(ro['softforks']['segwit']['active']) diff --git a/tests/basicswap/test_other.py b/tests/basicswap/test_other.py index 08971ea..5d1c92b 100644 --- a/tests/basicswap/test_other.py +++ b/tests/basicswap/test_other.py @@ -9,7 +9,7 @@ import unittest from basicswap.util import ( SerialiseNum, DeserialiseNum, - makeInt, + make_int, format8, ) from basicswap.basicswap import ( @@ -57,19 +57,28 @@ class Test(unittest.TestCase): decoded = decodeSequence(encoded) assert(decoded == blocks_val) - def test_makeInt(self): + def test_make_int(self): def test_case(vs, vf, expect_int): - assert(makeInt(vs) == expect_int) - assert(makeInt(vf) == expect_int) - vs_out = format8(makeInt(vs)) + i = make_int(vs) + assert(i == expect_int and isinstance(i, int)) + i = make_int(vf) + assert(i == expect_int and isinstance(i, int)) + vs_out = format_amount(i, 8) # Strip for i in range(7): if vs_out[-1] == '0': vs_out = vs_out[:-1] - assert(vs_out == vs) + if '.' in vs: + assert(vs_out == vs) + else: + assert(vs_out[:-2] == vs) + test_case('0', 0, 0) + test_case('1', 1, 100000000) + test_case('10', 10, 1000000000) test_case('0.00899999', 0.00899999, 899999) test_case('899999.0', 899999.0, 89999900000000) test_case('899999.00899999', 899999.00899999, 89999900899999) + test_case('0.0', 0.0, 0) test_case('1.0', 1.0, 100000000) test_case('1.1', 1.1, 110000000) test_case('1.2', 1.2, 120000000) @@ -79,6 +88,52 @@ class Test(unittest.TestCase): test_case('0.123', 0.123, 12300000) test_case('123000.000123', 123000.000123, 12300000012300) + try: + make_int('0.123456789') + assert(False) + except Exception as e: + assert(str(e) == 'Mantissa too long') + validate_amount('0.12345678') + + # floor + assert(make_int('0.123456789', r=-1) == 12345678) + # Round up + assert(make_int('0.123456789', r=1) == 12345679) + + def test_make_int12(self): + def test_case(vs, vf, expect_int): + i = make_int(vs, 12) + assert(i == expect_int and isinstance(i, int)) + i = make_int(vf, 12) + assert(i == expect_int and isinstance(i, int)) + vs_out = format_amount(i, 12) + # Strip + for i in range(7): + if vs_out[-1] == '0': + vs_out = vs_out[:-1] + if '.' in vs: + assert(vs_out == vs) + else: + assert(vs_out[:-2] == vs) + test_case('0.123456789', 0.123456789, 123456789000) + test_case('0.123456789123', 0.123456789123, 123456789123) + try: + make_int('0.1234567891234', 12) + assert(False) + except Exception as e: + assert(str(e) == 'Mantissa too long') + validate_amount('0.123456789123', 12) + try: + validate_amount('0.1234567891234', 12) + assert(False) + except Exception as e: + assert('Too many decimal places' in str(e)) + try: + validate_amount(0.1234567891234, 12) + assert(False) + except Exception as e: + assert('Too many decimal places' in str(e)) + if __name__ == '__main__': unittest.main() diff --git a/tests/basicswap/test_run.py b/tests/basicswap/test_run.py index d130870..43a0426 100644 --- a/tests/basicswap/test_run.py +++ b/tests/basicswap/test_run.py @@ -48,6 +48,9 @@ from basicswap.contrib.key import ( from basicswap.http_server import ( HttpThread, ) +from tests.basicswap.common import ( + checkForks, +) from bin.basicswap_run import startDaemon logger = logging.getLogger() @@ -205,15 +208,6 @@ def run_loop(self): btcRpc('generatetoaddress 1 {}'.format(self.btc_addr)) -def checkForks(ro): - if 'bip9_softforks' in ro: - assert(ro['bip9_softforks']['csv']['status'] == 'active') - assert(ro['bip9_softforks']['segwit']['status'] == 'active') - else: - assert(ro['softforks']['csv']['active']) - assert(ro['softforks']['segwit']['active']) - - class Test(unittest.TestCase): @classmethod diff --git a/tests/basicswap/test_xmr.py b/tests/basicswap/test_xmr.py new file mode 100644 index 0000000..06933cc --- /dev/null +++ b/tests/basicswap/test_xmr.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import os +import sys +import unittest +import json +import logging +import shutil +import time +import signal +import threading +from urllib.request import urlopen +from coincurve.ecdsaotves import ( + ecdsaotves_enc_sign, + ecdsaotves_enc_verify, + ecdsaotves_dec_sig, + ecdsaotves_rec_enc_key) +from coincurve.dleag import ( + dleag_prove, + dleag_verify) + +import basicswap.config as cfg +from basicswap.basicswap import ( + BasicSwap, + Coins, + SwapTypes, + BidStates, + TxStates, + SEQUENCE_LOCK_BLOCKS, +) +from basicswap.util import ( + COIN, + toWIF, + dumpje, +) +from basicswap.rpc import ( + callrpc_cli, + waitForRPC, +) +from basicswap.contrib.key import ( + ECKey, +) +from basicswap.http_server import ( + HttpThread, +) +from bin.basicswap_run import startDaemon + +logger = logging.getLogger() +logger.level = logging.DEBUG +if not len(logger.handlers): + logger.addHandler(logging.StreamHandler(sys.stdout)) + +NUM_NODES = 3 +BASE_PORT = 14792 +BASE_RPC_PORT = 19792 +BASE_ZMQ_PORT = 20792 +PREFIX_SECRET_KEY_REGTEST = 0x2e +TEST_HTML_PORT = 1800 +stop_test = False + + + +def prepareOtherDir(datadir, nodeId, conf_file='litecoin.conf'): + node_dir = os.path.join(datadir, str(nodeId)) + if not os.path.exists(node_dir): + os.makedirs(node_dir) + filePath = os.path.join(node_dir, conf_file) + + with open(filePath, 'w+') as fp: + fp.write('regtest=1\n') + fp.write('[regtest]\n') + fp.write('port=' + str(BASE_PORT + nodeId) + '\n') + fp.write('rpcport=' + str(BASE_RPC_PORT + nodeId) + '\n') + + fp.write('daemon=0\n') + fp.write('printtoconsole=0\n') + fp.write('server=1\n') + fp.write('discover=0\n') + fp.write('listenonion=0\n') + fp.write('bind=127.0.0.1\n') + fp.write('findpeers=0\n') + fp.write('debug=1\n') + fp.write('debugexclude=libevent\n') + fp.write('fallbackfee=0.0002\n') + + fp.write('acceptnonstdtxn=0\n') + + +def prepareDir(datadir, nodeId, network_key, network_pubkey): + node_dir = os.path.join(datadir, str(nodeId)) + if not os.path.exists(node_dir): + os.makedirs(node_dir) + filePath = os.path.join(node_dir, 'particl.conf') + + with open(filePath, 'w+') as fp: + fp.write('regtest=1\n') + fp.write('[regtest]\n') + fp.write('port=' + str(BASE_PORT + nodeId) + '\n') + fp.write('rpcport=' + str(BASE_RPC_PORT + nodeId) + '\n') + + fp.write('daemon=0\n') + fp.write('printtoconsole=0\n') + fp.write('server=1\n') + fp.write('discover=0\n') + fp.write('listenonion=0\n') + fp.write('bind=127.0.0.1\n') + fp.write('findpeers=0\n') + fp.write('debug=1\n') + fp.write('debugexclude=libevent\n') + fp.write('zmqpubsmsg=tcp://127.0.0.1:' + str(BASE_ZMQ_PORT + nodeId) + '\n') + + fp.write('acceptnonstdtxn=0\n') + fp.write('minstakeinterval=5\n') + + for i in range(0, NUM_NODES): + if nodeId == i: + continue + fp.write('addnode=127.0.0.1:%d\n' % (BASE_PORT + i)) + + if nodeId < 2: + fp.write('spentindex=1\n') + fp.write('txindex=1\n') + + basicswap_dir = os.path.join(datadir, str(nodeId), 'basicswap') + if not os.path.exists(basicswap_dir): + os.makedirs(basicswap_dir) + + ltcdatadir = os.path.join(datadir, str(LTC_NODE)) + btcdatadir = os.path.join(datadir, str(BTC_NODE)) + settings_path = os.path.join(basicswap_dir, cfg.CONFIG_FILENAME) + settings = { + 'zmqhost': 'tcp://127.0.0.1', + 'zmqport': BASE_ZMQ_PORT + nodeId, + 'htmlhost': 'localhost', + 'htmlport': 12700 + nodeId, + 'network_key': network_key, + 'network_pubkey': network_pubkey, + 'chainclients': { + 'particl': { + 'connection_type': 'rpc', + 'manage_daemon': False, + 'rpcport': BASE_RPC_PORT + nodeId, + 'datadir': node_dir, + 'bindir': cfg.PARTICL_BINDIR, + 'blocks_confirmed': 2, # Faster testing + }, + 'litecoin': { + 'connection_type': 'rpc', + 'manage_daemon': False, + 'rpcport': BASE_RPC_PORT + LTC_NODE, + 'datadir': ltcdatadir, + 'bindir': cfg.LITECOIN_BINDIR, + # 'use_segwit': True, + }, + 'bitcoin': { + 'connection_type': 'rpc', + 'manage_daemon': False, + 'rpcport': BASE_RPC_PORT + BTC_NODE, + 'datadir': btcdatadir, + 'bindir': cfg.BITCOIN_BINDIR, + 'use_segwit': True, + } + }, + 'check_progress_seconds': 2, + 'check_watched_seconds': 4, + 'check_expired_seconds': 60, + 'check_events_seconds': 1, + 'min_delay_auto_accept': 1, + 'max_delay_auto_accept': 5 + } + with open(settings_path, 'w') as fp: + json.dump(settings, fp, indent=4) + + +def partRpc(cmd, node_id=0): + return callrpc_cli(cfg.PARTICL_BINDIR, os.path.join(cfg.TEST_DATADIRS, str(node_id)), 'regtest', cmd, cfg.PARTICL_CLI) + + +def btcRpc(cmd): + return callrpc_cli(cfg.BITCOIN_BINDIR, os.path.join(cfg.TEST_DATADIRS, str(BTC_NODE)), 'regtest', cmd, cfg.BITCOIN_CLI) + + +def signal_handler(sig, frame): + global stop_test + print('signal {} detected.'.format(sig)) + stop_test = True + + +def run_loop(self): + while not stop_test: + time.sleep(1) + for c in self.swap_clients: + c.update() + btcRpc('generatetoaddress 1 {}'.format(self.btc_addr)) + + +def checkForks(ro): + if 'bip9_softforks' in ro: + assert(ro['bip9_softforks']['csv']['status'] == 'active') + assert(ro['bip9_softforks']['segwit']['status'] == 'active') + else: + assert(ro['softforks']['csv']['active']) + assert(ro['softforks']['segwit']['active']) + + +class Test(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(Test, cls).setUpClass() + + cls.swap_clients = [] + cls.xmr_daemons = [] + cls.xmr_wallet_auth = [] + + cls.part_stakelimit = 0 + cls.xmr_addr = None + + signal.signal(signal.SIGINT, signal_handler) + cls.update_thread = threading.Thread(target=run_loop, args=(cls,)) + cls.update_thread.start() + + @classmethod + def tearDownClass(cls): + global stop_test + logging.info('Finalising') + stop_test = True + cls.update_thread.join() + + super(Test, cls).tearDownClass() + + def test_01_part_xmr(self): + logging.info('---------- Test PART to XMR') + #swap_clients = self.swap_clients + + #offer_id = swap_clients[0].postOffer(Coins.PART, Coins.XMR, 100 * COIN, 0.5 * COIN, 100 * COIN, SwapTypes.SELLER_FIRST) + + + +if __name__ == '__main__': + unittest.main()