research-lab/source-code/BulletProofs/LinearBulletproof.java
2017-12-18 13:40:08 -05:00

372 lines
12 KiB
Java

// NOTE: this interchanges the roles of G and H to match other code's behavior
package how.monero.hodl.bulletproof;
import how.monero.hodl.crypto.Curve25519Point;
import how.monero.hodl.crypto.Scalar;
import how.monero.hodl.crypto.CryptoUtil;
import how.monero.hodl.util.ByteUtil;
import java.math.BigInteger;
import how.monero.hodl.util.VarInt;
import java.util.Random;
import static how.monero.hodl.crypto.Scalar.randomScalar;
import static how.monero.hodl.crypto.CryptoUtil.*;
import static how.monero.hodl.util.ByteUtil.*;
public class LinearBulletproof
{
private static int N;
private static Curve25519Point G;
private static Curve25519Point H;
private static Curve25519Point[] Gi;
private static Curve25519Point[] Hi;
public static class ProofTuple
{
private Curve25519Point V;
private Curve25519Point A;
private Curve25519Point S;
private Curve25519Point T1;
private Curve25519Point T2;
private Scalar taux;
private Scalar mu;
private Scalar[] l;
private Scalar[] r;
public ProofTuple(Curve25519Point V, Curve25519Point A, Curve25519Point S, Curve25519Point T1, Curve25519Point T2, Scalar taux, Scalar mu, Scalar[] l, Scalar[] r)
{
this.V = V;
this.A = A;
this.S = S;
this.T1 = T1;
this.T2 = T2;
this.taux = taux;
this.mu = mu;
this.l = l;
this.r = r;
}
}
/* Given two scalar arrays, construct a vector commitment */
public static Curve25519Point VectorExponent(Scalar[] a, Scalar[] b)
{
Curve25519Point Result = Curve25519Point.ZERO;
for (int i = 0; i < N; i++)
{
Result = Result.add(Gi[i].scalarMultiply(a[i]));
Result = Result.add(Hi[i].scalarMultiply(b[i]));
}
return Result;
}
/* Given a scalar, construct a vector of powers */
public static Scalar[] VectorPowers(Scalar x)
{
Scalar[] result = new Scalar[N];
for (int i = 0; i < N; i++)
{
result[i] = x.pow(i);
}
return result;
}
/* Given two scalar arrays, construct the inner product */
public static Scalar InnerProduct(Scalar[] a, Scalar[] b)
{
Scalar result = Scalar.ZERO;
for (int i = 0; i < N; i++)
{
result = result.add(a[i].mul(b[i]));
}
return result;
}
/* Given two scalar arrays, construct the Hadamard product */
public static Scalar[] Hadamard(Scalar[] a, Scalar[] b)
{
Scalar[] result = new Scalar[N];
for (int i = 0; i < N; i++)
{
result[i] = a[i].mul(b[i]);
}
return result;
}
/* Add two vectors */
public static Scalar[] VectorAdd(Scalar[] a, Scalar[] b)
{
Scalar[] result = new Scalar[N];
for (int i = 0; i < N; i++)
{
result[i] = a[i].add(b[i]);
}
return result;
}
/* Subtract two vectors */
public static Scalar[] VectorSubtract(Scalar[] a, Scalar[] b)
{
Scalar[] result = new Scalar[N];
for (int i = 0; i < N; i++)
{
result[i] = a[i].sub(b[i]);
}
return result;
}
/* Multiply a scalar and a vector */
public static Scalar[] VectorScalar(Scalar[] a, Scalar x)
{
Scalar[] result = new Scalar[N];
for (int i = 0; i < N; i++)
{
result[i] = a[i].mul(x);
}
return result;
}
/* Compute the inverse of a scalar, the stupid way */
public static Scalar Invert(Scalar x)
{
Scalar inverse = new Scalar(x.toBigInteger().modInverse(CryptoUtil.l));
assert x.mul(inverse).equals(Scalar.ONE);
return inverse;
}
/* Compute the value of k(y,z) */
public static Scalar ComputeK(Scalar y, Scalar z)
{
Scalar result = Scalar.ZERO;
result = result.sub(z.sq().mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))));
result = result.sub(z.pow(3).mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(Scalar.TWO))));
return result;
}
/* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */
public static ProofTuple PROVE(Scalar v, Scalar gamma)
{
Curve25519Point V = H.scalarMultiply(v).add(G.scalarMultiply(gamma));
// This hash is updated for Fiat-Shamir throughout the proof
Scalar hashCache = hashToScalar(V.toBytes());
// PAPER LINES 36-37
Scalar[] aL = new Scalar[N];
Scalar[] aR = new Scalar[N];
BigInteger tempV = v.toBigInteger();
for (int i = N-1; i >= 0; i--)
{
BigInteger basePow = BigInteger.valueOf(2).pow(i);
if (tempV.divide(basePow).equals(BigInteger.ZERO))
{
aL[i] = Scalar.ZERO;
}
else
{
aL[i] = Scalar.ONE;
tempV = tempV.subtract(basePow);
}
aR[i] = aL[i].sub(Scalar.ONE);
}
// DEBUG: Test to ensure this recovers the value
BigInteger test_aL = BigInteger.ZERO;
BigInteger test_aR = BigInteger.ZERO;
for (int i = 0; i < N; i++)
{
if (aL[i].equals(Scalar.ONE))
test_aL = test_aL.add(BigInteger.valueOf(2).pow(i));
if (aR[i].equals(Scalar.ZERO))
test_aR = test_aR.add(BigInteger.valueOf(2).pow(i));
}
assert test_aL.equals(v.toBigInteger());
assert test_aR.equals(v.toBigInteger());
// PAPER LINES 38-39
Scalar alpha = randomScalar();
Curve25519Point A = VectorExponent(aL,aR).add(G.scalarMultiply(alpha));
// PAPER LINES 40-42
Scalar[] sL = new Scalar[N];
Scalar[] sR = new Scalar[N];
for (int i = 0; i < N; i++)
{
sL[i] = randomScalar();
sR[i] = randomScalar();
}
Scalar rho = randomScalar();
Curve25519Point S = VectorExponent(sL,sR).add(G.scalarMultiply(rho));
// PAPER LINES 43-45
hashCache = hashToScalar(concat(hashCache.bytes,A.toBytes()));
hashCache = hashToScalar(concat(hashCache.bytes,S.toBytes()));
Scalar y = hashCache;
hashCache = hashToScalar(hashCache.bytes);
Scalar z = hashCache;
Scalar t0 = Scalar.ZERO;
Scalar t1 = Scalar.ZERO;
Scalar t2 = Scalar.ZERO;
t0 = t0.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y))));
t0 = t0.add(z.sq().mul(v));
Scalar k = ComputeK(y,z);
t0 = t0.add(k);
// DEBUG: Test the value of t0 has the correct form
Scalar test_t0 = Scalar.ZERO;
test_t0 = test_t0.add(InnerProduct(aL,Hadamard(aR,VectorPowers(y))));
test_t0 = test_t0.add(z.mul(InnerProduct(VectorSubtract(aL,aR),VectorPowers(y))));
test_t0 = test_t0.add(z.sq().mul(InnerProduct(VectorPowers(Scalar.TWO),aL)));
test_t0 = test_t0.add(k);
assert test_t0.equals(t0);
t1 = t1.add(InnerProduct(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),Hadamard(VectorPowers(y),sR)));
t1 = t1.add(InnerProduct(sL,VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorScalar(VectorPowers(Scalar.ONE),z))),VectorScalar(VectorPowers(Scalar.TWO),z.sq()))));
t2 = t2.add(InnerProduct(sL,Hadamard(VectorPowers(y),sR)));
// PAPER LINES 47-48
Scalar tau1 = randomScalar();
Scalar tau2 = randomScalar();
Curve25519Point T1 = H.scalarMultiply(t1).add(G.scalarMultiply(tau1));
Curve25519Point T2 = H.scalarMultiply(t2).add(G.scalarMultiply(tau2));
// PAPER LINES 49-51
hashCache = hashToScalar(concat(hashCache.bytes,z.bytes));
hashCache = hashToScalar(concat(hashCache.bytes,T1.toBytes()));
hashCache = hashToScalar(concat(hashCache.bytes,T2.toBytes()));
Scalar x = hashCache;
// PAPER LINES 52-53
Scalar taux = Scalar.ZERO;
taux = tau1.mul(x);
taux = taux.add(tau2.mul(x.sq()));
taux = taux.add(gamma.mul(z.sq()));
Scalar mu = x.mul(rho).add(alpha);
// PAPER LINES 54-57
Scalar[] l = new Scalar[N];
Scalar[] r = new Scalar[N];
l = VectorAdd(VectorSubtract(aL,VectorScalar(VectorPowers(Scalar.ONE),z)),VectorScalar(sL,x));
r = VectorAdd(Hadamard(VectorPowers(y),VectorAdd(aR,VectorAdd(VectorScalar(VectorPowers(Scalar.ONE),z),VectorScalar(sR,x)))),VectorScalar(VectorPowers(Scalar.TWO),z.sq()));
// DEBUG: Test if the l and r vectors match the polynomial forms
Scalar test_t = Scalar.ZERO;
test_t = test_t.add(t0).add(t1.mul(x));
test_t = test_t.add(t2.mul(x.sq()));
assert test_t.equals(InnerProduct(l,r));
// PAPER LINE 58
return new ProofTuple(V,A,S,T1,T2,taux,mu,l,r);
}
/* Given a range proof, determine if it is valid */
public static boolean VERIFY(ProofTuple proof)
{
// Reconstruct the challenges
Scalar hashCache = hashToScalar(proof.V.toBytes());
hashCache = hashToScalar(concat(hashCache.bytes,proof.A.toBytes()));
hashCache = hashToScalar(concat(hashCache.bytes,proof.S.toBytes()));
Scalar y = hashCache;
hashCache = hashToScalar(hashCache.bytes);
Scalar z = hashCache;
hashCache = hashToScalar(concat(hashCache.bytes,z.bytes));
hashCache = hashToScalar(concat(hashCache.bytes,proof.T1.toBytes()));
hashCache = hashToScalar(concat(hashCache.bytes,proof.T2.toBytes()));
Scalar x = hashCache;
// PAPER LINE 60
Scalar t = InnerProduct(proof.l,proof.r);
// PAPER LINE 61
Curve25519Point L61Left = G.scalarMultiply(proof.taux).add(H.scalarMultiply(t));
Scalar k = ComputeK(y,z);
Curve25519Point L61Right = H.scalarMultiply(k.add(z.mul(InnerProduct(VectorPowers(Scalar.ONE),VectorPowers(y)))));
L61Right = L61Right.add(proof.V.scalarMultiply(z.sq()));
L61Right = L61Right.add(proof.T1.scalarMultiply(x));
L61Right = L61Right.add(proof.T2.scalarMultiply(x.sq()));
if (!L61Right.equals(L61Left))
{
return false;
}
// PAPER LINE 62
Curve25519Point P = Curve25519Point.ZERO;
P = P.add(proof.A);
P = P.add(proof.S.scalarMultiply(x));
Scalar[] Gexp = new Scalar[N];
for (int i = 0; i < N; i++)
Gexp[i] = Scalar.ZERO.sub(z);
Scalar[] Hexp = new Scalar[N];
for (int i = 0; i < N; i++)
{
Hexp[i] = Scalar.ZERO;
Hexp[i] = Hexp[i].add(z.mul(y.pow(i)));
Hexp[i] = Hexp[i].add(z.sq().mul(Scalar.TWO.pow(i)));
Hexp[i] = Hexp[i].mul(Invert(y).pow(i));
}
P = P.add(VectorExponent(Gexp,Hexp));
// PAPER LINE 63
for (int i = 0; i < N; i++)
{
Hexp[i] = Scalar.ZERO;
Hexp[i] = Hexp[i].add(proof.r[i]);
Hexp[i] = Hexp[i].mul(Invert(y).pow(i));
}
Curve25519Point L63Right = VectorExponent(proof.l,Hexp).add(G.scalarMultiply(proof.mu));
if (!L63Right.equals(P))
{
return false;
}
return true;
}
public static void main(String[] args)
{
// Number of bits in the range
N = 64;
// Set the curve base points
G = Curve25519Point.G;
H = Curve25519Point.hashToPoint(G);
Gi = new Curve25519Point[N];
Hi = new Curve25519Point[N];
for (int i = 0; i < N; i++)
{
Gi[i] = getHpnGLookup(2*i);
Hi[i] = getHpnGLookup(2*i+1);
}
// Run a bunch of randomized trials
Random rando = new Random();
int TRIALS = 250;
int count = 0;
while (count < TRIALS)
{
long amount = rando.nextLong();
if (amount > Math.pow(2,N)-1 || amount < 0)
continue;
ProofTuple proof = PROVE(new Scalar(BigInteger.valueOf(amount)),randomScalar());
if (!VERIFY(proof))
System.out.println("Test failed");
count += 1;
}
}
}