From ba0c1dd037d35039558a1be69397d95a605067a1 Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Tue, 28 Nov 2017 10:03:50 -0700 Subject: [PATCH 01/11] BP Code --- .../hodl/bulletproof/LinearBulletproof.java | 348 ++++++++++++ .../hodl/bulletproof/LogBulletproof.java | 497 ++++++++++++++++++ .../bulletproof/OptimizedLogBulletproof.java | 487 +++++++++++++++++ 3 files changed, 1332 insertions(+) create mode 100644 source-code/StringCT-java/src/how/monero/hodl/bulletproof/LinearBulletproof.java create mode 100644 source-code/StringCT-java/src/how/monero/hodl/bulletproof/LogBulletproof.java create mode 100644 source-code/StringCT-java/src/how/monero/hodl/bulletproof/OptimizedLogBulletproof.java diff --git a/source-code/StringCT-java/src/how/monero/hodl/bulletproof/LinearBulletproof.java b/source-code/StringCT-java/src/how/monero/hodl/bulletproof/LinearBulletproof.java new file mode 100644 index 0000000..42c1d8a --- /dev/null +++ b/source-code/StringCT-java/src/how/monero/hodl/bulletproof/LinearBulletproof.java @@ -0,0 +1,348 @@ +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import how.monero.hodl.util.ByteUtil; +import java.math.BigInteger; +import how.monero.hodl.util.VarInt; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class LinearBulletproof +{ + private static int N; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Scalar[] l; + private Scalar[] r; + + public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Scalar[] l, Scalar[] r) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.l = l; + this.r = r; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + Scalar result = Scalar.ZERO; + for (int i = 0; i < N; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + assert x.mul(inverse).equals(Scalar.ONE); + + return inverse; + } + + /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ + public static ProofTuple PROVE(Scalar v, Scalar gamma) + { + Curve25519Point V = G.scalarMultiply(v).add(H.scalarMultiply(gamma)); + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[N]; + Scalar[] aR = new Scalar[N]; + + BigInteger tempV = v.toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[i] = Scalar.ZERO; + } + else + { + aL[i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[i] = aL[i].sub(Scalar.ONE); + } + + // DEBUG: Test to ensure this recovers the value + BigInteger test_aL = BigInteger.ZERO; + BigInteger test_aR = BigInteger.ZERO; + for (int i = 0; i < N; i++) + { + if (aL[i].equals(Scalar.ONE)) + test_aL = test_aL.add(BigInteger.valueOf(2).pow(i)); + if (aR[i].equals(Scalar.ZERO)) + test_aR = test_aR.add(BigInteger.valueOf(2).pow(i)); + } + assert test_aL.equals(v.toBigInteger()); + assert test_aR.equals(v.toBigInteger()); + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(H.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[N]; + Scalar[] sR = new Scalar[N]; + for (int i = 0; i < N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(H.scalarMultiply(rho)); + + // PAPER LINES 43-45 + Scalar y = hashToScalar(concat(A.toBytes(),S.toBytes())); + Scalar z = hashToScalar(y.bytes); + + Scalar t0 = Scalar.ZERO; + Scalar t1 = Scalar.ZERO; + Scalar t2 = Scalar.ZERO; + + t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + t0 = t0.add(z.sq().mul(v)); + Scalar k = Scalar.ZERO; + k = k.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + k = k.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + t0 = t0.add(k); + + // DEBUG: Test the value of t0 has the correct form + Scalar test_t0 = Scalar.ZERO; + test_t0 = test_t0.add(InnerProduct(aL,Hadamard(aR,VectorPowers(y)))); + test_t0 = test_t0.add(z.mul(InnerProduct(VectorSubtract(aL,aR),VectorPowers(y)))); + test_t0 = test_t0.add(z.sq().mul(InnerProduct(VectorPowers(Scalar.TWO),aL))); + test_t0 = test_t0.add(k); + assert test_t0.equals(t0); + + t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR))); + t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())))); + t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR))); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = G.scalarMultiply(t1).add(H.scalarMultiply(tau1)); + Curve25519Point T2 = G.scalarMultiply(t2).add(H.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + Scalar x = hashToScalar(concat(z.bytes,T1.toBytes(),T2.toBytes())); + + // PAPER LINES 52-53 + Scalar taux = Scalar.ZERO; + taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + taux = taux.add(gamma.mul(z.sq())); + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = new Scalar[N]; + Scalar[] r = new Scalar[N]; + + l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x)); + r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())); + + // DEBUG: Test if the l and r vectors match the polynomial forms + Scalar test_t = Scalar.ZERO; + test_t = test_t.add(t0).add(t1.mul(x)); + test_t = test_t.add(t2.mul(x.sq())); + assert test_t.equals(InnerProduct(l,r)); + + // PAPER LINE 58 + return new ProofTuple(V,A,S,T1,T2,taux,mu,l,r); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar y = hashToScalar(concat(proof.A.toBytes(),proof.S.toBytes())); + Scalar z = hashToScalar(y.bytes); + Scalar x = hashToScalar(concat(z.bytes,proof.T1.toBytes(),proof.T2.toBytes())); + + // PAPER LINE 60 + Scalar t = InnerProduct(proof.l,proof.r); + + // PAPER LINE 61 + Curve25519Point L61Left = H.scalarMultiply(proof.taux).add(G.scalarMultiply(t)); + + Scalar k = Scalar.ZERO; + k = k.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + k = k.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + + Curve25519Point L61Right = G.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))))); + L61Right = L61Right.add(proof.V.scalarMultiply(z.sq())); + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + { + return false; + } + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + Scalar[] Gexp = new Scalar[N]; + for (int i = 0; i < N; i++) + Gexp[i] = Scalar.ZERO.sub(z); + + Scalar[] Hexp = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Hexp[i] = Scalar.ZERO; + Hexp[i] = Hexp[i].add(z.mul(y.pow(i))); + Hexp[i] = Hexp[i].add(z.sq().mul(Scalar.TWO.pow(i))); + Hexp[i] = Hexp[i].mul(Invert(y).pow(i)); + } + P = P.add(VectorExponent(Gexp,Hexp)); + + // PAPER LINE 63 + for (int i = 0; i < N; i++) + { + Hexp[i] = Scalar.ZERO; + Hexp[i] = Hexp[i].add(proof.r[i]); + Hexp[i] = Hexp[i].mul(Invert(y).pow(i)); + } + Curve25519Point L63Right = VectorExponent(proof.l,Hexp).add(H.scalarMultiply(proof.mu)); + + if (!L63Right.equals(P)) + { + return false; + } + + return true; + } + + public static void main(String[] args) + { + // Number of bits in the range + N = 64; + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[N]; + Hi = new Curve25519Point[N]; + for (int i = 0; i < N; i++) + { + Gi[i] = getHpnGLookup(i); + Hi[i] = getHpnGLookup(N+i); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int TRIALS = 250; + int count = 0; + + while (count < TRIALS) + { + long amount = rando.nextLong(); + if (amount > Math.pow(2,N)-1 || amount < 0) + continue; + + ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar()); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +} diff --git a/source-code/StringCT-java/src/how/monero/hodl/bulletproof/LogBulletproof.java b/source-code/StringCT-java/src/how/monero/hodl/bulletproof/LogBulletproof.java new file mode 100644 index 0000000..cf1cda7 --- /dev/null +++ b/source-code/StringCT-java/src/how/monero/hodl/bulletproof/LogBulletproof.java @@ -0,0 +1,497 @@ +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import java.math.BigInteger; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class LogBulletproof +{ + private static int N; + private static int logN; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Curve25519Point[] L; + private Curve25519Point[] R; + private Scalar a; + private Scalar b; + private Scalar t; + + public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Curve25519Point[] L, Curve25519Point[] R, Scalar a, Scalar b, Scalar t) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.L = L; + this.R = R; + this.a = a; + this.b = b; + this.t = t; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + assert a.length == N && b.length == N; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Compute a custom vector-scalar commitment */ + public static Curve25519Point VectorExponentCustom(Curve25519Point[] A, Curve25519Point[] B, Scalar[] a, Scalar[] b) + { + assert a.length == A.length && b.length == B.length && a.length == b.length; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < a.length; i++) + { + Result = Result.add(A[i].scalarMultiply(a[i])); + Result = Result.add(B[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar result = Scalar.ZERO; + for (int i = 0; i < a.length; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Given two curvepoint arrays, construct the Hadamard product */ + public static Curve25519Point[] Hadamard2(Curve25519Point[] A, Curve25519Point[] B) + { + assert A.length == B.length; + + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].add(B[i]); + } + return Result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Exponentiate a curve vector by a scalar */ + public static Curve25519Point[] VectorScalar2(Curve25519Point[] A, Scalar x) + { + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].scalarMultiply(x); + } + return Result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + + assert x.mul(inverse).equals(Scalar.ONE); + return inverse; + } + + /* Compute the slice of a curvepoint vector */ + public static Curve25519Point[] CurveSlice(Curve25519Point[] a, int start, int stop) + { + Curve25519Point[] Result = new Curve25519Point[stop-start]; + for (int i = start; i < stop; i++) + { + Result[i-start] = a[i]; + } + return Result; + } + + /* Compute the slice of a scalar vector */ + public static Scalar[] ScalarSlice(Scalar[] a, int start, int stop) + { + Scalar[] result = new Scalar[stop-start]; + for (int i = start; i < stop; i++) + { + result[i-start] = a[i]; + } + return result; + } + + /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ + public static ProofTuple PROVE(Scalar v, Scalar gamma) + { + Curve25519Point V = G.scalarMultiply(v).add(H.scalarMultiply(gamma)); + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[N]; + Scalar[] aR = new Scalar[N]; + + BigInteger tempV = v.toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[i] = Scalar.ZERO; + } + else + { + aL[i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[i] = aL[i].sub(Scalar.ONE); + } + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(H.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[N]; + Scalar[] sR = new Scalar[N]; + for (int i = 0; i < N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(H.scalarMultiply(rho)); + + // PAPER LINES 43-45 + Scalar y = hashToScalar(concat(A.toBytes(),S.toBytes())); + Scalar z = hashToScalar(y.bytes); + + // Polynomial construction before PAPER LINE 46 + Scalar t0 = Scalar.ZERO; + Scalar t1 = Scalar.ZERO; + Scalar t2 = Scalar.ZERO; + + t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + t0 = t0.add(z.sq().mul(v)); + Scalar k = Scalar.ZERO; + k = k.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + k = k.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + t0 = t0.add(k); + + t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR))); + t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())))); + + t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR))); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = G.scalarMultiply(t1).add(H.scalarMultiply(tau1)); + Curve25519Point T2 = G.scalarMultiply(t2).add(H.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + Scalar x = hashToScalar(concat(z.bytes,T1.toBytes(),T2.toBytes())); + + // PAPER LINES 52-53 + Scalar taux = Scalar.ZERO; + taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + taux = taux.add(gamma.mul(z.sq())); + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = new Scalar[N]; + Scalar[] r = new Scalar[N]; + + l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x)); + r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())); + + Scalar t = InnerProduct(l,r); + + // PAPER LINES 32-33 + Scalar x_ip = hashToScalar(concat(x.bytes,taux.bytes,mu.bytes,t.bytes)); + + // These are used in the inner product rounds + int nprime = N; + Curve25519Point[] Gprime = new Curve25519Point[N]; + Curve25519Point[] Hprime = new Curve25519Point[N]; + Scalar[] aprime = new Scalar[N]; + Scalar[] bprime = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Gprime[i] = Gi[i]; + Hprime[i] = Hi[i].scalarMultiply(Invert(y).pow(i)); + aprime[i] = l[i]; + bprime[i] = r[i]; + } + Curve25519Point[] L = new Curve25519Point[logN]; + Curve25519Point[] R = new Curve25519Point[logN]; + int round = 0; // track the index based on number of rounds + Scalar[] w = new Scalar[logN]; // this is the challenge x in the inner product protocol + + // PAPER LINE 13 + while (nprime > 1) + { + // PAPER LINE 15 + nprime /= 2; + + // PAPER LINES 16-17 + Scalar cL = InnerProduct(ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)); + Scalar cR = InnerProduct(ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)); + + // PAPER LINES 18-19 + L[round] = VectorExponentCustom(CurveSlice(Gprime,nprime,Gprime.length),CurveSlice(Hprime,0,nprime),ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)).add(G.scalarMultiply(cL.mul(x_ip))); + R[round] = VectorExponentCustom(CurveSlice(Gprime,0,nprime),CurveSlice(Hprime,nprime,Hprime.length),ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)).add(G.scalarMultiply(cR.mul(x_ip))); + + // PAPER LINES 21-22 + if (round == 0) + w[0] = hashToScalar(concat(L[0].toBytes(),R[0].toBytes())); + else + w[round] = hashToScalar(concat(w[round-1].bytes,L[round].toBytes(),R[round].toBytes())); + + // PAPER LINES 24-25 + Gprime = Hadamard2(VectorScalar2(CurveSlice(Gprime,0,nprime),Invert(w[round])),VectorScalar2(CurveSlice(Gprime,nprime,Gprime.length),w[round])); + Hprime = Hadamard2(VectorScalar2(CurveSlice(Hprime,0,nprime),w[round]),VectorScalar2(CurveSlice(Hprime,nprime,Hprime.length),Invert(w[round]))); + + // PAPER LINES 28-29 + aprime = VectorAdd(VectorScalar(ScalarSlice(aprime,0,nprime),w[round]),VectorScalar(ScalarSlice(aprime,nprime,aprime.length),Invert(w[round]))); + bprime = VectorAdd(VectorScalar(ScalarSlice(bprime,0,nprime),Invert(w[round])),VectorScalar(ScalarSlice(bprime,nprime,bprime.length),w[round])); + + round += 1; + } + + // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) + return new ProofTuple(V,A,S,T1,T2,taux,mu,L,R,aprime[0],bprime[0],t); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar y = hashToScalar(concat(proof.A.toBytes(),proof.S.toBytes())); + Scalar z = hashToScalar(y.bytes); + Scalar x = hashToScalar(concat(z.bytes,proof.T1.toBytes(),proof.T2.toBytes())); + Scalar x_ip = hashToScalar(concat(x.bytes,proof.taux.bytes,proof.mu.bytes,proof.t.bytes)); + + // PAPER LINE 61 + Curve25519Point L61Left = H.scalarMultiply(proof.taux).add(G.scalarMultiply(proof.t)); + + Scalar k = Scalar.ZERO; + k = k.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + k = k.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + + Curve25519Point L61Right = G.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))))); + L61Right = L61Right.add(proof.V.scalarMultiply(z.sq())); + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + return false; + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + Scalar[] Gexp = new Scalar[N]; + for (int i = 0; i < N; i++) + Gexp[i] = Scalar.ZERO.sub(z); + + Scalar[] Hexp = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Hexp[i] = Scalar.ZERO; + Hexp[i] = Hexp[i].add(z.mul(y.pow(i))); + Hexp[i] = Hexp[i].add(z.sq().mul(Scalar.TWO.pow(i))); + Hexp[i] = Hexp[i].mul(Invert(y).pow(i)); + } + P = P.add(VectorExponent(Gexp,Hexp)); + + // Compute the number of rounds for the inner product + int rounds = proof.L.length; + + // PAPER LINES 21-22 + // The inner product challenges are computed per round + Scalar[] w = new Scalar[rounds]; + w[0] = hashToScalar(concat(proof.L[0].toBytes(),proof.R[0].toBytes())); + if (rounds > 1) + { + for (int i = 1; i < rounds; i++) + { + w[i] = hashToScalar(concat(w[i-1].bytes,proof.L[i].toBytes(),proof.R[i].toBytes())); + } + } + + // Basically PAPER LINES 24-25 + // Compute the curvepoints from G[i] and H[i] + Curve25519Point InnerProdG = Curve25519Point.ZERO; + Curve25519Point InnerProdH = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + // Convert the index to binary IN REVERSE and construct the scalar exponent + int index = i; + Scalar gScalar = Scalar.ONE; + Scalar hScalar = Invert(y).pow(i); + + for (int j = rounds-1; j >= 0; j--) + { + int J = w.length - j - 1; // because this is done in reverse bit order + int basePow = (int) Math.pow(2,j); // assumes we don't get too big + if (index / basePow == 0) // bit is zero + { + gScalar = gScalar.mul(Invert(w[J])); + hScalar = hScalar.mul(w[J]); + } + else // bit is one + { + gScalar = gScalar.mul(w[J]); + hScalar = hScalar.mul(Invert(w[J])); + index -= basePow; + } + } + + // Now compute the basepoint's scalar multiplication + // Each of these could be written as a multiexp operation instead + InnerProdG = InnerProdG.add(Gi[i].scalarMultiply(gScalar)); + InnerProdH = InnerProdH.add(Hi[i].scalarMultiply(hScalar)); + } + + // PAPER LINE 26 + Curve25519Point Pprime = P.add(H.scalarMultiply(Scalar.ZERO.sub(proof.mu))); + + for (int i = 0; i < rounds; i++) + { + Pprime = Pprime.add(proof.L[i].scalarMultiply(w[i].sq())); + Pprime = Pprime.add(proof.R[i].scalarMultiply(Invert(w[i]).sq())); + } + Pprime = Pprime.add(G.scalarMultiply(proof.t.mul(x_ip))); + + if (!Pprime.equals(InnerProdG.scalarMultiply(proof.a).add(InnerProdH.scalarMultiply(proof.b)).add(G.scalarMultiply(proof.a.mul(proof.b).mul(x_ip))))) + return false; + + return true; + } + + public static void main(String[] args) + { + // Number of bits in the range + N = 64; + logN = 6; // its log, manually + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[N]; + Hi = new Curve25519Point[N]; + for (int i = 0; i < N; i++) + { + Gi[i] = getHpnGLookup(i); + Hi[i] = getHpnGLookup(N+i); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int TRIALS = 250; + int count = 0; + + while (count < TRIALS) + { + long amount = rando.nextLong(); + if (amount > Math.pow(2,N)-1 || amount < 0) + continue; + + ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar()); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +} diff --git a/source-code/StringCT-java/src/how/monero/hodl/bulletproof/OptimizedLogBulletproof.java b/source-code/StringCT-java/src/how/monero/hodl/bulletproof/OptimizedLogBulletproof.java new file mode 100644 index 0000000..6b2acde --- /dev/null +++ b/source-code/StringCT-java/src/how/monero/hodl/bulletproof/OptimizedLogBulletproof.java @@ -0,0 +1,487 @@ +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import java.math.BigInteger; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class OptimizedLogBulletproof +{ + private static int N; + private static int logN; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Curve25519Point[] L; + private Curve25519Point[] R; + private Scalar a; + private Scalar b; + private Scalar t; + + public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Curve25519Point[] L, Curve25519Point[] R, Scalar a, Scalar b, Scalar t) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.L = L; + this.R = R; + this.a = a; + this.b = b; + this.t = t; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + assert a.length == N && b.length == N; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Compute a custom vector-scalar commitment */ + public static Curve25519Point VectorExponentCustom(Curve25519Point[] A, Curve25519Point[] B, Scalar[] a, Scalar[] b) + { + assert a.length == A.length && b.length == B.length && a.length == b.length; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < a.length; i++) + { + Result = Result.add(A[i].scalarMultiply(a[i])); + Result = Result.add(B[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar result = Scalar.ZERO; + for (int i = 0; i < a.length; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Given two curvepoint arrays, construct the Hadamard product */ + public static Curve25519Point[] Hadamard2(Curve25519Point[] A, Curve25519Point[] B) + { + assert A.length == B.length; + + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].add(B[i]); + } + return Result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Exponentiate a curve vector by a scalar */ + public static Curve25519Point[] VectorScalar2(Curve25519Point[] A, Scalar x) + { + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].scalarMultiply(x); + } + return Result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + + assert x.mul(inverse).equals(Scalar.ONE); + return inverse; + } + + /* Compute the slice of a curvepoint vector */ + public static Curve25519Point[] CurveSlice(Curve25519Point[] a, int start, int stop) + { + Curve25519Point[] Result = new Curve25519Point[stop-start]; + for (int i = start; i < stop; i++) + { + Result[i-start] = a[i]; + } + return Result; + } + + /* Compute the slice of a scalar vector */ + public static Scalar[] ScalarSlice(Scalar[] a, int start, int stop) + { + Scalar[] result = new Scalar[stop-start]; + for (int i = start; i < stop; i++) + { + result[i-start] = a[i]; + } + return result; + } + + /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ + public static ProofTuple PROVE(Scalar v, Scalar gamma) + { + Curve25519Point V = G.scalarMultiply(v).add(H.scalarMultiply(gamma)); + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[N]; + Scalar[] aR = new Scalar[N]; + + BigInteger tempV = v.toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[i] = Scalar.ZERO; + } + else + { + aL[i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[i] = aL[i].sub(Scalar.ONE); + } + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(H.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[N]; + Scalar[] sR = new Scalar[N]; + for (int i = 0; i < N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(H.scalarMultiply(rho)); + + // PAPER LINES 43-45 + Scalar y = hashToScalar(concat(A.toBytes(),S.toBytes())); + Scalar z = hashToScalar(y.bytes); + + // Polynomial construction before PAPER LINE 46 + Scalar t0 = Scalar.ZERO; + Scalar t1 = Scalar.ZERO; + Scalar t2 = Scalar.ZERO; + + t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + t0 = t0.add(z.sq().mul(v)); + Scalar k = Scalar.ZERO; + k = k.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + k = k.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + t0 = t0.add(k); + + t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR))); + t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())))); + + t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR))); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = G.scalarMultiply(t1).add(H.scalarMultiply(tau1)); + Curve25519Point T2 = G.scalarMultiply(t2).add(H.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + Scalar x = hashToScalar(concat(z.bytes,T1.toBytes(),T2.toBytes())); + + // PAPER LINES 52-53 + Scalar taux = Scalar.ZERO; + taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + taux = taux.add(gamma.mul(z.sq())); + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = new Scalar[N]; + Scalar[] r = new Scalar[N]; + + l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x)); + r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())); + + Scalar t = InnerProduct(l,r); + + // PAPER LINES 32-33 + Scalar x_ip = hashToScalar(concat(x.bytes,taux.bytes,mu.bytes,t.bytes)); + + // These are used in the inner product rounds + int nprime = N; + Curve25519Point[] Gprime = new Curve25519Point[N]; + Curve25519Point[] Hprime = new Curve25519Point[N]; + Scalar[] aprime = new Scalar[N]; + Scalar[] bprime = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Gprime[i] = Gi[i]; + Hprime[i] = Hi[i].scalarMultiply(Invert(y).pow(i)); + aprime[i] = l[i]; + bprime[i] = r[i]; + } + Curve25519Point[] L = new Curve25519Point[logN]; + Curve25519Point[] R = new Curve25519Point[logN]; + int round = 0; // track the index based on number of rounds + Scalar[] w = new Scalar[logN]; // this is the challenge x in the inner product protocol + + // PAPER LINE 13 + while (nprime > 1) + { + // PAPER LINE 15 + nprime /= 2; + + // PAPER LINES 16-17 + Scalar cL = InnerProduct(ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)); + Scalar cR = InnerProduct(ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)); + + // PAPER LINES 18-19 + L[round] = VectorExponentCustom(CurveSlice(Gprime,nprime,Gprime.length),CurveSlice(Hprime,0,nprime),ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)).add(G.scalarMultiply(cL.mul(x_ip))); + R[round] = VectorExponentCustom(CurveSlice(Gprime,0,nprime),CurveSlice(Hprime,nprime,Hprime.length),ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)).add(G.scalarMultiply(cR.mul(x_ip))); + + // PAPER LINES 21-22 + if (round == 0) + w[0] = hashToScalar(concat(L[0].toBytes(),R[0].toBytes())); + else + w[round] = hashToScalar(concat(w[round-1].bytes,L[round].toBytes(),R[round].toBytes())); + + // PAPER LINES 24-25 + Gprime = Hadamard2(VectorScalar2(CurveSlice(Gprime,0,nprime),Invert(w[round])),VectorScalar2(CurveSlice(Gprime,nprime,Gprime.length),w[round])); + Hprime = Hadamard2(VectorScalar2(CurveSlice(Hprime,0,nprime),w[round]),VectorScalar2(CurveSlice(Hprime,nprime,Hprime.length),Invert(w[round]))); + + // PAPER LINES 28-29 + aprime = VectorAdd(VectorScalar(ScalarSlice(aprime,0,nprime),w[round]),VectorScalar(ScalarSlice(aprime,nprime,aprime.length),Invert(w[round]))); + bprime = VectorAdd(VectorScalar(ScalarSlice(bprime,0,nprime),Invert(w[round])),VectorScalar(ScalarSlice(bprime,nprime,bprime.length),w[round])); + + round += 1; + } + + // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) + return new ProofTuple(V,A,S,T1,T2,taux,mu,L,R,aprime[0],bprime[0],t); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar y = hashToScalar(concat(proof.A.toBytes(),proof.S.toBytes())); + Scalar z = hashToScalar(y.bytes); + Scalar x = hashToScalar(concat(z.bytes,proof.T1.toBytes(),proof.T2.toBytes())); + Scalar x_ip = hashToScalar(concat(x.bytes,proof.taux.bytes,proof.mu.bytes,proof.t.bytes)); + + // PAPER LINE 61 + Curve25519Point L61Left = H.scalarMultiply(proof.taux).add(G.scalarMultiply(proof.t)); + + Scalar k = Scalar.ZERO; + k = k.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + k = k.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + + Curve25519Point L61Right = G.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))))); + L61Right = L61Right.add(proof.V.scalarMultiply(z.sq())); + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + return false; + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + // Compute the number of rounds for the inner product + int rounds = proof.L.length; + + // PAPER LINES 21-22 + // The inner product challenges are computed per round + Scalar[] w = new Scalar[rounds]; + w[0] = hashToScalar(concat(proof.L[0].toBytes(),proof.R[0].toBytes())); + if (rounds > 1) + { + for (int i = 1; i < rounds; i++) + { + w[i] = hashToScalar(concat(w[i-1].bytes,proof.L[i].toBytes(),proof.R[i].toBytes())); + } + } + + // Basically PAPER LINES 24-25 + // Compute the curvepoints from G[i] and H[i] + Curve25519Point InnerProdG = Curve25519Point.ZERO; + Curve25519Point InnerProdH = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + // Convert the index to binary IN REVERSE and construct the scalar exponent + int index = i; + Scalar gScalar = proof.a; + Scalar hScalar = proof.b.mul(Invert(y).pow(i)); + + for (int j = rounds-1; j >= 0; j--) + { + int J = w.length - j - 1; // because this is done in reverse bit order + int basePow = (int) Math.pow(2,j); // assumes we don't get too big + if (index / basePow == 0) // bit is zero + { + gScalar = gScalar.mul(Invert(w[J])); + hScalar = hScalar.mul(w[J]); + } + else // bit is one + { + gScalar = gScalar.mul(w[J]); + hScalar = hScalar.mul(Invert(w[J])); + index -= basePow; + } + } + + // Adjust the scalars using the exponents from PAPER LINE 62 + gScalar = gScalar.add(z); + hScalar = hScalar.sub(z.mul(y.pow(i)).add(z.sq().mul(Scalar.TWO.pow(i))).mul(Invert(y).pow(i))); + + // Now compute the basepoint's scalar multiplication + // Each of these could be written as a multiexp operation instead + InnerProdG = InnerProdG.add(Gi[i].scalarMultiply(gScalar)); + InnerProdH = InnerProdH.add(Hi[i].scalarMultiply(hScalar)); + } + + // PAPER LINE 26 + Curve25519Point Pprime = P.add(H.scalarMultiply(Scalar.ZERO.sub(proof.mu))); + + for (int i = 0; i < rounds; i++) + { + Pprime = Pprime.add(proof.L[i].scalarMultiply(w[i].sq())); + Pprime = Pprime.add(proof.R[i].scalarMultiply(Invert(w[i]).sq())); + } + Pprime = Pprime.add(G.scalarMultiply(proof.t.mul(x_ip))); + + if (!Pprime.equals(InnerProdG.add(InnerProdH).add(G.scalarMultiply(proof.a.mul(proof.b).mul(x_ip))))) + return false; + + return true; + } + + public static void main(String[] args) + { + // Number of bits in the range + N = 64; + logN = 6; // its log, manually + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[N]; + Hi = new Curve25519Point[N]; + for (int i = 0; i < N; i++) + { + Gi[i] = getHpnGLookup(i); + Hi[i] = getHpnGLookup(N+i); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int TRIALS = 250; + int count = 0; + + while (count < TRIALS) + { + long amount = rando.nextLong(); + if (amount > Math.pow(2,N)-1 || amount < 0) + continue; + + ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar()); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +} From 41c8b73f2b797e51e2f7d36c905dc151c75c57a9 Mon Sep 17 00:00:00 2001 From: b-g-goodell Date: Tue, 12 Dec 2017 11:15:56 -0700 Subject: [PATCH 02/11] Testing new versions --- source-code/Spectre/Block.py | 128 +++++++++-------- source-code/Spectre/RoBlocks.py | 245 ++++++++++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 60 deletions(-) create mode 100644 source-code/Spectre/RoBlocks.py diff --git a/source-code/Spectre/Block.py b/source-code/Spectre/Block.py index 48f01d5..399fad7 100644 --- a/source-code/Spectre/Block.py +++ b/source-code/Spectre/Block.py @@ -1,77 +1,85 @@ import unittest import math -import numpy as np import copy from collections import deque import time +import hashlib class Block(object): - """ Fundamental object. Contains dict of blockIDs:(parent blocks) """ + """ + Fundamental object. Attributes: + data = payload dict with keys "timestamp" and "txns" and others + ident = string + parents = dict {blockID : parentBlock} + Functions: + addParents : takes dict {blockID : parentBlock} as input + and updates parents to include. + _recomputeIdent : recomputes identity + Usage: + b0 = Block() + b0.data = ... + b1 = Block() + b1.data = ... + b1.addParents({b0.ident:b0}) + + """ def __init__(self): - self.id = "" # string - self.timestamp = None # format tbd - self.data = None # payload - self.parents = {} # block ID : pointer to block - self.children = {} # block ID : pointer to block - def addChild(self, childIn): - if childIn not in self.children: - self.children.update({childIn.id:childIn}) - def addChildren(self, childrenIn): - for child in childrenIn: - self.addChild(childrenIn[child]) - def addParent(self, parentIn): - if parentIn not in self.parents: - self.parents.update({parentIn.id:parentIn}) - def addParents(self, parentsIn): - for parent in parentsIn: - self.addParent(parentsIn[parent]) - - + # Initialize with empty payload, no identity, and empty parents. + self.data = None + self.ident = hash(str(0)) + self.parents = None + self.addParents({}) + + def addParents(self, parentsIn): # dict of parents + if self.parents is None: + self.parents = parentsIn + else: + self.parents.update(parentsIn) + self._recomputeIdent() + + def _recomputeIdent(self): + m = str(0) + str(self.data) + str(self.parents) + self.ident = hash(m) + + class Test_Block(unittest.TestCase): def test_Block(self): + # b0 -> b1 -> {both b2, b3} -> b4... oh, and say b3 -> b5 also b0 = Block() - b0.id = "0" - self.assertTrue(b0.data is None) - self.assertTrue(len(b0.parents)==0) + b0.data = {"timestamp" : time.time()} + time.sleep(1) b1 = Block() - b1.parents.update({"0":b0}) - b1.id = "1" - for parentID in b1.parents: - b1.parents[parentID].children.update({b1.id:b1}) - self.assertTrue(b1.data is None) - self.assertTrue(len(b1.parents)==1) - self.assertTrue("0" in b1.parents) + b1.data = {"timestamp" : time.time(), "txns" : [1,2,3]} + b1.addParents({b0.ident:b0}) # updateIdent called with addParent. + time.sleep(1) b2 = Block() - b2.parents.update({"0":b0}) - b2.id = "2" - for parentID in b2.parents: - b2.parents[parentID].children.update({b2.id:b2}) - self.assertTrue(b2.data is None) - self.assertTrue(len(b2.parents)==1) - self.assertTrue("0" in b2.parents) - + b2.data = {"timestamp" : time.time(), "txns" : None} + b2.addParents({b1.ident:b1}) + time.sleep(1) + b3 = Block() - b3.parents.update({"1":b1, "2":b2}) - b3.id = "3" - for parentID in b3.parents: - b3.parents[parentID].children.update({b3.id:b3}) - self.assertTrue(b3.data is None) - self.assertTrue(len(b3.parents)==2) - self.assertTrue("1" in b3.parents) - self.assertTrue("2" in b3.parents) - self.assertFalse("0" in b3.parents) - + b3.data = {"timestamp" : time.time(), "txns" : None} + b3.addParents({b1.ident:b1}) + time.sleep(1) + b4 = Block() - b4.parents.update({"2":b2}) - b4.id = "4" - for parentID in b4.parents: - b4.parents[parentID].children.update({b4.id:b4}) - self.assertTrue(b4.data is None) - self.assertTrue(len(b4.parents)==1) - self.assertTrue("2" in b4.parents) - -suite = unittest.TestLoader().loadTestsFromTestCase(Test_Block) -unittest.TextTestRunner(verbosity=1).run(suite) - + b4.data = {"timestamp" : time.time()} # see how sloppy we can be wheeee + b4.addParents({b2.ident:b2, b3.ident:b3}) + time.sleep(1) + + b5 = Block() + b5.data = {"timestamp" : time.time(), "txns" : "stuff" } + b5.addParents({b3.ident:b3}) + + self.assertTrue(len(b1.parents)==1 and b0.ident in b1.parents) + self.assertTrue(len(b2.parents)==1 and b1.ident in b2.parents) + self.assertTrue(len(b3.parents)==1 and b1.ident in b3.parents) + self.assertTrue(len(b4.parents)==2) + self.assertTrue(b2.ident in b4.parents and b3.ident in b4.parents) + self.assertTrue(len(b5.parents)==1 and b3.ident in b5.parents) + + +#suite = unittest.TestLoader().loadTestsFromTestCase(Test_Block) +#unittest.TextTestRunner(verbosity=1).run(suite) diff --git a/source-code/Spectre/RoBlocks.py b/source-code/Spectre/RoBlocks.py new file mode 100644 index 0000000..dc42aa7 --- /dev/null +++ b/source-code/Spectre/RoBlocks.py @@ -0,0 +1,245 @@ +''' + A handler for Block.py that takes a collection of blocks (which + only reference parents) as input data. It uses a doubly-linked + tree to determine precedent relationships efficiently, and it can + use that precedence relationship to produce a reduced/robust pre- + cedence relationship as output (the spectre precedence relationship + between blocks. + + Another handler will extract a coherent/robust list of non-conflict- + ing transactions from a reduced/robust RoBlocks object. +''' +from Block import * + +class RoBlocks(object): + def __init__(self): + print("Initializing") + # Initialize a RoBlocks object. + self.data = None + self.blocks = {} # Set of blocks (which track parents) + self.family = {} # Doubly linked list tracks parent-and-child links + self.invDLL = {} # subset of blocks unlikely to be re-orged + self.roots = {} # dict {blockIdent : block} root blocks + self.leaves = {} + self.antichainCutoff = 600 # stop re-orging after this many layers + self.pendingVotes = {} + self.votes = {} + + def _addBlocks(self, blocksIn): + print("Adding Blocks") + # Take dict of {blockIdent : block} and call _addBlock on each. + for b in blocksIn: + self._addBlock(blocksIn[b]) + + def _addBlock(self, b): + print("Adding block") + # Take a single block b and add to self.blocks, record family + # relations, update leaf monitor, update root monitor if nec- + # essary + diffDict = {b.ident:b} + self.blocks.update(diffDict) + self.family.update({b.ident:{}}) + self.family[b.ident].update({"parents":b.parents, "children":{}}) + for parentIdent in b.parents: + if parentIdent not in self.family: + self.family.update({parentIdent:{}}) + if "parents" not in self.family[parentIdent]: + self.family[parentIdent].update({"parents":{}}) + if "children" not in self.family[parentIdent]: + self.family[parentIdent].update({"children":{}}) + self.family[parentIdent]["parents"].update(b.parents) + self.family[parentIdent]["children"].update(diffDict) + if parentIdent in self.leaves: + del self.leaves[parentIdent] + if len(b.parents)==0 and b.ident not in self.roots: + self.roots.update(diffDict) + self.leaves.update(diffDict) + + def inPast(self, x, y): + print("Testing if in past") + # Return true if y is an ancestor of x + q = deque() + for pid in self.blocks[x].parents: + if pid==y: + return True + break + q.append(pid) + while(len(q)>0): + nxtIdent = q.popleft() + if len(self.blocks[nxtIdent].parents) > 0: + for pid in self.blocks[nxtIdent].parents: + if pid==y: + return True + break + q.append(pid) + return False + + + def vote(self): + print("Voting") + # Compute partial spectre vote for top several layers of + # the dag. + (U, vids) = self.leafBackAntichain() + self.votes = {} + + q = deque() + self.pendingVotes = {} + for i in range(len(U)): + for leafId in U[i]: + if i > 0: + self.sumPendingVotes(leafId, vids) + + for x in U[i]: + if x != leafId: + q.append(x) + while(len(q)>0): + x = q.popleft() + if (leafId, leafId, x) not in self.votes: + self.votes.update({(leafId, leafId, x):1}) + else: + try: + assert self.votes[(leafId, leafId, x)]==1 + except AssertionError: + print("Woops, we found (leafId, leafId, x) as a key in self.votes while running vote(), and it should be +1, but it isn't:\n\n", (leafId, leafId, x), self.votes[(leafId, leafId, x)], "\n\n") + if (leafId, x, leafId) not in self.votes: + self.votes.update({(leafId, x, leafId):-1}) + else: + try: + assert self.votes[(leafId,x,leafId)]==-1 + except AssertionError: + print("Woops, we found (leafId, x, leafId) as a key in self.votes while running vote(), and it should be +1, but it isn't:\n\n", (leafId, x, leafId), self.votes[(leafId, x, leafId)], "\n\n") + self.transmitVote(leafId, leafId, x) + for pid in self.blocks[x].parents: + if not self.inPast(leafId, pid) and pid in vids and pid != leafId: + q.append(pid) + print(self.votes) + + def sumPendingVotes(self, blockId, vulnIds): + print("Summing pending votes") + # For a blockId, take all pending votes for vulnerable IDs (x,y) + # if the net is positive vote 1, if the net is negative vote -1 + # otherwise vote 0. + for x in vulnIds: + for y in vulnIds: + if (blockId, x, y) in self.pendingVotes: + if self.pendingVotes[(blockId, x, y)] > 0: + if (blockId, x, y) not in self.votes: + self.votes.update({(blockId, x, y):1}) + else: + try: + assert self.votes[(blockId,x,y)]==1 + except AssertionError: + print("Woops, we found (blockId, x, y) as a key in self.votes, and it should be +1, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") + if (blockId, y, x) not in self.votes: + self.votes.update({(blockId, y, x):-1}) + else: + try: + assert self.votes[(blockId,y,x)]==-1 + except AssertionError: + print("Woops, we found (blockId, y, x) as a key in self.votes, and it should be -1, but it isn't:\n\n", (blockId, y, x), self.votes[(blockId, y,x)], "\n\n") + self.transmitVote(blockId, x, y) + elif self.pendingVotes[(blockId, x, y)] < 0: + if (blockId, x, y) not in self.votes: + self.votes.update({(blockId, x, y):-1}) + else: + try: + assert self.votes[(blockId,x,y)]==-1 + except AssertionError: + print("Woops, we found (blockId, x, y) as a key in self.votes, and it should be -1, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") + + if (blockId, y, x) not in self.votes: + self.votes.update({(blockId, y, x):1}) + else: + try: + assert self.votes[(blockId,y,x)]==1 + except AssertionError: + print("Woops, we found (blockId, y, x) as a key in self.votes, and it should be +1, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") + self.transmitVote(blockId, y, x) + else: + if (blockId, x, y) not in self.votes: + self.votes.update({(blockId, x, y):0}) + else: + try: + assert self.votes[(blockId,x,y)]==0 + except AssertionError: + print("Woops, we found (blockId, x, y) as a key in self.votes, and it should be 0, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") + if (blockId, y, x) not in self.votes: + self.votes.update({(blockId, y, x):0}) + else: + try: + assert self.votes[(blockId,y,x)]==0 + except AssertionError: + print("Woops, we found (blockId, y, x) as a key in self.votes, and it should be 0, but it isn't:\n\n", (blockId, y, x), self.votes[(blockId, y,x)], "\n\n") + + + + + def transmitVote(self, v, x, y): + print("Transmitting votes") + q = deque() + for pid in self.blocks[v].parents: + q.append(pid) + while(len(q)>0): + print("Length of queue = ", len(q)) + nxtPid = q.popleft() + if (nxtPid, x, y) not in self.pendingVotes: + self.pendingVotes.update({(nxtPid,x,y):1}) + self.pendingVotes.update({(nxtPid,y,x):-1}) + else: + self.pendingVotes[(nxtPid,x,y)] += 1 + self.pendingVotes[(nxtPid,y,x)] -= 1 + if len(self.blocks[nxtPid].parents) > 0: + for pid in self.blocks[nxtPid].parents: + if pid != nxtPid: + q.append(pid) + + + + + def leafBackAntichain(self): + print("Computing antichain") + temp = copy.deepcopy(self) + decomposition = [] + vulnIdents = None + decomposition.append(temp.leaves) + vulnIdents = decomposition[-1] + temp = temp.pruneLeaves() + while(len(temp.blocks)>0 and len(decomposition) < self.antichainCutoff): + decomposition.append(temp.leaves) + for xid in decomposition[-1]: + if xid not in vulnIdents: + vulnIdents.update({xid:decomposition[-1][xid]}) + temp = temp.pruneLeaves() + return decomposition, vulnIdents + + def pruneLeaves(self): + print("Pruning leaves") + out = RoBlocks() + q = deque() + for rootIdent in self.roots: + q.append(rootIdent) + while(len(q)>0): + thisIdent = q.popleft() + if thisIdent not in self.leaves: + out._addBlock(self.blocks[thisIdent]) + for chIdent in self.family[thisIdent]["children"]: + q.append(chIdent) + return out + +class Test_RoBlock(unittest.TestCase): + def test_RoBlocks(self): + R = RoBlocks() + b = Block() + b.data = "zirconium encrusted tweezers" + b._recomputeIdent() + R._addBlock(b) + + b = Block() + b.data = "brontosaurus slippers cannot exist" + b.addParents(R.leaves) + R._addBlock(b) + + R.vote() + +suite = unittest.TestLoader().loadTestsFromTestCase(Test_RoBlock) +unittest.TextTestRunner(verbosity=1).run(suite) From 463d8f35e28640d5d58c21b1ec19aa40a726cbb4 Mon Sep 17 00:00:00 2001 From: b-g-goodell Date: Tue, 12 Dec 2017 11:18:22 -0700 Subject: [PATCH 03/11] Rem. dep. file --- source-code/Spectre/BlockDAG.py | 227 -------------------------------- 1 file changed, 227 deletions(-) delete mode 100644 source-code/Spectre/BlockDAG.py diff --git a/source-code/Spectre/BlockDAG.py b/source-code/Spectre/BlockDAG.py deleted file mode 100644 index 3f89be0..0000000 --- a/source-code/Spectre/BlockDAG.py +++ /dev/null @@ -1,227 +0,0 @@ -from Block import * - #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### -class BlockDAG(object): - """ Collection of >=1 block. """ - def __init__(self, params=None): - self.genesis = Block() - self.genesis.id = "0" - self.blocks = {self.genesis.id:self.genesis} - self.leaves = {self.genesis.id:self.genesis} - - # Blocks from top-down antichain subsets covering >= 1/2 of blockDAG - self.votBlocks = {self.genesis.id:self.genesis} - # Blocks from top-down antichain subsets "non-negl" likely to re-org - self.ordBlocks = {self.genesis.id:self.genesis} - - if params is not None: - self.security = params - else: - self.security = 10 - self.vote = {} - self.pending = {} - for blockZ in self.votBlocks: - for blockX in self.ordBlocks: - for blockY in self.ordBlocks: - self.vote.update({(blockZ,blockX,blockY):0}) - self.pending.update({(blockZ,blockX,blockY):0}) - - def computeVote(self, dagIn): - (canopy, fullCanopy) = self.pick(dagIn) - for layer in fullCanopy: - for blockZ in layer: - if blockZ not in dagIn.votBlocks: - continue - else: - for blockX in layer: - if blockX not in dagIn.ordBlocks: - continue - else: - for blockY in layer: - if blockY not in dagIn.ordBlocks: - continue - else: - if self.inPast(dagIn,blockY,blockZ) and self.inPast(dagIn,blockX,blockZ): - # then Z votes recursively - if blockZ not in dagIn.seenPasts: - dagIn.seenPasts.update({blockZ:dagIn.getPast(blockZ)}) - dagIn.seenVotes.update({blockZ:dagIn.vote(dagIn.seenPasts[blockZ])}) - dagIn.vote.update({(blockZ,blockX,blockY):dagIn.seenVotes[blockZ][(blockX,blockY)], (blockZ,blockY,blockX):dagIn.seenVotes[blockZ][(blockY,blockX)], (blockZ,blockX,blockZ):1, (blockZ, blockZ, blockX):-1, (blockZ, blockY, blockZ):1, (blockZ, blockZ, blockY):-1}) - elif self.inPast(dagIn, blockY, blockZ) and not self.inPast(dagIn, blockX, blockZ): - dagIn.vote.update({(blockZ,blockX,blockY):-1, (blockZ,blockY,blockX):1, (blockZ,blockX,blockZ):-1, (blockZ,blockZ,blockX):1, (blockZ,blockZ,blockY):-1, (blockZ,blockY,blockZ):1}) # Then Z votes Y < Z < X - elif not self.inPast(dagIn, blockY, blockZ) and self.inPast(dagIn, blockX, blockZ): - dagIn.vote.update({(blockZ,blockX,blockY):1, (blockZ,blockY,blockX):-1, (blockZ,blockX,blockZ):1, (blockZ,blockZ,blockX):-1, (blockZ,blockZ,blockY):1, (blockZ,blockY,blockZ):-1}) # Then Z votes X < Z < Y - else: - if dagIn.pending[(blockZ,blockX,blockY)] > 0: - dagIn.vote.update({(blockZ,blockX,blockY):1, (blockZ,blockY,blockX):-1, (blockZ, blockX, blockZ):-1, (blockZ, blockZ, blockX):1, (blockZ, blockY, blockZ):-1, (blockZ, blockZ, blockY):1}) - elif dagIn.pending[(blockZ,blockX,blockY)] < 0: - dagIn.vote.update({(blockZ,blockX,blockY):-1, (blockZ,blockY,blockX):1, (blockZ, blockX, blockZ):-1, (blockZ, blockZ, blockX):1, (blockZ, blockY, blockZ):-1, (blockZ, blockZ, blockY):1}) - else: - dagIn.vote.update({(blockZ,blockX,blockY):0, (blockZ,blockY,blockX):0, (blockZ, blockX, blockZ):-1, (blockZ, blockZ, blockX):1, (blockZ, blockY, blockZ):-1, (blockZ, blockZ, blockY):1}) - q = deque() - for p in dagIn.blocks[blockZ].parents: - if p in dagIn.votBlocks: - q.append(p) - while(len(q)>0): - nextBlock = q.popleft() - if (nextBlock, blockX, blockY) not in dagIn.pending: - dagIn.pending.update({(nextBlock, blockX,blockY):0}) - if (nextBlock, blockY, blockX) not in dagIn.pending: - dagIn.pending.update({(nextBlock, blockY,blockX):0}) - if dagIn.vote[(blockZ,blockX,blockY)] > 0: - dagIn.pending[(nextBlock,blockX,blockY)] += 1 - dagIn.pending[(nextBlock,blockY,blockX)] -= 1 - elif dagIn.vote[(blockZ,blockX,blockY)] < 0: - dagIn.pending[(nextBlock,blockX,blockY)] -= 1 - dagIn.pending[(nextBlock,blockY,blockX)] += 1 - for p in dagIn.blocks[nextBlock].parents: - if p in dagIn.votBlocks: - q.append(p) - totalVote = {} - for blockX in dagIn.ordBlocks: - for blockY in dagIn.ordBlocks: - if (blockX, blockY) not in totalVote: - totalVote.update({(blockX,blockY):0, (blockY,blockX):0}) - for blockZ in dagIn.votBlocks: - if dagIn.vote[(blockZ,blockX,blockY)] > 0: - totalVote[(blockX,blockY)] += 1 - elif dagIn.vote[(blockZ,blockX,blockY)] < 0: - totalVote[(blockX,blockY)] -= 1 - if totalVote[(blockX,blockY)] > 0: - totalVote[(blockX,blockY)] = 1 - elif totalVote[(blockX,blockY)] < 0: - totalVote[(blockX,blockY)] = -1 - return totalVote - - def pick(self, dagIn): - """ Pick voting blocks and orderable blocks """ - (canopy, fullCanopy) = self.antichain(dagIn) - dagIn.votBlocks = {} - dagIn.ordBlocks = {} - idx = 0 - count = len(canopy[idx]) - for block in canopy[idx]: - dagIn.votBlocks.update({block:dagIn.blocks[block]}) - dagIn.ordBlocks.update({blcok:dagIn.blocks[block]}) - numVoters = 1 - ((-len(dagIn.blocks))//2) - while(count < numVoters): - idx += 1 - count += len(canopy[idx]) - for block in canopy[idx]: - dagIn.votBlocks.update({block:dagIn.blocks[block]}) - if idx < self.security: - dagIn.ordBlocks.update({block:dagIn.blocks[block]}) - return (canopy, fullCanopy) - - - def makeBlock(self, idIn, parentsIn): - assert idIn not in self.blocks - newBlock = Block() - newBlock.id = idIn - newBlock.addParents(parentsIn) - self.blocks.update({newBlock.id:newBlock}) - for parent in parentsIn: - if parent in self.leaves: - del self.leaves[parent] - self.blocks[parent].addChild(newBlock) - self.leaves.update({newBlock.id:newBlock}) - - def pruneLeaves(self, dagIn): - result = BlockDAG() - result.genesis.id = dagIn.genesis.id - q = deque() - for child in dagIn.genesis.children: - if child not in dagIn.leaves: - q.append(child) - while(len(q)>0): - nextBlock = q.popleft() - result.makeBlock(nextBlock, dagIn.blocks[nextBlock].parents) - for child in dagIn.blocks[nextBlock].children: - if child not in dagIn.leaves: - q.append(child) - return result - - def antichain(self, dagIn): - canopy = [] - fullCanopy = [] - nextDag = dagIn - canopy.append(nextDag.leaves) - fullCanopy.append(nextDag.leaves) - while(len(nextDag.blocks)>1): - nextDag = dagIn.pruneLeaves(dagIn) - canopy.append(nextDag.leaves) - fullCanopy.append(fullCanopy[-1]) - for leaf in nextDag.leaves: - fullCanopy[-1].append(leaf) - nextDag = self.pruneLeaves(dagIn) - return (canopy, fullCanopy) - - def inPast(self, dagIn, y, x): - """ self.inPast(dag, y,x) if and only if y is in the past of x in dag """ - found = False - if y in dagIn.blocks[x].parents: - found = True - else: - q = deque() - for parent in dagIn.blocks[x].parents: - q.append(parent) - while(len(q)>0): - nextBlock = q.popleft() - if y in dagIn.blocks[nextBlock].parents: - found = True - break - else: - for parent in dagIn.blocks[nextBlock].parents: - q.append(parent) - return found - - def getPast(self, dagIn, block): - subdag = BlockDAG() - subdag.genesis = dagIn.genesis - q = deque() - for child in dagIn.genesis.children: - if self.inPast(dagIn,child,block): - q.append(child) - while len(q) > 0: - nextBlock = q.popleft() - subdag.makeBlock(dagIn.blocks[nextBlock]) - for child in dagIn.blocks[nextBlock].children: - if self.inPast(dagIn,child,block): - q.append(child) - return subdag - - -class Test_BlockDAG(unittest.TestCase): - def test_BlockDAG(self): - dag = BlockDAG() - self.assertTrue("0" in dag.blocks) - self.assertTrue("0" in dag.leaves) - self.assertTrue(len(dag.blocks)==1) - self.assertTrue(len(dag.leaves)==1) - b0 = dag.genesis - - dag.makeBlock("1",{"0":b0}) - b1 = dag.blocks["1"] - self.assertTrue("1" in dag.blocks) - self.assertTrue("1" in dag.leaves) - self.assertTrue("0" not in dag.leaves) - self.assertTrue(len(dag.blocks)==2) - self.assertTrue(len(dag.leaves)==1) - self.assertTrue("1" in b0.children) - self.assertTrue("1" in dag.genesis.children) - self.assertTrue("1" in dag.blocks[dag.genesis.id].children) - - dag.makeBlock("2", {"0":b0}) - b2 = dag.blocks["2"] - - dag.makeBlock("3", {"1":b1, "2":b2}) - b3 = dag.blocks["3"] - - dag.makeBlock("4", {"2":b2}) - b4 = dag.blocks["4"] - - print(dag.computeVote(dag)) - - - -suite = unittest.TestLoader().loadTestsFromTestCase(Test_BlockDAG) -unittest.TextTestRunner(verbosity=1).run(suite) From cdaeb98d1d83fe9c9330cadf0188159af7ea0186 Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Tue, 12 Dec 2017 15:33:05 -0700 Subject: [PATCH 04/11] udpate ffs --- source-code/Spectre/{RoBlocks.py => BlockHandler.py} | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename source-code/Spectre/{RoBlocks.py => BlockHandler.py} (97%) diff --git a/source-code/Spectre/RoBlocks.py b/source-code/Spectre/BlockHandler.py similarity index 97% rename from source-code/Spectre/RoBlocks.py rename to source-code/Spectre/BlockHandler.py index dc42aa7..80bdf1f 100644 --- a/source-code/Spectre/RoBlocks.py +++ b/source-code/Spectre/BlockHandler.py @@ -7,14 +7,14 @@ between blocks. Another handler will extract a coherent/robust list of non-conflict- - ing transactions from a reduced/robust RoBlocks object. + ing transactions from a reduced/robust BlockHandler object. ''' from Block import * -class RoBlocks(object): +class BlockHandler(object): def __init__(self): print("Initializing") - # Initialize a RoBlocks object. + # Initialize a BlockHandler object. self.data = None self.blocks = {} # Set of blocks (which track parents) self.family = {} # Doubly linked list tracks parent-and-child links @@ -214,7 +214,7 @@ class RoBlocks(object): def pruneLeaves(self): print("Pruning leaves") - out = RoBlocks() + out = BlockHandler() q = deque() for rootIdent in self.roots: q.append(rootIdent) @@ -227,8 +227,8 @@ class RoBlocks(object): return out class Test_RoBlock(unittest.TestCase): - def test_RoBlocks(self): - R = RoBlocks() + def test_BlockHandler(self): + R = BlockHandler() b = Block() b.data = "zirconium encrusted tweezers" b._recomputeIdent() From ba5049ddccc97469835374c36a4fd1025b4491ec Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Tue, 12 Dec 2017 15:33:24 -0700 Subject: [PATCH 05/11] udpate ffs again --- source-code/Spectre/Block.py | 13 +- source-code/Spectre/BlockHandler.py | 274 ++++++++-------------------- 2 files changed, 81 insertions(+), 206 deletions(-) diff --git a/source-code/Spectre/Block.py b/source-code/Spectre/Block.py index 399fad7..f92b81e 100644 --- a/source-code/Spectre/Block.py +++ b/source-code/Spectre/Block.py @@ -16,18 +16,15 @@ class Block(object): and updates parents to include. _recomputeIdent : recomputes identity Usage: - b0 = Block() - b0.data = ... - b1 = Block() - b1.data = ... - b1.addParents({b0.ident:b0}) + b0 = Block(dataIn = stuff, parentsIn = None) + b1 = Block(dataIn = otherStuff, parentsIn = { b0.ident : b0 }) """ - def __init__(self): + def __init__(self, dataIn=None, parentsIn=None): # Initialize with empty payload, no identity, and empty parents. - self.data = None + self.data = dataIn self.ident = hash(str(0)) - self.parents = None + self.parents = parentsIn self.addParents({}) def addParents(self, parentsIn): # dict of parents diff --git a/source-code/Spectre/BlockHandler.py b/source-code/Spectre/BlockHandler.py index 80bdf1f..c72dd23 100644 --- a/source-code/Spectre/BlockHandler.py +++ b/source-code/Spectre/BlockHandler.py @@ -13,7 +13,7 @@ from Block import * class BlockHandler(object): def __init__(self): - print("Initializing") + #print("Initializing") # Initialize a BlockHandler object. self.data = None self.blocks = {} # Set of blocks (which track parents) @@ -32,214 +32,92 @@ class BlockHandler(object): self._addBlock(blocksIn[b]) def _addBlock(self, b): - print("Adding block") + #print("Adding block") # Take a single block b and add to self.blocks, record family # relations, update leaf monitor, update root monitor if nec- # essary - diffDict = {b.ident:b} - self.blocks.update(diffDict) - self.family.update({b.ident:{}}) - self.family[b.ident].update({"parents":b.parents, "children":{}}) - for parentIdent in b.parents: - if parentIdent not in self.family: - self.family.update({parentIdent:{}}) - if "parents" not in self.family[parentIdent]: - self.family[parentIdent].update({"parents":{}}) - if "children" not in self.family[parentIdent]: - self.family[parentIdent].update({"children":{}}) - self.family[parentIdent]["parents"].update(b.parents) - self.family[parentIdent]["children"].update(diffDict) - if parentIdent in self.leaves: - del self.leaves[parentIdent] - if len(b.parents)==0 and b.ident not in self.roots: - self.roots.update(diffDict) - self.leaves.update(diffDict) - - def inPast(self, x, y): - print("Testing if in past") - # Return true if y is an ancestor of x - q = deque() - for pid in self.blocks[x].parents: - if pid==y: - return True - break - q.append(pid) - while(len(q)>0): - nxtIdent = q.popleft() - if len(self.blocks[nxtIdent].parents) > 0: - for pid in self.blocks[nxtIdent].parents: - if pid==y: - return True - break - q.append(pid) - return False - - - def vote(self): - print("Voting") - # Compute partial spectre vote for top several layers of - # the dag. - (U, vids) = self.leafBackAntichain() - self.votes = {} - - q = deque() - self.pendingVotes = {} - for i in range(len(U)): - for leafId in U[i]: - if i > 0: - self.sumPendingVotes(leafId, vids) - for x in U[i]: - if x != leafId: - q.append(x) - while(len(q)>0): - x = q.popleft() - if (leafId, leafId, x) not in self.votes: - self.votes.update({(leafId, leafId, x):1}) - else: - try: - assert self.votes[(leafId, leafId, x)]==1 - except AssertionError: - print("Woops, we found (leafId, leafId, x) as a key in self.votes while running vote(), and it should be +1, but it isn't:\n\n", (leafId, leafId, x), self.votes[(leafId, leafId, x)], "\n\n") - if (leafId, x, leafId) not in self.votes: - self.votes.update({(leafId, x, leafId):-1}) - else: - try: - assert self.votes[(leafId,x,leafId)]==-1 - except AssertionError: - print("Woops, we found (leafId, x, leafId) as a key in self.votes while running vote(), and it should be +1, but it isn't:\n\n", (leafId, x, leafId), self.votes[(leafId, x, leafId)], "\n\n") - self.transmitVote(leafId, leafId, x) - for pid in self.blocks[x].parents: - if not self.inPast(leafId, pid) and pid in vids and pid != leafId: - q.append(pid) - print(self.votes) - - def sumPendingVotes(self, blockId, vulnIds): - print("Summing pending votes") - # For a blockId, take all pending votes for vulnerable IDs (x,y) - # if the net is positive vote 1, if the net is negative vote -1 - # otherwise vote 0. - for x in vulnIds: - for y in vulnIds: - if (blockId, x, y) in self.pendingVotes: - if self.pendingVotes[(blockId, x, y)] > 0: - if (blockId, x, y) not in self.votes: - self.votes.update({(blockId, x, y):1}) - else: - try: - assert self.votes[(blockId,x,y)]==1 - except AssertionError: - print("Woops, we found (blockId, x, y) as a key in self.votes, and it should be +1, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") - if (blockId, y, x) not in self.votes: - self.votes.update({(blockId, y, x):-1}) - else: - try: - assert self.votes[(blockId,y,x)]==-1 - except AssertionError: - print("Woops, we found (blockId, y, x) as a key in self.votes, and it should be -1, but it isn't:\n\n", (blockId, y, x), self.votes[(blockId, y,x)], "\n\n") - self.transmitVote(blockId, x, y) - elif self.pendingVotes[(blockId, x, y)] < 0: - if (blockId, x, y) not in self.votes: - self.votes.update({(blockId, x, y):-1}) - else: - try: - assert self.votes[(blockId,x,y)]==-1 - except AssertionError: - print("Woops, we found (blockId, x, y) as a key in self.votes, and it should be -1, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") - - if (blockId, y, x) not in self.votes: - self.votes.update({(blockId, y, x):1}) - else: - try: - assert self.votes[(blockId,y,x)]==1 - except AssertionError: - print("Woops, we found (blockId, y, x) as a key in self.votes, and it should be +1, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") - self.transmitVote(blockId, y, x) - else: - if (blockId, x, y) not in self.votes: - self.votes.update({(blockId, x, y):0}) - else: - try: - assert self.votes[(blockId,x,y)]==0 - except AssertionError: - print("Woops, we found (blockId, x, y) as a key in self.votes, and it should be 0, but it isn't:\n\n", (blockId, x, y), self.votes[(blockId, x,y)], "\n\n") - if (blockId, y, x) not in self.votes: - self.votes.update({(blockId, y, x):0}) - else: - try: - assert self.votes[(blockId,y,x)]==0 - except AssertionError: - print("Woops, we found (blockId, y, x) as a key in self.votes, and it should be 0, but it isn't:\n\n", (blockId, y, x), self.votes[(blockId, y,x)], "\n\n") - - - + diffDict = {copy.deepcopy(b.ident):copy.deepcopy(b)} + + try: + assert b.ident not in self.blocks + except AssertionError: + print("Woops, tried to add a block with ident in self.blocks, overwriting old block") + self.blocks.update(diffDict) + + try: + assert b.ident not in self.leaves + except AssertionError: + print("Woops, tried to add a block to leaf set that is already in the leafset, aborting.") + self.leaves.update(diffDict) # New block is always a leaf + + try: + assert b.ident not in self.family + except AssertionError: + print("woops, tried to add a block that already has a recorded family history, aborting.") + self.family.update({b.ident:{"parents":b.parents, "children":{}}}) # Add fam history fam + + # Now update each parent's family history to reflect the new child + if len(b.parents)>0: + for parentIdent in b.parents: + if parentIdent not in self.family: + # This should never occur. + print("Hey, what? confusedTravolta.gif... parentIdent not in self.family, parent not correct somehow.") + self.family.update({parentIdent:{}}) + + if "parents" not in self.family[parentIdent]: + # This should never occur. + print("Hey, what? confusedTravolta.gif... family history of parent lacks sub-dict for parentage, parent not correct somehow") + self.family[parentIdent].update({"parents":{}}) + + if "children" not in self.family[parentIdent]: + # This should never occur. + print("Hey, what? confusedTravolta.gif... family history of parent lacks sub-dict for children, parent not correct somehow") + self.family[parentIdent].update({"children":{}}) + + # Make sure grandparents are stored correctly (does nothing if already stored correctly) + self.family[parentIdent]["parents"].update(self.blocks[parentIdent].parents) + + # Update "children" sub-dict of family history of parent + self.family[parentIdent]["children"].update(diffDict) + + # If the parent was previously a leaf, it is no longer + if parentIdent in self.leaves: + del self.leaves[parentIdent] + + else: + if b.ident not in self.roots: + self.roots.update(diffDict) + self.leaves.update(diffDict) + self.family.update({b.ident:{"parents":{}, "children":{}}}) - def transmitVote(self, v, x, y): - print("Transmitting votes") - q = deque() - for pid in self.blocks[v].parents: - q.append(pid) - while(len(q)>0): - print("Length of queue = ", len(q)) - nxtPid = q.popleft() - if (nxtPid, x, y) not in self.pendingVotes: - self.pendingVotes.update({(nxtPid,x,y):1}) - self.pendingVotes.update({(nxtPid,y,x):-1}) - else: - self.pendingVotes[(nxtPid,x,y)] += 1 - self.pendingVotes[(nxtPid,y,x)] -= 1 - if len(self.blocks[nxtPid].parents) > 0: - for pid in self.blocks[nxtPid].parents: - if pid != nxtPid: - q.append(pid) - - - - - def leafBackAntichain(self): - print("Computing antichain") - temp = copy.deepcopy(self) - decomposition = [] - vulnIdents = None - decomposition.append(temp.leaves) - vulnIdents = decomposition[-1] - temp = temp.pruneLeaves() - while(len(temp.blocks)>0 and len(decomposition) < self.antichainCutoff): - decomposition.append(temp.leaves) - for xid in decomposition[-1]: - if xid not in vulnIdents: - vulnIdents.update({xid:decomposition[-1][xid]}) - temp = temp.pruneLeaves() - return decomposition, vulnIdents - - def pruneLeaves(self): - print("Pruning leaves") - out = BlockHandler() - q = deque() - for rootIdent in self.roots: - q.append(rootIdent) - while(len(q)>0): - thisIdent = q.popleft() - if thisIdent not in self.leaves: - out._addBlock(self.blocks[thisIdent]) - for chIdent in self.family[thisIdent]["children"]: - q.append(chIdent) - return out class Test_RoBlock(unittest.TestCase): def test_BlockHandler(self): R = BlockHandler() - b = Block() - b.data = "zirconium encrusted tweezers" - b._recomputeIdent() + b = Block(dataIn="zirconium encrusted tweezers", parentsIn={}) R._addBlock(b) - - b = Block() - b.data = "brontosaurus slippers cannot exist" - b.addParents(R.leaves) + diffDict = {copy.deepcopy(b.ident) : copy.deepcopy(b)} + #print("Differential ", diffDict) + b = Block(dataIn="brontosaurus slippers do not exist", parentsIn=copy.deepcopy(diffDict)) R._addBlock(b) - - R.vote() + #print("Blocks ", R.blocks) + #print("Family ", R.family) + #print("Leaves ", R.leaves) + self.assertEqual(len(R.blocks),2) + self.assertEqual(len(R.family),2) + self.assertEqual(len(R.leaves),1) + #print("Differential ", diffDict) + key, value = diffDict.popitem() + #print("Differential ", diffDict) + #print("Outputted values ", key, value) + #print("b.ident should = leaf.ident", b.ident) + self.assertTrue(key in R.blocks) + self.assertTrue(b.ident in R.blocks) + self.assertTrue(key in R.family[b.ident]["parents"]) + self.assertTrue(b.ident in R.family[key]["children"]) + + suite = unittest.TestLoader().loadTestsFromTestCase(Test_RoBlock) unittest.TextTestRunner(verbosity=1).run(suite) From 144594324623253dd96c4a6f2a802bdf0055aa0d Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Wed, 13 Dec 2017 10:46:00 -0700 Subject: [PATCH 06/11] voteFor, pendingVotes, transmitVote working --- source-code/Spectre/Block.py | 9 +- source-code/Spectre/BlockHandler.py | 462 ++++++++++++++++++++++++---- 2 files changed, 410 insertions(+), 61 deletions(-) diff --git a/source-code/Spectre/Block.py b/source-code/Spectre/Block.py index f92b81e..42a7eed 100644 --- a/source-code/Spectre/Block.py +++ b/source-code/Spectre/Block.py @@ -20,18 +20,19 @@ class Block(object): b1 = Block(dataIn = otherStuff, parentsIn = { b0.ident : b0 }) """ - def __init__(self, dataIn=None, parentsIn=None): + def __init__(self, dataIn=None, parentsIn=[]): # Initialize with empty payload, no identity, and empty parents. self.data = dataIn self.ident = hash(str(0)) + assert type(parentsIn)==type([]) self.parents = parentsIn - self.addParents({}) + self._recomputeIdent() - def addParents(self, parentsIn): # dict of parents + def addParents(self, parentsIn=[]): # list of parentIdents if self.parents is None: self.parents = parentsIn else: - self.parents.update(parentsIn) + self.parents = self.parents + parentsIn self._recomputeIdent() def _recomputeIdent(self): diff --git a/source-code/Spectre/BlockHandler.py b/source-code/Spectre/BlockHandler.py index c72dd23..3cac181 100644 --- a/source-code/Spectre/BlockHandler.py +++ b/source-code/Spectre/BlockHandler.py @@ -10,6 +10,7 @@ ing transactions from a reduced/robust BlockHandler object. ''' from Block import * +import random class BlockHandler(object): def __init__(self): @@ -19,17 +20,52 @@ class BlockHandler(object): self.blocks = {} # Set of blocks (which track parents) self.family = {} # Doubly linked list tracks parent-and-child links self.invDLL = {} # subset of blocks unlikely to be re-orged - self.roots = {} # dict {blockIdent : block} root blocks - self.leaves = {} + self.roots = [] # list of root blockIdents + self.leaves = [] # list of leaf blockIdents + self.antichains = [] + self.vids = [] self.antichainCutoff = 600 # stop re-orging after this many layers self.pendingVotes = {} self.votes = {} - - def _addBlocks(self, blocksIn): - print("Adding Blocks") - # Take dict of {blockIdent : block} and call _addBlock on each. - for b in blocksIn: - self._addBlock(blocksIn[b]) + self.oldVotes = {} + def sumPendingVote(self, vid, touched): + for (xid,yid) in zip(self.vids,self.vids): + if (vid, xid, yid) in self.pendingVotes: + if self.pendingVotes[(vid,xid,yid)] > 0: + touched = self.voteFor((vid,xid,yid), touched) + elif self.pendingVotes[(vid,xid,yid)] <0: + touched = self.voteFor((vid,yid,xid), touched) + else: + self.votes.update({(vid,xid,yid): 0, (vid,yid,xid): 0}) + touched.update({(vid,xid,yid): True, (vid,yid,xid): True}) + return touched + + def voteFor(self, votingIdents, touched): + (vid, xid, yid) = votingIdents + self.votes.update({(vid,xid,yid):1, (vid,yid,xid):-1}) + touched.update({(vid,xid,yid):True, (vid,yid,xid):True}) + self.transmitVote((vid,xid,yid)) + return touched + + def transmitVote(self, votingIdents): + (vid, xid, yid) = votingIdents + q = deque() + for wid in self.blocks[vid].parents: + if wid in self.vids: + q.append(wid) + while(len(q)>0): + wid = q.popleft() + if (wid,xid,yid) not in self.pendingVotes: + self.pendingVotes.update({(wid,xid,yid):0}) + if (wid,yid,xid) not in self.pendingVotes: + self.pendingVotes.update({(wid,yid,xid):0}) + self.pendingVotes[(wid,xid,yid)]+=1 + self.pendingVotes[(wid,yid,xid)]-=1 + #print(self.blocks[wid].parents) + for pid in self.blocks[wid].parents: + if pid in self.vids: + q.append(pid) + def _addBlock(self, b): #print("Adding block") @@ -49,73 +85,385 @@ class BlockHandler(object): assert b.ident not in self.leaves except AssertionError: print("Woops, tried to add a block to leaf set that is already in the leafset, aborting.") - self.leaves.update(diffDict) # New block is always a leaf + self.leaves.append(b.ident) # New block is always a leaf try: assert b.ident not in self.family except AssertionError: print("woops, tried to add a block that already has a recorded family history, aborting.") - self.family.update({b.ident:{"parents":b.parents, "children":{}}}) # Add fam history fam + self.family.update({b.ident:{"parents":b.parents, "children":[]}}) + # Add fam history fam (new blocks have no children yet) # Now update each parent's family history to reflect the new child - if len(b.parents)>0: - for parentIdent in b.parents: - if parentIdent not in self.family: - # This should never occur. - print("Hey, what? confusedTravolta.gif... parentIdent not in self.family, parent not correct somehow.") - self.family.update({parentIdent:{}}) + if b.parents is not None: + if len(b.parents)>0: + for parentIdent in b.parents: + if parentIdent not in self.family: + # This should never occur. + print("Hey, what? confusedTravolta.gif... parentIdent not in self.family, parent not correct somehow.") + self.family.update({parentIdent:{}}) - if "parents" not in self.family[parentIdent]: - # This should never occur. - print("Hey, what? confusedTravolta.gif... family history of parent lacks sub-dict for parentage, parent not correct somehow") - self.family[parentIdent].update({"parents":{}}) + if "parents" not in self.family[parentIdent]: + # This should never occur. + print("Hey, what? confusedTravolta.gif... family history of parent lacks sub-dict for parentage, parent not correct somehow") + self.family[parentIdent].update({"parents":[]}) - if "children" not in self.family[parentIdent]: - # This should never occur. - print("Hey, what? confusedTravolta.gif... family history of parent lacks sub-dict for children, parent not correct somehow") - self.family[parentIdent].update({"children":{}}) + if "children" not in self.family[parentIdent]: + # This should never occur. + print("Hey, what? confusedTravolta.gif... family history of parent lacks sub-dict for children, parent not correct somehow") + self.family[parentIdent].update({"children":[]}) - # Make sure grandparents are stored correctly (does nothing if already stored correctly) - self.family[parentIdent]["parents"].update(self.blocks[parentIdent].parents) - - # Update "children" sub-dict of family history of parent - self.family[parentIdent]["children"].update(diffDict) + if self.blocks[parentIdent].parents is not None: + for pid in self.blocks[parentIdent].parents: + if pid not in self.family[parentIdent]["parents"]: + self.family[parentIdent]["parents"].append(pid) + #for p in self.blocks[parentIdent].parents: self.family[parentIdent]["parents"].append(p) + + # Update "children" sub-dict of family history of parent + self.family[parentIdent]["children"].append(b.ident) - # If the parent was previously a leaf, it is no longer - if parentIdent in self.leaves: - del self.leaves[parentIdent] + # If the parent was previously a leaf, it is no longer + if parentIdent in self.leaves: + self.leaves.remove(parentIdent) + + else: + if b.ident not in self.roots: + self.roots.append(b.ident) + if b.ident not in self.leaves: + self.leaves.append(b.ident) + if b.ident not in self.family: + self.family.update({b.ident:{"parents":{}, "children":{}}}) else: if b.ident not in self.roots: - self.roots.update(diffDict) - self.leaves.update(diffDict) + self.roots.append(b.ident) + if b.ident not in self.leaves: + self.leaves.append(b.ident) + if b.ident not in self.family: self.family.update({b.ident:{"parents":{}, "children":{}}}) - + + def _hasAncestor(self, xid, yid): + # Return true if y is an ancestor of x + assert xid in self.blocks + assert yid in self.blocks + q = deque() + found = False + if self.blocks[xid].parents is not None: + for pid in self.blocks[xid].parents: + if pid==yid: + found = True + break + q.append(pid) + while(len(q)>0 and not found): + xid = q.popleft() + if self.blocks[xid].parents is not None: + if len(self.blocks[xid].parents) > 0: + for pid in self.blocks[xid].parents: + if pid==yid: + found = True + q.append(pid) + return found + + def pruneLeaves(self): + #print("Pruning leaves") + out = BlockHandler() + q = deque() + for rootIdent in self.roots: + q.append(rootIdent) + while(len(q)>0): + thisIdent = q.popleft() + if thisIdent not in self.leaves: + out._addBlock(self.blocks[thisIdent]) + for chIdent in self.family[thisIdent]["children"]: + q.append(chIdent) + return out + def leafBackAntichain(self): + #print("Computing antichain") + temp = copy.deepcopy(self) + decomposition = [] + vulnIdents = [] + + decomposition.append([]) + for lid in temp.leaves: + decomposition[-1].append(lid) + vulnIdents = copy.deepcopy(decomposition[-1]) + temp = temp.pruneLeaves() + while(len(temp.blocks)>0 and len(decomposition) < self.antichainCutoff): + decomposition.append([]) + for lid in temp.leaves: + decomposition[-1].append(lid) + for xid in decomposition[-1]: + if xid not in vulnIdents: + vulnIdents.append(xid) + temp = temp.pruneLeaves() + return decomposition, vulnIdents + class Test_RoBlock(unittest.TestCase): - def test_BlockHandler(self): + def test_betterTest(self): R = BlockHandler() - b = Block(dataIn="zirconium encrusted tweezers", parentsIn={}) - R._addBlock(b) - diffDict = {copy.deepcopy(b.ident) : copy.deepcopy(b)} - #print("Differential ", diffDict) - b = Block(dataIn="brontosaurus slippers do not exist", parentsIn=copy.deepcopy(diffDict)) - R._addBlock(b) - #print("Blocks ", R.blocks) - #print("Family ", R.family) - #print("Leaves ", R.leaves) - self.assertEqual(len(R.blocks),2) - self.assertEqual(len(R.family),2) - self.assertEqual(len(R.leaves),1) - #print("Differential ", diffDict) - key, value = diffDict.popitem() - #print("Differential ", diffDict) - #print("Outputted values ", key, value) - #print("b.ident should = leaf.ident", b.ident) - self.assertTrue(key in R.blocks) - self.assertTrue(b.ident in R.blocks) - self.assertTrue(key in R.family[b.ident]["parents"]) - self.assertTrue(b.ident in R.family[key]["children"]) + self.assertTrue(R.data is None) + self.assertEqual(len(R.blocks),0) + self.assertEqual(type(R.blocks),type({})) + self.assertEqual(len(R.family),0) + self.assertEqual(type(R.family),type({})) + self.assertEqual(len(R.invDLL),0) + self.assertEqual(type(R.invDLL),type({})) + self.assertEqual(len(R.roots),0) + self.assertEqual(type(R.leaves),type([])) + self.assertEqual(len(R.leaves),0) + self.assertEqual(R.antichainCutoff,600) + self.assertEqual(type(R.roots),type([])) + self.assertEqual(len(R.pendingVotes),0) + self.assertEqual(type(R.pendingVotes),type({})) + self.assertEqual(len(R.votes),0) + self.assertEqual(type(R.votes),type({})) + + gen = Block() # genesis block + self.assertTrue(gen.data is None) + self.assertEqual(gen.parents,[]) + msg = str(0) + str(None) + str([]) + self.assertEqual(gen.ident, hash(msg)) + + block0 = gen + block1 = Block(parentsIn=[block0.ident], dataIn={"timestamp":time.time(), "txns":"pair of zircon encrusted tweezers"}) + block2 = Block(parentsIn=[block1.ident], dataIn={"timestamp":time.time(), "txns":"watch out for that yellow snow"}) + block3 = Block(parentsIn=[block1.ident], dataIn={"timestamp":time.time(), "txns":"he had the stank foot"}) + block4 = Block(parentsIn=[block2.ident, block3.ident], dataIn={"timestamp":time.time(), "txns":"come here fido"}) + block5 = Block(parentsIn=[block3.ident], dataIn={"timestamp":time.time(), "txns":"applied rotation on her sugar plum"}) + block6 = Block(parentsIn=[block5.ident], dataIn={"timestamp":time.time(), "txns":"listen to frank zappa for the love of all that is good"}) + R._addBlock(block0) + self.assertTrue(block0.ident in R.leaves) + self.assertTrue(block0.ident in R.roots) + + R._addBlock(block1) + self.assertTrue(block1.ident in R.leaves and block0.ident not in R.leaves) + R._addBlock(block2) + self.assertTrue(block2.ident in R.leaves and block1.ident not in R.leaves) + R._addBlock(block3) + self.assertTrue(block3.ident in R.leaves and block2.ident in R.leaves and block1.ident not in R.leaves) + + R._addBlock(block4) + self.assertTrue(block4.ident in R.leaves and block3.ident not in R.leaves and block2.ident not in R.leaves) + + R._addBlock(block5) + self.assertTrue(block4.ident in R.leaves and block5.ident in R.leaves and block3.ident not in R.leaves) + + R._addBlock(block6) + self.assertTrue(block4.ident in R.leaves and block6.ident in R.leaves and block5.ident not in R.leaves) + + self.assertEqual(len(R.blocks), 7) + self.assertEqual(len(R.family), 7) + self.assertEqual(len(R.invDLL), 0) + self.assertEqual(len(R.roots), 1) + self.assertEqual(len(R.leaves),2) + self.assertEqual(R.antichainCutoff, 600) + self.assertEqual(len(R.pendingVotes),0) + self.assertEqual(len(R.votes),0) + + self.assertTrue( R._hasAncestor(block6.ident, block0.ident) and not R._hasAncestor(block0.ident, block6.ident)) + self.assertTrue( R._hasAncestor(block5.ident, block0.ident) and not R._hasAncestor(block0.ident, block5.ident)) + self.assertTrue( R._hasAncestor(block4.ident, block0.ident) and not R._hasAncestor(block0.ident, block4.ident)) + self.assertTrue( R._hasAncestor(block3.ident, block0.ident) and not R._hasAncestor(block0.ident, block3.ident)) + self.assertTrue( R._hasAncestor(block2.ident, block0.ident) and not R._hasAncestor(block0.ident, block2.ident)) + self.assertTrue( R._hasAncestor(block1.ident, block0.ident) and not R._hasAncestor(block0.ident, block1.ident)) + + self.assertTrue( R._hasAncestor(block6.ident, block1.ident) and not R._hasAncestor(block1.ident, block6.ident)) + self.assertTrue( R._hasAncestor(block5.ident, block1.ident) and not R._hasAncestor(block1.ident, block5.ident)) + self.assertTrue( R._hasAncestor(block4.ident, block1.ident) and not R._hasAncestor(block1.ident, block4.ident)) + self.assertTrue( R._hasAncestor(block3.ident, block1.ident) and not R._hasAncestor(block1.ident, block3.ident)) + self.assertTrue( R._hasAncestor(block2.ident, block1.ident) and not R._hasAncestor(block1.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block1.ident) and R._hasAncestor(block1.ident, block0.ident)) + + self.assertTrue(not R._hasAncestor(block6.ident, block2.ident) and not R._hasAncestor(block2.ident, block6.ident)) + self.assertTrue(not R._hasAncestor(block5.ident, block2.ident) and not R._hasAncestor(block2.ident, block5.ident)) + self.assertTrue( R._hasAncestor(block4.ident, block2.ident) and not R._hasAncestor(block2.ident, block4.ident)) + self.assertTrue(not R._hasAncestor(block3.ident, block2.ident) and not R._hasAncestor(block2.ident, block3.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block2.ident) and R._hasAncestor(block2.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block2.ident) and R._hasAncestor(block2.ident, block0.ident)) + + self.assertTrue( R._hasAncestor(block6.ident, block3.ident) and not R._hasAncestor(block3.ident, block6.ident)) + self.assertTrue( R._hasAncestor(block5.ident, block3.ident) and not R._hasAncestor(block3.ident, block5.ident)) + self.assertTrue( R._hasAncestor(block4.ident, block3.ident) and not R._hasAncestor(block3.ident, block4.ident)) + self.assertTrue(not R._hasAncestor(block2.ident, block3.ident) and not R._hasAncestor(block3.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block3.ident) and R._hasAncestor(block3.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block3.ident) and R._hasAncestor(block3.ident, block0.ident)) + + self.assertTrue(not R._hasAncestor(block6.ident, block4.ident) and not R._hasAncestor(block4.ident, block6.ident)) + self.assertTrue(not R._hasAncestor(block5.ident, block4.ident) and not R._hasAncestor(block4.ident, block5.ident)) + self.assertTrue(not R._hasAncestor(block3.ident, block4.ident) and R._hasAncestor(block4.ident, block3.ident)) + self.assertTrue(not R._hasAncestor(block2.ident, block4.ident) and R._hasAncestor(block4.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block4.ident) and R._hasAncestor(block4.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block4.ident) and R._hasAncestor(block4.ident, block0.ident)) + + self.assertTrue( R._hasAncestor(block6.ident, block5.ident) and not R._hasAncestor(block5.ident, block6.ident)) + self.assertTrue(not R._hasAncestor(block4.ident, block5.ident) and not R._hasAncestor(block5.ident, block4.ident)) + self.assertTrue(not R._hasAncestor(block3.ident, block5.ident) and R._hasAncestor(block5.ident, block3.ident)) + self.assertTrue(not R._hasAncestor(block2.ident, block5.ident) and not R._hasAncestor(block5.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block5.ident) and R._hasAncestor(block5.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block5.ident) and R._hasAncestor(block5.ident, block0.ident)) + + self.assertTrue(not R._hasAncestor(block5.ident, block6.ident) and R._hasAncestor(block6.ident, block5.ident)) + self.assertTrue(not R._hasAncestor(block4.ident, block6.ident) and not R._hasAncestor(block6.ident, block4.ident)) + self.assertTrue(not R._hasAncestor(block3.ident, block6.ident) and R._hasAncestor(block6.ident, block3.ident)) + self.assertTrue(not R._hasAncestor(block2.ident, block6.ident) and not R._hasAncestor(block6.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block6.ident) and R._hasAncestor(block6.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block6.ident) and R._hasAncestor(block6.ident, block0.ident)) + + R = R.pruneLeaves() + + self.assertEqual(len(R.blocks), 5) + self.assertEqual(len(R.family), 5) + self.assertEqual(len(R.invDLL), 0) + self.assertEqual(len(R.roots), 1) + self.assertEqual(len(R.leaves),2) + self.assertEqual(R.antichainCutoff, 600) + self.assertEqual(len(R.pendingVotes),0) + self.assertEqual(len(R.votes),0) + + self.assertTrue( R._hasAncestor(block5.ident, block0.ident) and not R._hasAncestor(block0.ident, block5.ident)) + self.assertTrue( R._hasAncestor(block3.ident, block0.ident) and not R._hasAncestor(block0.ident, block3.ident)) + self.assertTrue( R._hasAncestor(block2.ident, block0.ident) and not R._hasAncestor(block0.ident, block2.ident)) + self.assertTrue( R._hasAncestor(block1.ident, block0.ident) and not R._hasAncestor(block0.ident, block1.ident)) + + self.assertTrue( R._hasAncestor(block5.ident, block1.ident) and not R._hasAncestor(block1.ident, block5.ident)) + self.assertTrue( R._hasAncestor(block3.ident, block1.ident) and not R._hasAncestor(block1.ident, block3.ident)) + self.assertTrue( R._hasAncestor(block2.ident, block1.ident) and not R._hasAncestor(block1.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block1.ident) and R._hasAncestor(block1.ident, block0.ident)) + + self.assertTrue(not R._hasAncestor(block5.ident, block2.ident) and not R._hasAncestor(block2.ident, block5.ident)) + self.assertTrue(not R._hasAncestor(block3.ident, block2.ident) and not R._hasAncestor(block2.ident, block3.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block2.ident) and R._hasAncestor(block2.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block2.ident) and R._hasAncestor(block2.ident, block0.ident)) + + self.assertTrue( R._hasAncestor(block5.ident, block3.ident) and not R._hasAncestor(block3.ident, block5.ident)) + self.assertTrue(not R._hasAncestor(block2.ident, block3.ident) and not R._hasAncestor(block3.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block3.ident) and R._hasAncestor(block3.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block3.ident) and R._hasAncestor(block3.ident, block0.ident)) + + self.assertTrue(not R._hasAncestor(block3.ident, block5.ident) and R._hasAncestor(block5.ident, block3.ident)) + self.assertTrue(not R._hasAncestor(block2.ident, block5.ident) and not R._hasAncestor(block5.ident, block2.ident)) + self.assertTrue(not R._hasAncestor(block1.ident, block5.ident) and R._hasAncestor(block5.ident, block1.ident)) + self.assertTrue(not R._hasAncestor(block0.ident, block5.ident) and R._hasAncestor(block5.ident, block0.ident)) + + + ## Formal unit tests for leafBackAntichain() to follow: visual inspection reveals this does what it says on the tin. + #R.vote() + #print(R.votes) + + def test_big_bertha(self): + R = BlockHandler() + gen = Block() # genesis block + msg = str(0) + str(None) + str([]) + block0 = gen + block1 = Block(parentsIn=[block0.ident], dataIn={"timestamp":time.time(), "txns":"pair of zircon encrusted tweezers"}) + block2 = Block(parentsIn=[block1.ident], dataIn={"timestamp":time.time(), "txns":"watch out for that yellow snow"}) + block3 = Block(parentsIn=[block1.ident], dataIn={"timestamp":time.time(), "txns":"he had the stank foot"}) + block4 = Block(parentsIn=[block2.ident, block3.ident], dataIn={"timestamp":time.time(), "txns":"come here fido"}) + block5 = Block(parentsIn=[block3.ident], dataIn={"timestamp":time.time(), "txns":"applied rotation on her sugar plum"}) + block6 = Block(parentsIn=[block5.ident], dataIn={"timestamp":time.time(), "txns":"listen to frank zappa for the love of all that is good"}) + R._addBlock(block0) + R._addBlock(block1) + R._addBlock(block2) + R._addBlock(block3) + R._addBlock(block4) + R._addBlock(block5) + R._addBlock(block6) + + # Testing voteFor + # Verify all roots have children + for rid in R.roots: + self.assertFalse(len(R.family[rid]["children"])==0) + + # Verify that all children of all roots have children and collect grandchildren idents + gc = [] + for rid in R.roots: + for cid in R.family[rid]["children"]: + self.assertFalse(len(R.family[cid]["children"]) == 0) + gc = gc + R.family[cid]["children"] + + # Pick a random grandchild of the root. + gcid = random.choice(gc) + + # Pick a random block with gcid in its past + vid = random.choice(list(R.blocks.keys())) + while(not R._hasAncestor(vid, gcid)): + vid = random.choice(list(R.blocks.keys())) + + # Pick a random pair of blocks for gcid and vid to vote on. + xid = random.choice(list(R.blocks.keys())) + yid = random.choice(list(R.blocks.keys())) + + # Have vid cast vote that xid < yid + R.voteFor((vid,xid,yid),{}) + # Verify that R.votes has correct entries + self.assertEqual(R.votes[(vid,xid,yid)], 1) + self.assertEqual(R.votes[(vid,yid,xid)],-1) + + # Check that for each ancestor of vid, that they received an appropriate pending vote + q = deque() + for pid in R.blocks[vid].parents: + if pid in R.vids: + q.append(pid) + while(len(q)>0): + wid = q.popleft() + self.assertEqual(R.pendingVotes[(wid,xid,yid)],1) + for pid in R.blocks[wid].parents: + if pid in R.vids: + q.append(pid) + + # Now we are going to mess around with how voting at gcid interacts with the above. + # First, we let gcid cast a vote that xid < yid and check that it propagates appropriately as above. + R.voteFor((gcid,xid,yid),{}) + self.assertEqual(R.votes[(gcid,xid,yid)],1) + self.assertEqual(R.votes[(gcid,yid,xid)],-1) + for pid in R.blocks[vid].parents: + if pid in R.vids: + q.append(gpid) + while(len(q)>0): + wid = q.popleft() + self.assertEqual(R.pendingVotes[(wid,xid,yid)],2) + self.assertEqual(R.pendingVotes[(wid,yid,xid)],-2) + for pid in R.blocks[wid].parents: + if pid in R.vids: + q.append(pid) + # Now we are going to have gcid cast the opposite vote. this should change what is stored in R.votes + # but also change pending votes below gcid + R.voteFor((gcid,yid,xid),{}) + self.assertEqual(R.votes[(gcid,xid,yid)],-1) + self.assertEqual(R.votes[(gcid,yid,xid)],1) + for pid in R.blocks[vid].parents: + if pid in R.vids: + q.append(gpid) + while(len(q)>0): + wid = q.popleft() + self.assertEqual(R.pendingVotes[(wid,xid,yid)],0) + self.assertEqual(R.pendingVotes[(wid,yid,yid)],0) + for pid in R.blocks[wid].parents: + if pid in R.vids: + q.append(pid) + # Do again, now pending votes should be negative + R.voteFor((gcid,yid,xid),{}) + self.assertEqual(R.votes[(gcid,xid,yid)],-1) + self.assertEqual(R.votes[(gcid,yid,xid)],1) + for pid in R.blocks[vid].parents: + if pid in R.vids: + q.append(gpid) + while(len(q)>0): + wid = q.popleft() + self.assertEqual(R.pendingVotes[(wid,xid,yid)],-1) + self.assertEqual(R.pendingVotes[(wid,yid,yid)],-1) + for pid in R.blocks[wid].parents: + if pid in R.vids: + q.append(pid) + + + + + + + #R.vote() + #print(R.votes) From fef1017fb7b03a877b89a093e6e8958edb623aa5 Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Wed, 13 Dec 2017 10:53:39 -0700 Subject: [PATCH 07/11] fixed small errors in unit testing --- source-code/Spectre/BlockHandler.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/source-code/Spectre/BlockHandler.py b/source-code/Spectre/BlockHandler.py index 3cac181..407738c 100644 --- a/source-code/Spectre/BlockHandler.py +++ b/source-code/Spectre/BlockHandler.py @@ -417,7 +417,7 @@ class Test_RoBlock(unittest.TestCase): R.voteFor((gcid,xid,yid),{}) self.assertEqual(R.votes[(gcid,xid,yid)],1) self.assertEqual(R.votes[(gcid,yid,xid)],-1) - for pid in R.blocks[vid].parents: + for pid in R.blocks[gcid].parents: if pid in R.vids: q.append(gpid) while(len(q)>0): @@ -432,13 +432,13 @@ class Test_RoBlock(unittest.TestCase): R.voteFor((gcid,yid,xid),{}) self.assertEqual(R.votes[(gcid,xid,yid)],-1) self.assertEqual(R.votes[(gcid,yid,xid)],1) - for pid in R.blocks[vid].parents: + for pid in R.blocks[gcid].parents: if pid in R.vids: q.append(gpid) while(len(q)>0): wid = q.popleft() self.assertEqual(R.pendingVotes[(wid,xid,yid)],0) - self.assertEqual(R.pendingVotes[(wid,yid,yid)],0) + self.assertEqual(R.pendingVotes[(wid,yid,xid)],0) for pid in R.blocks[wid].parents: if pid in R.vids: q.append(pid) @@ -446,18 +446,23 @@ class Test_RoBlock(unittest.TestCase): R.voteFor((gcid,yid,xid),{}) self.assertEqual(R.votes[(gcid,xid,yid)],-1) self.assertEqual(R.votes[(gcid,yid,xid)],1) - for pid in R.blocks[vid].parents: + for pid in R.blocks[gcid].parents: if pid in R.vids: q.append(gpid) while(len(q)>0): wid = q.popleft() self.assertEqual(R.pendingVotes[(wid,xid,yid)],-1) - self.assertEqual(R.pendingVotes[(wid,yid,yid)],-1) + self.assertEqual(R.pendingVotes[(wid,yid,xid)],1) for pid in R.blocks[wid].parents: if pid in R.vids: q.append(pid) - + # Test sumPendingVotes + R.sumPendingVote(gcid, {}) + self.assertTrue((gcid,xid,yid) in R.votes) + self.assertTrue((gcid,yid,xid) in R.votes) + self.assertEqual(R.votes[(gcid,xid,yid)],-1) + self.assertEqual(R.votes[(gcid,yid,xid)],1) From de324cfcf743855e393069e2c10d191274a8f139 Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Wed, 13 Dec 2017 11:09:19 -0700 Subject: [PATCH 08/11] vote function added, unit test must be written on results --- source-code/Spectre/BlockHandler.py | 46 +++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/source-code/Spectre/BlockHandler.py b/source-code/Spectre/BlockHandler.py index 407738c..691a21f 100644 --- a/source-code/Spectre/BlockHandler.py +++ b/source-code/Spectre/BlockHandler.py @@ -28,6 +28,38 @@ class BlockHandler(object): self.pendingVotes = {} self.votes = {} self.oldVotes = {} + + def vote(self): + U, V = self.leafBackAntichain() + self.antichains = U + touched = {} + for i in range(len(U)): + for vid in U[i]: # ID of voting block + touched = self.sumPendingVote(vid, touched) + for j in range(i+1): + for xid in U[j]: # Voting block compares self to xid + # Note if j=i, xid and vid are incomparable. + # If j < i, then xid may have vid as an ancestor. + # vid can never have xid as an ancestor. + # In all cases, vid votes that vid precedes xid + if xid==vid: + continue + else: + touched = self.voteFor((vid,vid,xid),touched) + # For each ancestor of xid that is not an ancestor of vid, + # we can apply the same! + q = deque() + for pid in self.blocks[xid].parents: + if pid in self.vids and not self._hasAncestor(vid,pid): + q.append(pid) + while(len(q)>0): + wid = q.popleft() + for pid in self.blocks[wid].parents: + if pid in self.vids and not self._hasAncestor(vid, pid): + q.append(pid) + touched = self.voteFor((vid,vid,wid),touched) + return touched + def sumPendingVote(self, vid, touched): for (xid,yid) in zip(self.vids,self.vids): if (vid, xid, yid) in self.pendingVotes: @@ -199,7 +231,7 @@ class BlockHandler(object): temp = temp.pruneLeaves() return decomposition, vulnIdents -class Test_RoBlock(unittest.TestCase): +class Test_BlockHandler(unittest.TestCase): def test_betterTest(self): R = BlockHandler() self.assertTrue(R.data is None) @@ -464,13 +496,21 @@ class Test_RoBlock(unittest.TestCase): self.assertEqual(R.votes[(gcid,xid,yid)],-1) self.assertEqual(R.votes[(gcid,yid,xid)],1) + touched = R.vote() + print("\n ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ====") + print("Antichain layers:\n") + for layer in R.antichains: + print(layer) + print("\n ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ====") + for key in R.votes: + print("key = ", key, ", vote = ", R.votes[key]) - + print("\n ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ====") #R.vote() #print(R.votes) -suite = unittest.TestLoader().loadTestsFromTestCase(Test_RoBlock) +suite = unittest.TestLoader().loadTestsFromTestCase(Test_BlockHandler) unittest.TextTestRunner(verbosity=1).run(suite) From 4b133452079cd9e19dd40c8d1cc69d5bbde8bb4f Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Thu, 14 Dec 2017 00:24:20 -0700 Subject: [PATCH 09/11] Unit test written (but failed) for vote() --- source-code/Spectre/BlockHandler.py | 462 ++++++++++++++++++---------- 1 file changed, 307 insertions(+), 155 deletions(-) diff --git a/source-code/Spectre/BlockHandler.py b/source-code/Spectre/BlockHandler.py index 691a21f..7004fce 100644 --- a/source-code/Spectre/BlockHandler.py +++ b/source-code/Spectre/BlockHandler.py @@ -27,79 +27,9 @@ class BlockHandler(object): self.antichainCutoff = 600 # stop re-orging after this many layers self.pendingVotes = {} self.votes = {} - self.oldVotes = {} + self.totalVotes = {} - def vote(self): - U, V = self.leafBackAntichain() - self.antichains = U - touched = {} - for i in range(len(U)): - for vid in U[i]: # ID of voting block - touched = self.sumPendingVote(vid, touched) - for j in range(i+1): - for xid in U[j]: # Voting block compares self to xid - # Note if j=i, xid and vid are incomparable. - # If j < i, then xid may have vid as an ancestor. - # vid can never have xid as an ancestor. - # In all cases, vid votes that vid precedes xid - if xid==vid: - continue - else: - touched = self.voteFor((vid,vid,xid),touched) - # For each ancestor of xid that is not an ancestor of vid, - # we can apply the same! - q = deque() - for pid in self.blocks[xid].parents: - if pid in self.vids and not self._hasAncestor(vid,pid): - q.append(pid) - while(len(q)>0): - wid = q.popleft() - for pid in self.blocks[wid].parents: - if pid in self.vids and not self._hasAncestor(vid, pid): - q.append(pid) - touched = self.voteFor((vid,vid,wid),touched) - return touched - - def sumPendingVote(self, vid, touched): - for (xid,yid) in zip(self.vids,self.vids): - if (vid, xid, yid) in self.pendingVotes: - if self.pendingVotes[(vid,xid,yid)] > 0: - touched = self.voteFor((vid,xid,yid), touched) - elif self.pendingVotes[(vid,xid,yid)] <0: - touched = self.voteFor((vid,yid,xid), touched) - else: - self.votes.update({(vid,xid,yid): 0, (vid,yid,xid): 0}) - touched.update({(vid,xid,yid): True, (vid,yid,xid): True}) - return touched - - def voteFor(self, votingIdents, touched): - (vid, xid, yid) = votingIdents - self.votes.update({(vid,xid,yid):1, (vid,yid,xid):-1}) - touched.update({(vid,xid,yid):True, (vid,yid,xid):True}) - self.transmitVote((vid,xid,yid)) - return touched - - def transmitVote(self, votingIdents): - (vid, xid, yid) = votingIdents - q = deque() - for wid in self.blocks[vid].parents: - if wid in self.vids: - q.append(wid) - while(len(q)>0): - wid = q.popleft() - if (wid,xid,yid) not in self.pendingVotes: - self.pendingVotes.update({(wid,xid,yid):0}) - if (wid,yid,xid) not in self.pendingVotes: - self.pendingVotes.update({(wid,yid,xid):0}) - self.pendingVotes[(wid,xid,yid)]+=1 - self.pendingVotes[(wid,yid,xid)]-=1 - #print(self.blocks[wid].parents) - for pid in self.blocks[wid].parents: - if pid in self.vids: - q.append(pid) - - - def _addBlock(self, b): + def addBlock(self, b): #print("Adding block") # Take a single block b and add to self.blocks, record family # relations, update leaf monitor, update root monitor if nec- @@ -173,8 +103,9 @@ class BlockHandler(object): self.leaves.append(b.ident) if b.ident not in self.family: self.family.update({b.ident:{"parents":{}, "children":{}}}) - - def _hasAncestor(self, xid, yid): + pass + + def hasAncestor(self, xid, yid): # Return true if y is an ancestor of x assert xid in self.blocks assert yid in self.blocks @@ -195,7 +126,7 @@ class BlockHandler(object): found = True q.append(pid) return found - + def pruneLeaves(self): #print("Pruning leaves") out = BlockHandler() @@ -205,7 +136,7 @@ class BlockHandler(object): while(len(q)>0): thisIdent = q.popleft() if thisIdent not in self.leaves: - out._addBlock(self.blocks[thisIdent]) + out.addBlock(self.blocks[thisIdent]) for chIdent in self.family[thisIdent]["children"]: q.append(chIdent) return out @@ -231,6 +162,146 @@ class BlockHandler(object): temp = temp.pruneLeaves() return decomposition, vulnIdents + + def transmitVote(self, votingIdents): + (vid, xid, yid) = votingIdents + q = deque() + for wid in self.blocks[vid].parents: + if wid in self.vids: + q.append(wid) + while(len(q)>0): + wid = q.popleft() + if (wid,xid,yid) not in self.pendingVotes: + self.pendingVotes.update({(wid,xid,yid):0}) + if (wid,yid,xid) not in self.pendingVotes: + self.pendingVotes.update({(wid,yid,xid):0}) + self.pendingVotes[(wid,xid,yid)]+=1 + self.pendingVotes[(wid,yid,xid)]-=1 + #print(self.blocks[wid].parents) + for pid in self.blocks[wid].parents: + if pid in self.vids: + q.append(pid) + + def voteFor(self, votingIdents, touched): + (vid, xid, yid) = votingIdents + self.votes.update({(vid,xid,yid):1, (vid,yid,xid):-1}) + touched.update({(vid,xid,yid):True, (vid,yid,xid):True}) + self.transmitVote((vid,xid,yid)) + return touched + + def sumPendingVote(self, vid, touched): + pastR = self.pastOf(vid) + for xid in self.vids: + for yid in self.vids: + if (vid, xid, yid) in self.pendingVotes: + if self.pendingVotes[(vid,xid,yid)] > 0: + touched = self.voteFor((vid,xid,yid), touched) + elif self.pendingVotes[(vid,xid,yid)] <0: + touched = self.voteFor((vid,yid,xid), touched) + else: + self.votes.update({(vid,xid,yid): 0, (vid,yid,xid): 0}) + touched.update({(vid,xid,yid): True, (vid,yid,xid): True}) + #R = self.pastOf(vid) + #touched = R.vote(touched) + + return touched + + def vote(self,touchedIn={}): + U, V = self.leafBackAntichain() + self.antichains = U + self.vids = V + touched = touchedIn + for i in range(len(U)): + for vid in U[i]: # ID of voting block + touched = self.sumPendingVote(vid, touched) + for j in range(i+1): + for xid in U[j]: # Voting block compares self to xid + # Note if j=i, xid and vid are incomparable. + # If j < i, then xid may have vid as an ancestor. + # vid can never have xid as an ancestor. + # In all cases, vid votes that vid precedes xid + if xid==vid: + continue + else: + touched = self.voteFor((vid,vid,xid),touched) + # For each ancestor of xid that is not an ancestor of vid, + # we can apply the same! + q = deque() + for pid in self.blocks[xid].parents: + if pid in self.vids and not self.hasAncestor(vid,pid): + q.append(pid) + while(len(q)>0): + wid = q.popleft() + for pid in self.blocks[wid].parents: + if pid in self.vids and not self.hasAncestor(vid, pid): + q.append(pid) + touched = self.voteFor((vid,vid,wid),touched) + R = self.pastOf(vid) + R.vote() + for xid in R.blocks: + touched = self.voteFor((vid,xid,vid), touched) + for yid in R.blocks: + if (xid, yid) in R.totalVotes: + if R.totalVotes[(xid,yid)]: + touched = self.voteFor((vid,xid,yid), touched) + elif (yid, xid) in R.totalVotes: + if R.totalVotes[(yid,xid)]: + touched = self.voteFor((vid, yid, xid), touched) + self.computeTotalVotes() + + return touched + + def computeTotalVotes(self): + for xid in self.vids: + for yid in self.vids: + s = 0 + found = False + for vid in self.vids: + if (vid, xid, yid) in self.votes or (vid, yid, xid) in self.votes: + found = True + if self.votes[(vid, xid, yid)]==1: + s+= 1 + elif self.votes[(vid,yid,xid)]==-1: + s-= 1 + if found: + if s > 0: + self.totalVotes.update({(xid, yid):True, (yid,xid):False}) + elif s < 0: + self.totalVotes.update({(xid,yid):False, (yid,xid):True}) + elif s==0: + self.totalVotes.update({(xid,yid):False, (yid,xid):False}) + else: + if (xid,yid) in self.totalVotes: + del self.totalVotes[(xid,yid)] + if (yid,xid) in self.totalVotes: + del self.totalVotes[(yid,xid)] + + def pastOf(self, xid): + R = BlockHandler() + identsToAdd = {} + q = deque() + for pid in self.blocks[xid].parents: + q.append(pid) + while(len(q)>0): + yid = q.popleft() + if yid not in identsToAdd: + identsToAdd.update({yid:True}) + for pid in self.blocks[yid].parents: + q.append(pid) + for rid in self.roots: + if rid in identsToAdd: + q.append(rid) + while(len(q)>0): + yid = q.popleft() + if yid not in R.blocks: + R.addBlock(self.blocks[yid]) + for pid in self.family[yid]["children"]: + if pid in identsToAdd: + q.append(pid) + return R + + + class Test_BlockHandler(unittest.TestCase): def test_betterTest(self): R = BlockHandler() @@ -264,24 +335,24 @@ class Test_BlockHandler(unittest.TestCase): block4 = Block(parentsIn=[block2.ident, block3.ident], dataIn={"timestamp":time.time(), "txns":"come here fido"}) block5 = Block(parentsIn=[block3.ident], dataIn={"timestamp":time.time(), "txns":"applied rotation on her sugar plum"}) block6 = Block(parentsIn=[block5.ident], dataIn={"timestamp":time.time(), "txns":"listen to frank zappa for the love of all that is good"}) - R._addBlock(block0) + R.addBlock(block0) self.assertTrue(block0.ident in R.leaves) self.assertTrue(block0.ident in R.roots) - R._addBlock(block1) + R.addBlock(block1) self.assertTrue(block1.ident in R.leaves and block0.ident not in R.leaves) - R._addBlock(block2) + R.addBlock(block2) self.assertTrue(block2.ident in R.leaves and block1.ident not in R.leaves) - R._addBlock(block3) + R.addBlock(block3) self.assertTrue(block3.ident in R.leaves and block2.ident in R.leaves and block1.ident not in R.leaves) - R._addBlock(block4) + R.addBlock(block4) self.assertTrue(block4.ident in R.leaves and block3.ident not in R.leaves and block2.ident not in R.leaves) - R._addBlock(block5) + R.addBlock(block5) self.assertTrue(block4.ident in R.leaves and block5.ident in R.leaves and block3.ident not in R.leaves) - R._addBlock(block6) + R.addBlock(block6) self.assertTrue(block4.ident in R.leaves and block6.ident in R.leaves and block5.ident not in R.leaves) self.assertEqual(len(R.blocks), 7) @@ -293,54 +364,54 @@ class Test_BlockHandler(unittest.TestCase): self.assertEqual(len(R.pendingVotes),0) self.assertEqual(len(R.votes),0) - self.assertTrue( R._hasAncestor(block6.ident, block0.ident) and not R._hasAncestor(block0.ident, block6.ident)) - self.assertTrue( R._hasAncestor(block5.ident, block0.ident) and not R._hasAncestor(block0.ident, block5.ident)) - self.assertTrue( R._hasAncestor(block4.ident, block0.ident) and not R._hasAncestor(block0.ident, block4.ident)) - self.assertTrue( R._hasAncestor(block3.ident, block0.ident) and not R._hasAncestor(block0.ident, block3.ident)) - self.assertTrue( R._hasAncestor(block2.ident, block0.ident) and not R._hasAncestor(block0.ident, block2.ident)) - self.assertTrue( R._hasAncestor(block1.ident, block0.ident) and not R._hasAncestor(block0.ident, block1.ident)) + self.assertTrue( R.hasAncestor(block6.ident, block0.ident) and not R.hasAncestor(block0.ident, block6.ident)) + self.assertTrue( R.hasAncestor(block5.ident, block0.ident) and not R.hasAncestor(block0.ident, block5.ident)) + self.assertTrue( R.hasAncestor(block4.ident, block0.ident) and not R.hasAncestor(block0.ident, block4.ident)) + self.assertTrue( R.hasAncestor(block3.ident, block0.ident) and not R.hasAncestor(block0.ident, block3.ident)) + self.assertTrue( R.hasAncestor(block2.ident, block0.ident) and not R.hasAncestor(block0.ident, block2.ident)) + self.assertTrue( R.hasAncestor(block1.ident, block0.ident) and not R.hasAncestor(block0.ident, block1.ident)) - self.assertTrue( R._hasAncestor(block6.ident, block1.ident) and not R._hasAncestor(block1.ident, block6.ident)) - self.assertTrue( R._hasAncestor(block5.ident, block1.ident) and not R._hasAncestor(block1.ident, block5.ident)) - self.assertTrue( R._hasAncestor(block4.ident, block1.ident) and not R._hasAncestor(block1.ident, block4.ident)) - self.assertTrue( R._hasAncestor(block3.ident, block1.ident) and not R._hasAncestor(block1.ident, block3.ident)) - self.assertTrue( R._hasAncestor(block2.ident, block1.ident) and not R._hasAncestor(block1.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block1.ident) and R._hasAncestor(block1.ident, block0.ident)) + self.assertTrue( R.hasAncestor(block6.ident, block1.ident) and not R.hasAncestor(block1.ident, block6.ident)) + self.assertTrue( R.hasAncestor(block5.ident, block1.ident) and not R.hasAncestor(block1.ident, block5.ident)) + self.assertTrue( R.hasAncestor(block4.ident, block1.ident) and not R.hasAncestor(block1.ident, block4.ident)) + self.assertTrue( R.hasAncestor(block3.ident, block1.ident) and not R.hasAncestor(block1.ident, block3.ident)) + self.assertTrue( R.hasAncestor(block2.ident, block1.ident) and not R.hasAncestor(block1.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block1.ident) and R.hasAncestor(block1.ident, block0.ident)) - self.assertTrue(not R._hasAncestor(block6.ident, block2.ident) and not R._hasAncestor(block2.ident, block6.ident)) - self.assertTrue(not R._hasAncestor(block5.ident, block2.ident) and not R._hasAncestor(block2.ident, block5.ident)) - self.assertTrue( R._hasAncestor(block4.ident, block2.ident) and not R._hasAncestor(block2.ident, block4.ident)) - self.assertTrue(not R._hasAncestor(block3.ident, block2.ident) and not R._hasAncestor(block2.ident, block3.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block2.ident) and R._hasAncestor(block2.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block2.ident) and R._hasAncestor(block2.ident, block0.ident)) + self.assertTrue(not R.hasAncestor(block6.ident, block2.ident) and not R.hasAncestor(block2.ident, block6.ident)) + self.assertTrue(not R.hasAncestor(block5.ident, block2.ident) and not R.hasAncestor(block2.ident, block5.ident)) + self.assertTrue( R.hasAncestor(block4.ident, block2.ident) and not R.hasAncestor(block2.ident, block4.ident)) + self.assertTrue(not R.hasAncestor(block3.ident, block2.ident) and not R.hasAncestor(block2.ident, block3.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block2.ident) and R.hasAncestor(block2.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block2.ident) and R.hasAncestor(block2.ident, block0.ident)) - self.assertTrue( R._hasAncestor(block6.ident, block3.ident) and not R._hasAncestor(block3.ident, block6.ident)) - self.assertTrue( R._hasAncestor(block5.ident, block3.ident) and not R._hasAncestor(block3.ident, block5.ident)) - self.assertTrue( R._hasAncestor(block4.ident, block3.ident) and not R._hasAncestor(block3.ident, block4.ident)) - self.assertTrue(not R._hasAncestor(block2.ident, block3.ident) and not R._hasAncestor(block3.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block3.ident) and R._hasAncestor(block3.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block3.ident) and R._hasAncestor(block3.ident, block0.ident)) + self.assertTrue( R.hasAncestor(block6.ident, block3.ident) and not R.hasAncestor(block3.ident, block6.ident)) + self.assertTrue( R.hasAncestor(block5.ident, block3.ident) and not R.hasAncestor(block3.ident, block5.ident)) + self.assertTrue( R.hasAncestor(block4.ident, block3.ident) and not R.hasAncestor(block3.ident, block4.ident)) + self.assertTrue(not R.hasAncestor(block2.ident, block3.ident) and not R.hasAncestor(block3.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block3.ident) and R.hasAncestor(block3.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block3.ident) and R.hasAncestor(block3.ident, block0.ident)) - self.assertTrue(not R._hasAncestor(block6.ident, block4.ident) and not R._hasAncestor(block4.ident, block6.ident)) - self.assertTrue(not R._hasAncestor(block5.ident, block4.ident) and not R._hasAncestor(block4.ident, block5.ident)) - self.assertTrue(not R._hasAncestor(block3.ident, block4.ident) and R._hasAncestor(block4.ident, block3.ident)) - self.assertTrue(not R._hasAncestor(block2.ident, block4.ident) and R._hasAncestor(block4.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block4.ident) and R._hasAncestor(block4.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block4.ident) and R._hasAncestor(block4.ident, block0.ident)) + self.assertTrue(not R.hasAncestor(block6.ident, block4.ident) and not R.hasAncestor(block4.ident, block6.ident)) + self.assertTrue(not R.hasAncestor(block5.ident, block4.ident) and not R.hasAncestor(block4.ident, block5.ident)) + self.assertTrue(not R.hasAncestor(block3.ident, block4.ident) and R.hasAncestor(block4.ident, block3.ident)) + self.assertTrue(not R.hasAncestor(block2.ident, block4.ident) and R.hasAncestor(block4.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block4.ident) and R.hasAncestor(block4.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block4.ident) and R.hasAncestor(block4.ident, block0.ident)) - self.assertTrue( R._hasAncestor(block6.ident, block5.ident) and not R._hasAncestor(block5.ident, block6.ident)) - self.assertTrue(not R._hasAncestor(block4.ident, block5.ident) and not R._hasAncestor(block5.ident, block4.ident)) - self.assertTrue(not R._hasAncestor(block3.ident, block5.ident) and R._hasAncestor(block5.ident, block3.ident)) - self.assertTrue(not R._hasAncestor(block2.ident, block5.ident) and not R._hasAncestor(block5.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block5.ident) and R._hasAncestor(block5.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block5.ident) and R._hasAncestor(block5.ident, block0.ident)) + self.assertTrue( R.hasAncestor(block6.ident, block5.ident) and not R.hasAncestor(block5.ident, block6.ident)) + self.assertTrue(not R.hasAncestor(block4.ident, block5.ident) and not R.hasAncestor(block5.ident, block4.ident)) + self.assertTrue(not R.hasAncestor(block3.ident, block5.ident) and R.hasAncestor(block5.ident, block3.ident)) + self.assertTrue(not R.hasAncestor(block2.ident, block5.ident) and not R.hasAncestor(block5.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block5.ident) and R.hasAncestor(block5.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block5.ident) and R.hasAncestor(block5.ident, block0.ident)) - self.assertTrue(not R._hasAncestor(block5.ident, block6.ident) and R._hasAncestor(block6.ident, block5.ident)) - self.assertTrue(not R._hasAncestor(block4.ident, block6.ident) and not R._hasAncestor(block6.ident, block4.ident)) - self.assertTrue(not R._hasAncestor(block3.ident, block6.ident) and R._hasAncestor(block6.ident, block3.ident)) - self.assertTrue(not R._hasAncestor(block2.ident, block6.ident) and not R._hasAncestor(block6.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block6.ident) and R._hasAncestor(block6.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block6.ident) and R._hasAncestor(block6.ident, block0.ident)) + self.assertTrue(not R.hasAncestor(block5.ident, block6.ident) and R.hasAncestor(block6.ident, block5.ident)) + self.assertTrue(not R.hasAncestor(block4.ident, block6.ident) and not R.hasAncestor(block6.ident, block4.ident)) + self.assertTrue(not R.hasAncestor(block3.ident, block6.ident) and R.hasAncestor(block6.ident, block3.ident)) + self.assertTrue(not R.hasAncestor(block2.ident, block6.ident) and not R.hasAncestor(block6.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block6.ident) and R.hasAncestor(block6.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block6.ident) and R.hasAncestor(block6.ident, block0.ident)) R = R.pruneLeaves() @@ -353,30 +424,30 @@ class Test_BlockHandler(unittest.TestCase): self.assertEqual(len(R.pendingVotes),0) self.assertEqual(len(R.votes),0) - self.assertTrue( R._hasAncestor(block5.ident, block0.ident) and not R._hasAncestor(block0.ident, block5.ident)) - self.assertTrue( R._hasAncestor(block3.ident, block0.ident) and not R._hasAncestor(block0.ident, block3.ident)) - self.assertTrue( R._hasAncestor(block2.ident, block0.ident) and not R._hasAncestor(block0.ident, block2.ident)) - self.assertTrue( R._hasAncestor(block1.ident, block0.ident) and not R._hasAncestor(block0.ident, block1.ident)) + self.assertTrue( R.hasAncestor(block5.ident, block0.ident) and not R.hasAncestor(block0.ident, block5.ident)) + self.assertTrue( R.hasAncestor(block3.ident, block0.ident) and not R.hasAncestor(block0.ident, block3.ident)) + self.assertTrue( R.hasAncestor(block2.ident, block0.ident) and not R.hasAncestor(block0.ident, block2.ident)) + self.assertTrue( R.hasAncestor(block1.ident, block0.ident) and not R.hasAncestor(block0.ident, block1.ident)) - self.assertTrue( R._hasAncestor(block5.ident, block1.ident) and not R._hasAncestor(block1.ident, block5.ident)) - self.assertTrue( R._hasAncestor(block3.ident, block1.ident) and not R._hasAncestor(block1.ident, block3.ident)) - self.assertTrue( R._hasAncestor(block2.ident, block1.ident) and not R._hasAncestor(block1.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block1.ident) and R._hasAncestor(block1.ident, block0.ident)) + self.assertTrue( R.hasAncestor(block5.ident, block1.ident) and not R.hasAncestor(block1.ident, block5.ident)) + self.assertTrue( R.hasAncestor(block3.ident, block1.ident) and not R.hasAncestor(block1.ident, block3.ident)) + self.assertTrue( R.hasAncestor(block2.ident, block1.ident) and not R.hasAncestor(block1.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block1.ident) and R.hasAncestor(block1.ident, block0.ident)) - self.assertTrue(not R._hasAncestor(block5.ident, block2.ident) and not R._hasAncestor(block2.ident, block5.ident)) - self.assertTrue(not R._hasAncestor(block3.ident, block2.ident) and not R._hasAncestor(block2.ident, block3.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block2.ident) and R._hasAncestor(block2.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block2.ident) and R._hasAncestor(block2.ident, block0.ident)) + self.assertTrue(not R.hasAncestor(block5.ident, block2.ident) and not R.hasAncestor(block2.ident, block5.ident)) + self.assertTrue(not R.hasAncestor(block3.ident, block2.ident) and not R.hasAncestor(block2.ident, block3.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block2.ident) and R.hasAncestor(block2.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block2.ident) and R.hasAncestor(block2.ident, block0.ident)) - self.assertTrue( R._hasAncestor(block5.ident, block3.ident) and not R._hasAncestor(block3.ident, block5.ident)) - self.assertTrue(not R._hasAncestor(block2.ident, block3.ident) and not R._hasAncestor(block3.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block3.ident) and R._hasAncestor(block3.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block3.ident) and R._hasAncestor(block3.ident, block0.ident)) + self.assertTrue( R.hasAncestor(block5.ident, block3.ident) and not R.hasAncestor(block3.ident, block5.ident)) + self.assertTrue(not R.hasAncestor(block2.ident, block3.ident) and not R.hasAncestor(block3.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block3.ident) and R.hasAncestor(block3.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block3.ident) and R.hasAncestor(block3.ident, block0.ident)) - self.assertTrue(not R._hasAncestor(block3.ident, block5.ident) and R._hasAncestor(block5.ident, block3.ident)) - self.assertTrue(not R._hasAncestor(block2.ident, block5.ident) and not R._hasAncestor(block5.ident, block2.ident)) - self.assertTrue(not R._hasAncestor(block1.ident, block5.ident) and R._hasAncestor(block5.ident, block1.ident)) - self.assertTrue(not R._hasAncestor(block0.ident, block5.ident) and R._hasAncestor(block5.ident, block0.ident)) + self.assertTrue(not R.hasAncestor(block3.ident, block5.ident) and R.hasAncestor(block5.ident, block3.ident)) + self.assertTrue(not R.hasAncestor(block2.ident, block5.ident) and not R.hasAncestor(block5.ident, block2.ident)) + self.assertTrue(not R.hasAncestor(block1.ident, block5.ident) and R.hasAncestor(block5.ident, block1.ident)) + self.assertTrue(not R.hasAncestor(block0.ident, block5.ident) and R.hasAncestor(block5.ident, block0.ident)) ## Formal unit tests for leafBackAntichain() to follow: visual inspection reveals this does what it says on the tin. @@ -394,13 +465,15 @@ class Test_BlockHandler(unittest.TestCase): block4 = Block(parentsIn=[block2.ident, block3.ident], dataIn={"timestamp":time.time(), "txns":"come here fido"}) block5 = Block(parentsIn=[block3.ident], dataIn={"timestamp":time.time(), "txns":"applied rotation on her sugar plum"}) block6 = Block(parentsIn=[block5.ident], dataIn={"timestamp":time.time(), "txns":"listen to frank zappa for the love of all that is good"}) - R._addBlock(block0) - R._addBlock(block1) - R._addBlock(block2) - R._addBlock(block3) - R._addBlock(block4) - R._addBlock(block5) - R._addBlock(block6) + R.addBlock(block0) + R.addBlock(block1) + R.addBlock(block2) + R.addBlock(block3) + R.addBlock(block4) + R.addBlock(block5) + R.addBlock(block6) + + names = {0:block0.ident, 1:block1.ident, 2:block2.ident, 3:block3.ident, 4:block4.ident, 5:block5.ident, 6:block6.ident} # Testing voteFor # Verify all roots have children @@ -419,7 +492,7 @@ class Test_BlockHandler(unittest.TestCase): # Pick a random block with gcid in its past vid = random.choice(list(R.blocks.keys())) - while(not R._hasAncestor(vid, gcid)): + while(not R.hasAncestor(vid, gcid)): vid = random.choice(list(R.blocks.keys())) # Pick a random pair of blocks for gcid and vid to vote on. @@ -506,8 +579,87 @@ class Test_BlockHandler(unittest.TestCase): print("key = ", key, ", vote = ", R.votes[key]) print("\n ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ====") - - #R.vote() + for key in R.totalVotes: + print("key = ", key, ", vote = ", R.totalVotes[key]) + print("\n ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ====") + + self.assertTrue((names[0], names[1]) in R.totalVotes and (names[1], names[0]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[0], names[1])] and not R.totalVotes[(names[1], names[0])]) + + self.assertTrue((names[0], names[2]) in R.totalVotes and (names[2], names[0]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[0], names[2])] and not R.totalVotes[(names[2], names[0])]) + + self.assertTrue((names[0], names[3]) in R.totalVotes and (names[3], names[0]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[0], names[3])] and not R.totalVotes[(names[3], names[0])]) + + self.assertTrue((names[0], names[4]) in R.totalVotes and (names[4], names[0]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[0], names[4])] and not R.totalVotes[(names[4], names[0])]) + + self.assertTrue((names[0], names[5]) in R.totalVotes and (names[5], names[0]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[0], names[5])] and not R.totalVotes[(names[5], names[0])]) + + self.assertTrue((names[0], names[6]) in R.totalVotes and (names[6], names[0]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[0], names[6])] and not R.totalVotes[(names[6], names[0])]) + + #### #### #### #### + + self.assertTrue((names[1], names[2]) in R.totalVotes and (names[2], names[1]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[1], names[2])] and not R.totalVotes[(names[2], names[1])]) + + self.assertTrue((names[1], names[3]) in R.totalVotes and (names[3], names[1]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[1], names[3])] and not R.totalVotes[(names[3], names[1])]) + + self.assertTrue((names[1], names[4]) in R.totalVotes and (names[4], names[1]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[1], names[4])] and not R.totalVotes[(names[4], names[1])]) + + self.assertTrue((names[1], names[5]) in R.totalVotes and (names[5], names[1]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[1], names[5])] and not R.totalVotes[(names[5], names[1])]) + + self.assertTrue((names[1], names[6]) in R.totalVotes and (names[6], names[1]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[1], names[6])] and not R.totalVotes[(names[6], names[1])]) + + #### #### #### #### + + self.assertTrue((names[2], names[3]) in R.totalVotes and (names[3], names[2]) in R.totalVotes) + self.assertTrue(not R.totalVotes[(names[2], names[3])] and R.totalVotes[(names[3], names[2])]) + + self.assertTrue((names[2], names[4]) in R.totalVotes and (names[4], names[2]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[2], names[4])] and not R.totalVotes[(names[4], names[2])]) + + self.assertTrue((names[2], names[5]) in R.totalVotes and (names[5], names[2]) in R.totalVotes) + self.assertTrue(not R.totalVotes[(names[2], names[5])] and R.totalVotes[(names[5], names[2])]) + + self.assertTrue((names[2], names[6]) in R.totalVotes and (names[6], names[2]) in R.totalVotes) + #print("2,6 ", R.totalVotes[(names[2], names[6])]) + #print("6,2 ", R.totalVotes[(names[6], names[2])]) + self.assertTrue(not R.totalVotes[(names[2], names[6])] and R.totalVotes[(names[6], names[2])]) + + #### #### #### #### + + self.assertTrue((names[3], names[4]) in R.totalVotes and (names[4], names[3]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[3], names[4])] and not R.totalVotes[(names[4], names[3])]) + + self.assertTrue((names[3], names[5]) in R.totalVotes and (names[5], names[3]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[3], names[5])] and not R.totalVotes[(names[5], names[3])]) + + self.assertTrue((names[3], names[6]) in R.totalVotes and (names[6], names[3]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[3], names[6])] and not R.totalVotes[(names[6], names[3])]) + + #### #### #### #### + + self.assertTrue((names[4], names[5]) in R.totalVotes and (names[5], names[4]) in R.totalVotes) + self.assertTrue(not R.totalVotes[(names[4], names[5])] and R.totalVotes[(names[5], names[4])]) + + self.assertTrue((names[4], names[6]) in R.totalVotes and (names[6], names[4]) in R.totalVotes) + self.assertTrue(not R.totalVotes[(names[4], names[6])] and R.totalVotes[(names[6], names[4])]) + + #### #### #### #### + + self.assertTrue((names[5], names[6]) in R.totalVotes and (names[6], names[5]) in R.totalVotes) + self.assertTrue(R.totalVotes[(names[5], names[6])] and not R.totalVotes[(names[6], names[5])]) + + + #print(R.votes) From 898f462c4a2a011ad11d3ea93fda08e584742d10 Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Mon, 18 Dec 2017 11:38:56 -0700 Subject: [PATCH 10/11] Creating BP folder and files --- source-code/BulletProofs/readme.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 source-code/BulletProofs/readme.md diff --git a/source-code/BulletProofs/readme.md b/source-code/BulletProofs/readme.md new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/source-code/BulletProofs/readme.md @@ -0,0 +1 @@ + From 69d16897582b825d8fdd2e922e22fe5319805576 Mon Sep 17 00:00:00 2001 From: Brandon Goodell Date: Mon, 18 Dec 2017 13:40:08 -0500 Subject: [PATCH 11/11] more robust computation of challenges --- .../BulletProofs/LinearBulletproof.java | 372 ++++++++++++ source-code/BulletProofs/LogBulletproof.java | 532 ++++++++++++++++++ .../BulletProofs/OptimizedLogBulletproof.java | 522 +++++++++++++++++ 3 files changed, 1426 insertions(+) create mode 100644 source-code/BulletProofs/LinearBulletproof.java create mode 100644 source-code/BulletProofs/LogBulletproof.java create mode 100644 source-code/BulletProofs/OptimizedLogBulletproof.java diff --git a/source-code/BulletProofs/LinearBulletproof.java b/source-code/BulletProofs/LinearBulletproof.java new file mode 100644 index 0000000..f33913e --- /dev/null +++ b/source-code/BulletProofs/LinearBulletproof.java @@ -0,0 +1,372 @@ +// NOTE: this interchanges the roles of G and H to match other code's behavior + +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import how.monero.hodl.util.ByteUtil; +import java.math.BigInteger; +import how.monero.hodl.util.VarInt; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class LinearBulletproof +{ + private static int N; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Scalar[] l; + private Scalar[] r; + + public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Scalar[] l, Scalar[] r) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.l = l; + this.r = r; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + Scalar result = Scalar.ZERO; + for (int i = 0; i < N; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + assert x.mul(inverse).equals(Scalar.ONE); + + return inverse; + } + + /* Compute the value of k(y,z) */ + public static Scalar ComputeK(Scalar y, Scalar z) + { + Scalar result = Scalar.ZERO; + result = result.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + result = result.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + + return result; + } + + /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ + public static ProofTuple PROVE(Scalar v, Scalar gamma) + { + Curve25519Point V = H.scalarMultiply(v).add(G.scalarMultiply(gamma)); + + // This hash is updated for Fiat-Shamir throughout the proof + Scalar hashCache = hashToScalar(V.toBytes()); + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[N]; + Scalar[] aR = new Scalar[N]; + + BigInteger tempV = v.toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[i] = Scalar.ZERO; + } + else + { + aL[i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[i] = aL[i].sub(Scalar.ONE); + } + + // DEBUG: Test to ensure this recovers the value + BigInteger test_aL = BigInteger.ZERO; + BigInteger test_aR = BigInteger.ZERO; + for (int i = 0; i < N; i++) + { + if (aL[i].equals(Scalar.ONE)) + test_aL = test_aL.add(BigInteger.valueOf(2).pow(i)); + if (aR[i].equals(Scalar.ZERO)) + test_aR = test_aR.add(BigInteger.valueOf(2).pow(i)); + } + assert test_aL.equals(v.toBigInteger()); + assert test_aR.equals(v.toBigInteger()); + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(G.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[N]; + Scalar[] sR = new Scalar[N]; + for (int i = 0; i < N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(G.scalarMultiply(rho)); + + // PAPER LINES 43-45 + hashCache = hashToScalar(concat(hashCache.bytes,A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + + Scalar t0 = Scalar.ZERO; + Scalar t1 = Scalar.ZERO; + Scalar t2 = Scalar.ZERO; + + t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + t0 = t0.add(z.sq().mul(v)); + Scalar k = ComputeK(y,z); + t0 = t0.add(k); + + // DEBUG: Test the value of t0 has the correct form + Scalar test_t0 = Scalar.ZERO; + test_t0 = test_t0.add(InnerProduct(aL,Hadamard(aR,VectorPowers(y)))); + test_t0 = test_t0.add(z.mul(InnerProduct(VectorSubtract(aL,aR),VectorPowers(y)))); + test_t0 = test_t0.add(z.sq().mul(InnerProduct(VectorPowers(Scalar.TWO),aL))); + test_t0 = test_t0.add(k); + assert test_t0.equals(t0); + + t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR))); + t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())))); + t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR))); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = H.scalarMultiply(t1).add(G.scalarMultiply(tau1)); + Curve25519Point T2 = H.scalarMultiply(t2).add(G.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,T2.toBytes())); + Scalar x = hashCache; + + // PAPER LINES 52-53 + Scalar taux = Scalar.ZERO; + taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + taux = taux.add(gamma.mul(z.sq())); + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = new Scalar[N]; + Scalar[] r = new Scalar[N]; + + l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x)); + r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())); + + // DEBUG: Test if the l and r vectors match the polynomial forms + Scalar test_t = Scalar.ZERO; + test_t = test_t.add(t0).add(t1.mul(x)); + test_t = test_t.add(t2.mul(x.sq())); + assert test_t.equals(InnerProduct(l,r)); + + // PAPER LINE 58 + return new ProofTuple(V,A,S,T1,T2,taux,mu,l,r); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar hashCache = hashToScalar(proof.V.toBytes()); + hashCache = hashToScalar(concat(hashCache.bytes,proof.A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T2.toBytes())); + Scalar x = hashCache; + + // PAPER LINE 60 + Scalar t = InnerProduct(proof.l,proof.r); + + // PAPER LINE 61 + Curve25519Point L61Left = G.scalarMultiply(proof.taux).add(H.scalarMultiply(t)); + + Scalar k = ComputeK(y,z); + + Curve25519Point L61Right = H.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))))); + L61Right = L61Right.add(proof.V.scalarMultiply(z.sq())); + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + { + return false; + } + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + Scalar[] Gexp = new Scalar[N]; + for (int i = 0; i < N; i++) + Gexp[i] = Scalar.ZERO.sub(z); + + Scalar[] Hexp = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Hexp[i] = Scalar.ZERO; + Hexp[i] = Hexp[i].add(z.mul(y.pow(i))); + Hexp[i] = Hexp[i].add(z.sq().mul(Scalar.TWO.pow(i))); + Hexp[i] = Hexp[i].mul(Invert(y).pow(i)); + } + P = P.add(VectorExponent(Gexp,Hexp)); + + // PAPER LINE 63 + for (int i = 0; i < N; i++) + { + Hexp[i] = Scalar.ZERO; + Hexp[i] = Hexp[i].add(proof.r[i]); + Hexp[i] = Hexp[i].mul(Invert(y).pow(i)); + } + Curve25519Point L63Right = VectorExponent(proof.l,Hexp).add(G.scalarMultiply(proof.mu)); + + if (!L63Right.equals(P)) + { + return false; + } + + return true; + } + + public static void main(String[] args) + { + // Number of bits in the range + N = 64; + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[N]; + Hi = new Curve25519Point[N]; + for (int i = 0; i < N; i++) + { + Gi[i] = getHpnGLookup(2*i); + Hi[i] = getHpnGLookup(2*i+1); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int TRIALS = 250; + int count = 0; + + while (count < TRIALS) + { + long amount = rando.nextLong(); + if (amount > Math.pow(2,N)-1 || amount < 0) + continue; + + ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar()); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +} diff --git a/source-code/BulletProofs/LogBulletproof.java b/source-code/BulletProofs/LogBulletproof.java new file mode 100644 index 0000000..7061cd9 --- /dev/null +++ b/source-code/BulletProofs/LogBulletproof.java @@ -0,0 +1,532 @@ +// NOTE: this interchanges the roles of G and H to match other code's behavior + +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import java.math.BigInteger; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class LogBulletproof +{ + private static int N; + private static int logN; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Curve25519Point[] L; + private Curve25519Point[] R; + private Scalar a; + private Scalar b; + private Scalar t; + + public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Curve25519Point[] L, Curve25519Point[] R, Scalar a, Scalar b, Scalar t) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.L = L; + this.R = R; + this.a = a; + this.b = b; + this.t = t; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + assert a.length == N && b.length == N; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Compute a custom vector-scalar commitment */ + public static Curve25519Point VectorExponentCustom(Curve25519Point[] A, Curve25519Point[] B, Scalar[] a, Scalar[] b) + { + assert a.length == A.length && b.length == B.length && a.length == b.length; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < a.length; i++) + { + Result = Result.add(A[i].scalarMultiply(a[i])); + Result = Result.add(B[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar result = Scalar.ZERO; + for (int i = 0; i < a.length; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Given two curvepoint arrays, construct the Hadamard product */ + public static Curve25519Point[] Hadamard2(Curve25519Point[] A, Curve25519Point[] B) + { + assert A.length == B.length; + + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].add(B[i]); + } + return Result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Exponentiate a curve vector by a scalar */ + public static Curve25519Point[] VectorScalar2(Curve25519Point[] A, Scalar x) + { + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].scalarMultiply(x); + } + return Result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + + assert x.mul(inverse).equals(Scalar.ONE); + return inverse; + } + + /* Compute the slice of a curvepoint vector */ + public static Curve25519Point[] CurveSlice(Curve25519Point[] a, int start, int stop) + { + Curve25519Point[] Result = new Curve25519Point[stop-start]; + for (int i = start; i < stop; i++) + { + Result[i-start] = a[i]; + } + return Result; + } + + /* Compute the slice of a scalar vector */ + public static Scalar[] ScalarSlice(Scalar[] a, int start, int stop) + { + Scalar[] result = new Scalar[stop-start]; + for (int i = start; i < stop; i++) + { + result[i-start] = a[i]; + } + return result; + } + + /* Compute the value of k(y,z) */ + public static Scalar ComputeK(Scalar y, Scalar z) + { + Scalar result = Scalar.ZERO; + result = result.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + result = result.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + + return result; + } + + /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ + public static ProofTuple PROVE(Scalar v, Scalar gamma) + { + Curve25519Point V = H.scalarMultiply(v).add(G.scalarMultiply(gamma)); + + // This hash is updated for Fiat-Shamir throughout the proof + Scalar hashCache = hashToScalar(V.toBytes()); + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[N]; + Scalar[] aR = new Scalar[N]; + + BigInteger tempV = v.toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[i] = Scalar.ZERO; + } + else + { + aL[i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[i] = aL[i].sub(Scalar.ONE); + } + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(G.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[N]; + Scalar[] sR = new Scalar[N]; + for (int i = 0; i < N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(G.scalarMultiply(rho)); + + // PAPER LINES 43-45 + hashCache = hashToScalar(concat(hashCache.bytes,A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + + // Polynomial construction before PAPER LINE 46 + Scalar t0 = Scalar.ZERO; + Scalar t1 = Scalar.ZERO; + Scalar t2 = Scalar.ZERO; + + t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + t0 = t0.add(z.sq().mul(v)); + Scalar k = ComputeK(y,z); + t0 = t0.add(k); + + t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR))); + t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())))); + + t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR))); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = H.scalarMultiply(t1).add(G.scalarMultiply(tau1)); + Curve25519Point T2 = H.scalarMultiply(t2).add(G.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,T2.toBytes())); + Scalar x = hashCache; + + // PAPER LINES 52-53 + Scalar taux = Scalar.ZERO; + taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + taux = taux.add(gamma.mul(z.sq())); + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = new Scalar[N]; + Scalar[] r = new Scalar[N]; + + l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x)); + r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())); + + Scalar t = InnerProduct(l,r); + + // PAPER LINES 32-33 + hashCache = hashToScalar(concat(hashCache.bytes,x.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,taux.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,mu.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,t.bytes)); + Scalar x_ip = hashCache; + + // These are used in the inner product rounds + int nprime = N; + Curve25519Point[] Gprime = new Curve25519Point[N]; + Curve25519Point[] Hprime = new Curve25519Point[N]; + Scalar[] aprime = new Scalar[N]; + Scalar[] bprime = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Gprime[i] = Gi[i]; + Hprime[i] = Hi[i].scalarMultiply(Invert(y).pow(i)); + aprime[i] = l[i]; + bprime[i] = r[i]; + } + Curve25519Point[] L = new Curve25519Point[logN]; + Curve25519Point[] R = new Curve25519Point[logN]; + int round = 0; // track the index based on number of rounds + Scalar[] w = new Scalar[logN]; // this is the challenge x in the inner product protocol + + // PAPER LINE 13 + while (nprime > 1) + { + // PAPER LINE 15 + nprime /= 2; + + // PAPER LINES 16-17 + Scalar cL = InnerProduct(ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)); + Scalar cR = InnerProduct(ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)); + + // PAPER LINES 18-19 + L[round] = VectorExponentCustom(CurveSlice(Gprime,nprime,Gprime.length),CurveSlice(Hprime,0,nprime),ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)).add(H.scalarMultiply(cL.mul(x_ip))); + R[round] = VectorExponentCustom(CurveSlice(Gprime,0,nprime),CurveSlice(Hprime,nprime,Hprime.length),ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)).add(H.scalarMultiply(cR.mul(x_ip))); + + // PAPER LINES 21-22 + hashCache = hashToScalar(concat(hashCache.bytes,L[round].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,R[round].toBytes())); + w[round] = hashCache; + + // PAPER LINES 24-25 + Gprime = Hadamard2(VectorScalar2(CurveSlice(Gprime,0,nprime),Invert(w[round])),VectorScalar2(CurveSlice(Gprime,nprime,Gprime.length),w[round])); + Hprime = Hadamard2(VectorScalar2(CurveSlice(Hprime,0,nprime),w[round]),VectorScalar2(CurveSlice(Hprime,nprime,Hprime.length),Invert(w[round]))); + + // PAPER LINES 28-29 + aprime = VectorAdd(VectorScalar(ScalarSlice(aprime,0,nprime),w[round]),VectorScalar(ScalarSlice(aprime,nprime,aprime.length),Invert(w[round]))); + bprime = VectorAdd(VectorScalar(ScalarSlice(bprime,0,nprime),Invert(w[round])),VectorScalar(ScalarSlice(bprime,nprime,bprime.length),w[round])); + + round += 1; + } + + // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) + return new ProofTuple(V,A,S,T1,T2,taux,mu,L,R,aprime[0],bprime[0],t); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar hashCache = hashToScalar(proof.V.toBytes()); + hashCache = hashToScalar(concat(hashCache.bytes,proof.A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T2.toBytes())); + Scalar x = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,x.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.taux.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.mu.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.t.bytes)); + Scalar x_ip = hashCache; + + // PAPER LINE 61 + Curve25519Point L61Left = G.scalarMultiply(proof.taux).add(H.scalarMultiply(proof.t)); + + Scalar k = ComputeK(y,z); + + Curve25519Point L61Right = H.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))))); + L61Right = L61Right.add(proof.V.scalarMultiply(z.sq())); + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + return false; + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + Scalar[] Gexp = new Scalar[N]; + for (int i = 0; i < N; i++) + Gexp[i] = Scalar.ZERO.sub(z); + + Scalar[] Hexp = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Hexp[i] = Scalar.ZERO; + Hexp[i] = Hexp[i].add(z.mul(y.pow(i))); + Hexp[i] = Hexp[i].add(z.sq().mul(Scalar.TWO.pow(i))); + Hexp[i] = Hexp[i].mul(Invert(y).pow(i)); + } + P = P.add(VectorExponent(Gexp,Hexp)); + + // Compute the number of rounds for the inner product + int rounds = proof.L.length; + + // PAPER LINES 21-22 + // The inner product challenges are computed per round + Scalar[] w = new Scalar[rounds]; + hashCache = hashToScalar(concat(hashCache.bytes,proof.L[0].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.R[0].toBytes())); + w[0] = hashCache; + if (rounds > 1) + { + for (int i = 1; i < rounds; i++) + { + hashCache = hashToScalar(concat(hashCache.bytes,proof.L[i].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.R[i].toBytes())); + w[i] = hashCache; + } + } + + // Basically PAPER LINES 24-25 + // Compute the curvepoints from G[i] and H[i] + Curve25519Point InnerProdG = Curve25519Point.ZERO; + Curve25519Point InnerProdH = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + // Convert the index to binary IN REVERSE and construct the scalar exponent + int index = i; + Scalar gScalar = Scalar.ONE; + Scalar hScalar = Invert(y).pow(i); + + for (int j = rounds-1; j >= 0; j--) + { + int J = w.length - j - 1; // because this is done in reverse bit order + int basePow = (int) Math.pow(2,j); // assumes we don't get too big + if (index / basePow == 0) // bit is zero + { + gScalar = gScalar.mul(Invert(w[J])); + hScalar = hScalar.mul(w[J]); + } + else // bit is one + { + gScalar = gScalar.mul(w[J]); + hScalar = hScalar.mul(Invert(w[J])); + index -= basePow; + } + } + + // Now compute the basepoint's scalar multiplication + // Each of these could be written as a multiexp operation instead + InnerProdG = InnerProdG.add(Gi[i].scalarMultiply(gScalar)); + InnerProdH = InnerProdH.add(Hi[i].scalarMultiply(hScalar)); + } + + // PAPER LINE 26 + Curve25519Point Pprime = P.add(G.scalarMultiply(Scalar.ZERO.sub(proof.mu))); + + for (int i = 0; i < rounds; i++) + { + Pprime = Pprime.add(proof.L[i].scalarMultiply(w[i].sq())); + Pprime = Pprime.add(proof.R[i].scalarMultiply(Invert(w[i]).sq())); + } + Pprime = Pprime.add(H.scalarMultiply(proof.t.mul(x_ip))); + + if (!Pprime.equals(InnerProdG.scalarMultiply(proof.a).add(InnerProdH.scalarMultiply(proof.b)).add(H.scalarMultiply(proof.a.mul(proof.b).mul(x_ip))))) + return false; + + return true; + } + + public static void main(String[] args) + { + // Number of bits in the range + N = 64; + logN = 6; // its log, manually + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[N]; + Hi = new Curve25519Point[N]; + for (int i = 0; i < N; i++) + { + Gi[i] = getHpnGLookup(2*i); + Hi[i] = getHpnGLookup(2*i+1); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int TRIALS = 250; + int count = 0; + + while (count < TRIALS) + { + long amount = rando.nextLong(); + if (amount > Math.pow(2,N)-1 || amount < 0) + continue; + + ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar()); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +} diff --git a/source-code/BulletProofs/OptimizedLogBulletproof.java b/source-code/BulletProofs/OptimizedLogBulletproof.java new file mode 100644 index 0000000..6752fc7 --- /dev/null +++ b/source-code/BulletProofs/OptimizedLogBulletproof.java @@ -0,0 +1,522 @@ +// NOTE: this interchanges the roles of G and H to match other code's behavior + +package how.monero.hodl.bulletproof; + +import how.monero.hodl.crypto.Curve25519Point; +import how.monero.hodl.crypto.Scalar; +import how.monero.hodl.crypto.CryptoUtil; +import java.math.BigInteger; +import java.util.Random; + +import static how.monero.hodl.crypto.Scalar.randomScalar; +import static how.monero.hodl.crypto.CryptoUtil.*; +import static how.monero.hodl.util.ByteUtil.*; + +public class OptimizedLogBulletproof +{ + private static int N; + private static int logN; + private static Curve25519Point G; + private static Curve25519Point H; + private static Curve25519Point[] Gi; + private static Curve25519Point[] Hi; + + public static class ProofTuple + { + private Curve25519Point V; + private Curve25519Point A; + private Curve25519Point S; + private Curve25519Point T1; + private Curve25519Point T2; + private Scalar taux; + private Scalar mu; + private Curve25519Point[] L; + private Curve25519Point[] R; + private Scalar a; + private Scalar b; + private Scalar t; + + public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Curve25519Point[] L, Curve25519Point[] R, Scalar a, Scalar b, Scalar t) + { + this.V = V; + this.A = A; + this.S = S; + this.T1 = T1; + this.T2 = T2; + this.taux = taux; + this.mu = mu; + this.L = L; + this.R = R; + this.a = a; + this.b = b; + this.t = t; + } + } + + /* Given two scalar arrays, construct a vector commitment */ + public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b) + { + assert a.length == N && b.length == N; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + Result = Result.add(Gi[i].scalarMultiply(a[i])); + Result = Result.add(Hi[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Compute a custom vector-scalar commitment */ + public static Curve25519Point VectorExponentCustom(Curve25519Point[] A, Curve25519Point[] B, Scalar[] a, Scalar[] b) + { + assert a.length == A.length && b.length == B.length && a.length == b.length; + + Curve25519Point Result = Curve25519Point.ZERO; + for (int i = 0; i < a.length; i++) + { + Result = Result.add(A[i].scalarMultiply(a[i])); + Result = Result.add(B[i].scalarMultiply(b[i])); + } + return Result; + } + + /* Given a scalar, construct a vector of powers */ + public static Scalar[] VectorPowers(Scalar x) + { + Scalar[] result = new Scalar[N]; + for (int i = 0; i < N; i++) + { + result[i] = x.pow(i); + } + return result; + } + + /* Given two scalar arrays, construct the inner product */ + public static Scalar InnerProduct(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar result = Scalar.ZERO; + for (int i = 0; i < a.length; i++) + { + result = result.add(a[i].mul(b[i])); + } + return result; + } + + /* Given two scalar arrays, construct the Hadamard product */ + public static Scalar[] Hadamard(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(b[i]); + } + return result; + } + + /* Given two curvepoint arrays, construct the Hadamard product */ + public static Curve25519Point[] Hadamard2(Curve25519Point[] A, Curve25519Point[] B) + { + assert A.length == B.length; + + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].add(B[i]); + } + return Result; + } + + /* Add two vectors */ + public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].add(b[i]); + } + return result; + } + + /* Subtract two vectors */ + public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b) + { + assert a.length == b.length; + + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].sub(b[i]); + } + return result; + } + + /* Multiply a scalar and a vector */ + public static Scalar[] VectorScalar(Scalar[] a, Scalar x) + { + Scalar[] result = new Scalar[a.length]; + for (int i = 0; i < a.length; i++) + { + result[i] = a[i].mul(x); + } + return result; + } + + /* Exponentiate a curve vector by a scalar */ + public static Curve25519Point[] VectorScalar2(Curve25519Point[] A, Scalar x) + { + Curve25519Point[] Result = new Curve25519Point[A.length]; + for (int i = 0; i < A.length; i++) + { + Result[i] = A[i].scalarMultiply(x); + } + return Result; + } + + /* Compute the inverse of a scalar, the stupid way */ + public static Scalar Invert(Scalar x) + { + Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l)); + + assert x.mul(inverse).equals(Scalar.ONE); + return inverse; + } + + /* Compute the slice of a curvepoint vector */ + public static Curve25519Point[] CurveSlice(Curve25519Point[] a, int start, int stop) + { + Curve25519Point[] Result = new Curve25519Point[stop-start]; + for (int i = start; i < stop; i++) + { + Result[i-start] = a[i]; + } + return Result; + } + + /* Compute the slice of a scalar vector */ + public static Scalar[] ScalarSlice(Scalar[] a, int start, int stop) + { + Scalar[] result = new Scalar[stop-start]; + for (int i = start; i < stop; i++) + { + result[i-start] = a[i]; + } + return result; + } + + /* Compute the value of k(y,z) */ + public static Scalar ComputeK(Scalar y, Scalar z) + { + Scalar result = Scalar.ZERO; + result = result.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + result = result.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO)))); + + return result; + } + + /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ + public static ProofTuple PROVE(Scalar v, Scalar gamma) + { + Curve25519Point V = H.scalarMultiply(v).add(G.scalarMultiply(gamma)); + + // This hash is updated for Fiat-Shamir throughout the proof + Scalar hashCache = hashToScalar(V.toBytes()); + + // PAPER LINES 36-37 + Scalar[] aL = new Scalar[N]; + Scalar[] aR = new Scalar[N]; + + BigInteger tempV = v.toBigInteger(); + for (int i = N-1; i >= 0; i--) + { + BigInteger basePow = BigInteger.valueOf(2).pow(i); + if (tempV.divide(basePow).equals(BigInteger.ZERO)) + { + aL[i] = Scalar.ZERO; + } + else + { + aL[i] = Scalar.ONE; + tempV = tempV.subtract(basePow); + } + + aR[i] = aL[i].sub(Scalar.ONE); + } + + // PAPER LINES 38-39 + Scalar alpha = randomScalar(); + Curve25519Point A = VectorExponent(aL,aR).add(G.scalarMultiply(alpha)); + + // PAPER LINES 40-42 + Scalar[] sL = new Scalar[N]; + Scalar[] sR = new Scalar[N]; + for (int i = 0; i < N; i++) + { + sL[i] = randomScalar(); + sR[i] = randomScalar(); + } + Scalar rho = randomScalar(); + Curve25519Point S = VectorExponent(sL,sR).add(G.scalarMultiply(rho)); + + // PAPER LINES 43-45 + hashCache = hashToScalar(concat(hashCache.bytes,A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + + // Polynomial construction before PAPER LINE 46 + Scalar t0 = Scalar.ZERO; + Scalar t1 = Scalar.ZERO; + Scalar t2 = Scalar.ZERO; + + t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))); + t0 = t0.add(z.sq().mul(v)); + Scalar k = ComputeK(y,z); + t0 = t0.add(k); + + t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR))); + t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())))); + + t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR))); + + // PAPER LINES 47-48 + Scalar tau1 = randomScalar(); + Scalar tau2 = randomScalar(); + Curve25519Point T1 = H.scalarMultiply(t1).add(G.scalarMultiply(tau1)); + Curve25519Point T2 = H.scalarMultiply(t2).add(G.scalarMultiply(tau2)); + + // PAPER LINES 49-51 + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,T2.toBytes())); + Scalar x = hashCache; + + // PAPER LINES 52-53 + Scalar taux = Scalar.ZERO; + taux = tau1.mul(x); + taux = taux.add(tau2.mul(x.sq())); + taux = taux.add(gamma.mul(z.sq())); + Scalar mu = x.mul(rho).add(alpha); + + // PAPER LINES 54-57 + Scalar[] l = new Scalar[N]; + Scalar[] r = new Scalar[N]; + + l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x)); + r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq())); + + Scalar t = InnerProduct(l,r); + + // PAPER LINES 32-33 + hashCache = hashToScalar(concat(hashCache.bytes,x.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,taux.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,mu.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,t.bytes)); + Scalar x_ip = hashCache; + + // These are used in the inner product rounds + int nprime = N; + Curve25519Point[] Gprime = new Curve25519Point[N]; + Curve25519Point[] Hprime = new Curve25519Point[N]; + Scalar[] aprime = new Scalar[N]; + Scalar[] bprime = new Scalar[N]; + for (int i = 0; i < N; i++) + { + Gprime[i] = Gi[i]; + Hprime[i] = Hi[i].scalarMultiply(Invert(y).pow(i)); + aprime[i] = l[i]; + bprime[i] = r[i]; + } + Curve25519Point[] L = new Curve25519Point[logN]; + Curve25519Point[] R = new Curve25519Point[logN]; + int round = 0; // track the index based on number of rounds + Scalar[] w = new Scalar[logN]; // this is the challenge x in the inner product protocol + + // PAPER LINE 13 + while (nprime > 1) + { + // PAPER LINE 15 + nprime /= 2; + + // PAPER LINES 16-17 + Scalar cL = InnerProduct(ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)); + Scalar cR = InnerProduct(ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)); + + // PAPER LINES 18-19 + L[round] = VectorExponentCustom(CurveSlice(Gprime,nprime,Gprime.length),CurveSlice(Hprime,0,nprime),ScalarSlice(aprime,0,nprime),ScalarSlice(bprime,nprime,bprime.length)).add(H.scalarMultiply(cL.mul(x_ip))); + R[round] = VectorExponentCustom(CurveSlice(Gprime,0,nprime),CurveSlice(Hprime,nprime,Hprime.length),ScalarSlice(aprime,nprime,aprime.length),ScalarSlice(bprime,0,nprime)).add(H.scalarMultiply(cR.mul(x_ip))); + + // PAPER LINES 21-22 + hashCache = hashToScalar(concat(hashCache.bytes,L[round].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,R[round].toBytes())); + w[round] = hashCache; + + // PAPER LINES 24-25 + Gprime = Hadamard2(VectorScalar2(CurveSlice(Gprime,0,nprime),Invert(w[round])),VectorScalar2(CurveSlice(Gprime,nprime,Gprime.length),w[round])); + Hprime = Hadamard2(VectorScalar2(CurveSlice(Hprime,0,nprime),w[round]),VectorScalar2(CurveSlice(Hprime,nprime,Hprime.length),Invert(w[round]))); + + // PAPER LINES 28-29 + aprime = VectorAdd(VectorScalar(ScalarSlice(aprime,0,nprime),w[round]),VectorScalar(ScalarSlice(aprime,nprime,aprime.length),Invert(w[round]))); + bprime = VectorAdd(VectorScalar(ScalarSlice(bprime,0,nprime),Invert(w[round])),VectorScalar(ScalarSlice(bprime,nprime,bprime.length),w[round])); + + round += 1; + } + + // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) + return new ProofTuple(V,A,S,T1,T2,taux,mu,L,R,aprime[0],bprime[0],t); + } + + /* Given a range proof, determine if it is valid */ + public static boolean VERIFY(ProofTuple proof) + { + // Reconstruct the challenges + Scalar hashCache = hashToScalar(proof.V.toBytes()); + hashCache = hashToScalar(concat(hashCache.bytes,proof.A.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.S.toBytes())); + Scalar y = hashCache; + hashCache = hashToScalar(hashCache.bytes); + Scalar z = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,z.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T1.toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.T2.toBytes())); + Scalar x = hashCache; + hashCache = hashToScalar(concat(hashCache.bytes,x.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.taux.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.mu.bytes)); + hashCache = hashToScalar(concat(hashCache.bytes,proof.t.bytes)); + Scalar x_ip = hashCache; + + // PAPER LINE 61 + Curve25519Point L61Left = G.scalarMultiply(proof.taux).add(H.scalarMultiply(proof.t)); + + Scalar k = ComputeK(y,z); + + Curve25519Point L61Right = H.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))))); + L61Right = L61Right.add(proof.V.scalarMultiply(z.sq())); + L61Right = L61Right.add(proof.T1.scalarMultiply(x)); + L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq())); + + if (!L61Right.equals(L61Left)) + return false; + + // PAPER LINE 62 + Curve25519Point P = Curve25519Point.ZERO; + P = P.add(proof.A); + P = P.add(proof.S.scalarMultiply(x)); + + // Compute the number of rounds for the inner product + int rounds = proof.L.length; + + // PAPER LINES 21-22 + // The inner product challenges are computed per round + Scalar[] w = new Scalar[rounds]; + hashCache = hashToScalar(concat(hashCache.bytes,proof.L[0].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.R[0].toBytes())); + w[0] = hashCache; + if (rounds > 1) + { + for (int i = 1; i < rounds; i++) + { + hashCache = hashToScalar(concat(hashCache.bytes,proof.L[i].toBytes())); + hashCache = hashToScalar(concat(hashCache.bytes,proof.R[i].toBytes())); + w[i] = hashCache; + } + } + + // Basically PAPER LINES 24-25 + // Compute the curvepoints from G[i] and H[i] + Curve25519Point InnerProdG = Curve25519Point.ZERO; + Curve25519Point InnerProdH = Curve25519Point.ZERO; + for (int i = 0; i < N; i++) + { + // Convert the index to binary IN REVERSE and construct the scalar exponent + int index = i; + Scalar gScalar = proof.a; + Scalar hScalar = proof.b.mul(Invert(y).pow(i)); + + for (int j = rounds-1; j >= 0; j--) + { + int J = w.length - j - 1; // because this is done in reverse bit order + int basePow = (int) Math.pow(2,j); // assumes we don't get too big + if (index / basePow == 0) // bit is zero + { + gScalar = gScalar.mul(Invert(w[J])); + hScalar = hScalar.mul(w[J]); + } + else // bit is one + { + gScalar = gScalar.mul(w[J]); + hScalar = hScalar.mul(Invert(w[J])); + index -= basePow; + } + } + + // Adjust the scalars using the exponents from PAPER LINE 62 + gScalar = gScalar.add(z); + hScalar = hScalar.sub(z.mul(y.pow(i)).add(z.sq().mul(Scalar.TWO.pow(i))).mul(Invert(y).pow(i))); + + // Now compute the basepoint's scalar multiplication + // Each of these could be written as a multiexp operation instead + InnerProdG = InnerProdG.add(Gi[i].scalarMultiply(gScalar)); + InnerProdH = InnerProdH.add(Hi[i].scalarMultiply(hScalar)); + } + + // PAPER LINE 26 + Curve25519Point Pprime = P.add(G.scalarMultiply(Scalar.ZERO.sub(proof.mu))); + + for (int i = 0; i < rounds; i++) + { + Pprime = Pprime.add(proof.L[i].scalarMultiply(w[i].sq())); + Pprime = Pprime.add(proof.R[i].scalarMultiply(Invert(w[i]).sq())); + } + Pprime = Pprime.add(H.scalarMultiply(proof.t.mul(x_ip))); + + if (!Pprime.equals(InnerProdG.add(InnerProdH).add(H.scalarMultiply(proof.a.mul(proof.b).mul(x_ip))))) + return false; + + return true; + } + + public static void main(String[] args) + { + // Number of bits in the range + N = 64; + logN = 6; // its log, manually + + // Set the curve base points + G = Curve25519Point.G; + H = Curve25519Point.hashToPoint(G); + Gi = new Curve25519Point[N]; + Hi = new Curve25519Point[N]; + for (int i = 0; i < N; i++) + { + Gi[i] = getHpnGLookup(2*i); + Hi[i] = getHpnGLookup(2*i+1); + } + + // Run a bunch of randomized trials + Random rando = new Random(); + int TRIALS = 250; + int count = 0; + + while (count < TRIALS) + { + long amount = rando.nextLong(); + if (amount > Math.pow(2,N)-1 || amount < 0) + continue; + + ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar()); + if (!VERIFY(proof)) + System.out.println("Test failed"); + + count += 1; + } + } +}