From c7269c051439ff74ef9d3e7605bfe66d03e34aab Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Mon, 18 Dec 2017 14:41:21 -0500 Subject: [PATCH] More robust FS challenge computation. --- .../BulletProofs/MultiBulletproof.java | 554 ++++++++++++++++++ 1 file changed, 554 insertions(+) create mode 100644 source-code/BulletProofs/MultiBulletproof.java diff --git a/source-code/BulletProofs/MultiBulletproof.java b/source-code/BulletProofs/MultiBulletproof.java new file mode 100644 index 0000000..48277ef --- /dev/null +++ b/source-code/BulletProofs/MultiBulletproof.java @@ -0,0 +1,554 @@ +// NOTE: this interchanges the roles of G and H to match other code's behavior + +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import java.math.BigInteger; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class MultiBulletproof +{ + private static int N; + private static int logMN; + private static int M; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V[]; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Curve25519Point[] L; + private Curve25519Point[] R; + private Scalar a; + private Scalar b; + private Scalar t; + + public ProofTuple(Curve25519Point V[], Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Curve25519Point[] L, Curve25519Point[] R, Scalar a, Scalar b, Scalar t) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.L = L; + this.R = R; + this.a = a; + this.b = b; + this.t = t; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + assert a.length == M*N && b.length == M*N; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < M*N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Compute a custom vector-scalar commitment */ + public static Curve25519Point VectorExponentCustom(Curve25519Point[] A, Curve25519Point[] B, Scalar[] a, Scalar[] b) + { + assert a.length == A.length && b.length == B.length && a.length == b.length; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < a.length; i++) + { + Result = Result.add(A[i].scalarMultiply(a[i])); + Result = Result.add(B[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x, int size) + { + Scalar[] result = new Scalar[size]; + for (int i = 0; i < size; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar result = Scalar.ZERO; + for (int i = 0; i < a.length; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Given two curvepoint arrays, construct the Hadamard product */ + public static Curve25519Point[] Hadamard2(Curve25519Point[] A, Curve25519Point[] B) + { + assert A.length == B.length; + + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].add(B[i]); + } + return Result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Exponentiate a curve vector by a scalar */ + public static Curve25519Point[] VectorScalar2(Curve25519Point[] A, Scalar x) + { + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].scalarMultiply(x); + } + return Result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + + assert x.mul(inverse).equals(Scalar.ONE); + return inverse; + } + + /* Compute the slice of a curvepoint vector */ + public static Curve25519Point[] CurveSlice(Curve25519Point[] a, int start, int stop) + { + Curve25519Point[] Result = new Curve25519Point[stop-start]; + for (int i = start; i < stop; i++) + { + Result[i-start] = a[i]; + } + return Result; + } + + /* Compute the slice of a scalar vector */ + public static Scalar[] ScalarSlice(Scalar[] a, int start, int stop) + { + Scalar[] result = new Scalar[stop-start]; + for (int i = start; i < stop; i++) + { + result[i-start] = a[i]; + } + return result; + } + + /* Construct an aggregate range proof */ + public static ProofTuple PROVE(Scalar[] v, Scalar[] gamma) + { + Curve25519Point[] V = new Curve25519Point[M]; + + V[0] = H.scalarMultiply(v[0]).add(G.scalarMultiply(gamma[0])); + // This hash is updated for Fiat-Shamir throughout the proof + Scalar hashCache = hashToScalar(V[0].toBytes()); + for (int j = 1; j < M; j++) + { + V[j] = H.scalarMultiply(v[j]).add(G.scalarMultiply(gamma[j])); + hashCache = hashToScalar(concat(hashCache.bytes,V[j].toBytes())); + } + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[M*N]; + Scalar[] aR = new Scalar[M*N]; + + for (int j = 0; j < M; j++) + { + BigInteger tempV = v[j].toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[j*N+i] = Scalar.ZERO; + } + else + { + aL[j*N+i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[j*N+i] = aL[j*N+i].sub(Scalar.ONE); + } + } + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(G.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[M*N]; + Scalar[] sR = new Scalar[M*N]; + for (int i = 0; i < M*N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(G.scalarMultiply(rho)); + + // PAPER LINES 43-45 + hashCache = hashToScalar(concat(hashCache.bytes,A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + + // Polynomial construction by coefficients + Scalar[] l0; + Scalar[] l1; + Scalar[] r0; + Scalar[] r1; + + l0 = VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE,M*N),z)); + l1 = sL; + + // This computes the ugly sum/concatenation from PAPER LINE 65 + Scalar[] zerosTwos = new Scalar[M*N]; + for (int i = 0; i < M*N; i++) + { + zerosTwos[i] = Scalar.ZERO; + for (int j = 1; j <= M; j++) // note this starts from 1 + { + Scalar temp = Scalar.ZERO; + if (i >= (j-1)*N && i < j*N) + temp = Scalar.TWO.pow(i-(j-1)*N); // exponent ranges from 0..N-1 + zerosTwos[i] = zerosTwos[i].add(z.pow(1+j).mul(temp)); + } + } + + r0 = VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE,M*N),z)); + r0 = Hadamard(r0,VectorPowers(y,M*N)); + r0 = VectorAdd(r0,zerosTwos); + r1 = Hadamard(VectorPowers(y,M*N),sR); + + // Polynomial construction before PAPER LINE 46 + Scalar t0 = InnerProduct(l0,r0); + Scalar t1 = InnerProduct(l0,r1).add(InnerProduct(l1,r0)); + Scalar t2 = InnerProduct(l1,r1); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = H.scalarMultiply(t1).add(G.scalarMultiply(tau1)); + Curve25519Point T2 = H.scalarMultiply(t2).add(G.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,T2.toBytes())); + Scalar x = hashCache; + + // PAPER LINES 52-53 + Scalar taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + for (int j = 1; j <= M; j++) // note this starts from 1 + { + taux = taux.add(z.pow(1+j).mul(gamma[j-1])); + } + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = l0; + l = VectorAdd(l,VectorScalar(l1,x)); + Scalar[] r = r0; + r = VectorAdd(r,VectorScalar(r1,x)); + + Scalar t = InnerProduct(l,r); + + // PAPER LINES 32-33 + hashCache = hashToScalar(concat(hashCache.bytes,x.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,taux.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,mu.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,t.bytes)); + Scalar x_ip = hashCache; + + // These are used in the inner product rounds + int nprime = M*N; + Curve25519Point[] Gprime = new Curve25519Point[M*N]; + Curve25519Point[] Hprime = new Curve25519Point[M*N]; + Scalar[] aprime = new Scalar[M*N]; + Scalar[] bprime = new Scalar[M*N]; + for (int i = 0; i < M*N; i++) + { + Gprime[i] = Gi[i]; + Hprime[i] = Hi[i].scalarMultiply(Invert(y).pow(i)); + aprime[i] = l[i]; + bprime[i] = r[i]; + } + Curve25519Point[] L = new Curve25519Point[logMN]; + Curve25519Point[] R = new Curve25519Point[logMN]; + int round = 0; // track the index based on number of rounds + Scalar[] w = new Scalar[logMN]; // this is the challenge x in the inner product protocol + + // PAPER LINE 13 + while (nprime > 1) + { + // PAPER LINE 15 + nprime /= 2; + + // PAPER LINES 16-17 + Scalar cL = InnerProduct(ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)); + Scalar cR = InnerProduct(ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)); + + // PAPER LINES 18-19 + L[round] = VectorExponentCustom(CurveSlice(Gprime,nprime,Gprime.length),CurveSlice(Hprime,0,nprime),ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)).add(H.scalarMultiply(cL.mul(x_ip))); + R[round] = VectorExponentCustom(CurveSlice(Gprime,0,nprime),CurveSlice(Hprime,nprime,Hprime.length),ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)).add(H.scalarMultiply(cR.mul(x_ip))); + + // PAPER LINES 21-22 + hashCache = hashToScalar(concat(hashCache.bytes,L[round].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,R[round].toBytes())); + w[round] = hashCache; + + // PAPER LINES 24-25 + Gprime = Hadamard2(VectorScalar2(CurveSlice(Gprime,0,nprime),Invert(w[round])),VectorScalar2(CurveSlice(Gprime,nprime,Gprime.length),w[round])); + Hprime = Hadamard2(VectorScalar2(CurveSlice(Hprime,0,nprime),w[round]),VectorScalar2(CurveSlice(Hprime,nprime,Hprime.length),Invert(w[round]))); + + // PAPER LINES 28-29 + aprime = VectorAdd(VectorScalar(ScalarSlice(aprime,0,nprime),w[round]),VectorScalar(ScalarSlice(aprime,nprime,aprime.length),Invert(w[round]))); + bprime = VectorAdd(VectorScalar(ScalarSlice(bprime,0,nprime),Invert(w[round])),VectorScalar(ScalarSlice(bprime,nprime,bprime.length),w[round])); + + round += 1; + } + + // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) + return new ProofTuple(V,A,S,T1,T2,taux,mu,L,R,aprime[0],bprime[0],t); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar hashCache = hashToScalar(proof.V[0].toBytes()); + for (int j = 1; j < M; j++) + hashCache = hashToScalar(concat(hashCache.bytes,proof.V[j].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T2.toBytes())); + Scalar x = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,x.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.taux.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.mu.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.t.bytes)); + Scalar x_ip = hashCache; + + // PAPER LINE 61 + Curve25519Point L61Left = G.scalarMultiply(proof.taux).add(H.scalarMultiply(proof.t)); + + Scalar k = Scalar.ZERO.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE,M*N),VectorPowers(y,M*N)))); + for (int j = 1; j <= M; j++) // note this starts from 1 + { + k = k.sub(z.pow(j+2).mul(InnerProduct(VectorPowers(Scalar.ONE,N),VectorPowers(Scalar.TWO,N)))); + } + + Curve25519Point L61Right = H.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE,M*N),VectorPowers(y,M*N))))); + for (int j = 0; j < M; j++) + { + L61Right = L61Right.add(proof.V[j].scalarMultiply(z.pow(j+2))); + } + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + return false; + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + // PAPER LINES 21-22 + // The inner product challenges are computed per round + Scalar[] w = new Scalar[logMN]; + hashCache = hashToScalar(concat(hashCache.bytes,proof.L[0].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.R[0].toBytes())); + w[0] = hashCache; + if (logMN > 1) + { + for (int i = 1; i < logMN; i++) + { + hashCache = hashToScalar(concat(hashCache.bytes,proof.L[i].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.R[i].toBytes())); + w[i] = hashCache; + } + } + + // Basically PAPER LINES 24-25 + // Compute the curvepoints from G[i] and H[i] + Curve25519Point InnerProdG = Curve25519Point.ZERO; + Curve25519Point InnerProdH = Curve25519Point.ZERO; + for (int i = 0; i < M*N; i++) + { + // Convert the index to binary IN REVERSE and construct the scalar exponent + int index = i; + Scalar gScalar = proof.a; + Scalar hScalar = proof.b.mul(Invert(y).pow(i)); + + for (int j = logMN-1; j >= 0; j--) + { + int J = w.length - j - 1; // because this is done in reverse bit order + int basePow = (int) Math.pow(2,j); // assumes we don't get too big + if (index / basePow == 0) // bit is zero + { + gScalar = gScalar.mul(Invert(w[J])); + hScalar = hScalar.mul(w[J]); + } + else // bit is one + { + gScalar = gScalar.mul(w[J]); + hScalar = hScalar.mul(Invert(w[J])); + index -= basePow; + } + } + + gScalar = gScalar.add(z); + hScalar = hScalar.sub(z.mul(y.pow(i)).add(z.pow(2+i/N).mul(Scalar.TWO.pow(i%N))).mul(Invert(y).pow(i))); + + // Now compute the basepoint's scalar multiplication + // Each of these could be written as a multiexp operation instead + InnerProdG = InnerProdG.add(Gi[i].scalarMultiply(gScalar)); + InnerProdH = InnerProdH.add(Hi[i].scalarMultiply(hScalar)); + } + + // PAPER LINE 26 + Curve25519Point Pprime = P.add(G.scalarMultiply(Scalar.ZERO.sub(proof.mu))); + + for (int i = 0; i < logMN; i++) + { + Pprime = Pprime.add(proof.L[i].scalarMultiply(w[i].sq())); + Pprime = Pprime.add(proof.R[i].scalarMultiply(Invert(w[i]).sq())); + } + Pprime = Pprime.add(H.scalarMultiply(proof.t.mul(x_ip))); + + if (!Pprime.equals(InnerProdG.add(InnerProdH).add(H.scalarMultiply(proof.a.mul(proof.b).mul(x_ip))))) + return false; + + return true; + } + + public static void main(String[] args) + { + // Test parameters + N = 64; // number of bits in amount range (so amounts are 0..2^(N-1)) + M = 4; // number of outputs (must be a power of 2) + logMN = 8; // must be manually set to log_2(MN) + int TRIALS = 25; // number of randomized trials to run + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[M*N]; + Hi = new Curve25519Point[M*N]; + for (int i = 0; i < M*N; i++) + { + Gi[i] = getHpnGLookup(2*i); + Hi[i] = getHpnGLookup(2*i+1); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int count = 0; + + Scalar[] amounts = new Scalar[M]; + Scalar[] masks = new Scalar[M]; + while (count < TRIALS) + { + for (int j = 0; j < M; j++) + { + long amount = -1L; + while (amount > Math.pow(2,N)-1 || amount < 0L) + amount = rando.nextLong(); + amounts[j] = new Scalar(BigInteger.valueOf(amount)); + masks[j] = randomScalar(); + } + + ProofTuple proof = PROVE(amounts,masks); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +}