From 7a68b065e000414be898df8fda202b59de964086 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Wed, 10 Jul 2024 20:56:53 -0400 Subject: [PATCH] Redo the Bulletproofs impl Uses the IP-impl from the FCMP++ work. --- coins/monero/ringct/bulletproofs/src/lib.rs | 73 ++- .../src/original/inner_product.rs | 303 +++++++++ .../ringct/bulletproofs/src/original/mod.rs | 599 ++++++++---------- .../src/plus/aggregate_range_proof.rs | 18 +- .../ringct/bulletproofs/src/plus/mod.rs | 6 +- .../src/plus/weighted_inner_product.rs | 3 - .../src/{plus => }/point_vector.rs | 19 +- .../ringct/bulletproofs/src/tests/mod.rs | 61 +- .../src/tests/original/inner_product.rs | 75 +++ .../bulletproofs/src/tests/original/mod.rs | 62 ++ .../src/tests/plus/aggregate_range_proof.rs | 4 +- coins/monero/wallet/src/send/tx.rs | 2 +- 12 files changed, 794 insertions(+), 431 deletions(-) create mode 100644 coins/monero/ringct/bulletproofs/src/original/inner_product.rs rename coins/monero/ringct/bulletproofs/src/{plus => }/point_vector.rs (74%) create mode 100644 coins/monero/ringct/bulletproofs/src/tests/original/inner_product.rs create mode 100644 coins/monero/ringct/bulletproofs/src/tests/original/mod.rs diff --git a/coins/monero/ringct/bulletproofs/src/lib.rs b/coins/monero/ringct/bulletproofs/src/lib.rs index d6e47d75..2a789575 100644 --- a/coins/monero/ringct/bulletproofs/src/lib.rs +++ b/coins/monero/ringct/bulletproofs/src/lib.rs @@ -20,6 +20,8 @@ pub use monero_generators::MAX_COMMITMENTS; use monero_primitives::Commitment; pub(crate) mod scalar_vector; +pub(crate) mod point_vector; + pub(crate) mod core; use crate::core::LOG_COMMITMENT_BITS; @@ -28,10 +30,16 @@ use batch_verifier::{BulletproofsBatchVerifier, BulletproofsPlusBatchVerifier}; pub use batch_verifier::BatchVerifier; pub(crate) mod original; -use crate::original::OriginalStruct; +use crate::original::{ + IpProof, AggregateRangeStatement as OriginalStatement, AggregateRangeWitness as OriginalWitness, + AggregateRangeProof as OriginalProof, +}; pub(crate) mod plus; -use crate::plus::*; +use crate::plus::{ + WipProof, AggregateRangeStatement as PlusStatement, AggregateRangeWitness as PlusWitness, + AggregateRangeProof as PlusProof, +}; #[cfg(test)] mod tests; @@ -55,9 +63,9 @@ pub enum BulletproofError { #[derive(Clone, PartialEq, Eq, Debug)] pub enum Bulletproof { /// A Bulletproof. - Original(OriginalStruct), + Original(OriginalProof), /// A Bulletproof+. - Plus(AggregateRangeProof), + Plus(PlusProof), } impl Bulletproof { @@ -100,7 +108,7 @@ impl Bulletproof { /// Prove the list of commitments are within [0 .. 2^64) with an aggregate Bulletproof. pub fn prove( rng: &mut R, - outputs: &[Commitment], + outputs: Vec, ) -> Result { if outputs.is_empty() { Err(BulletproofError::NoCommitments)?; @@ -108,7 +116,13 @@ impl Bulletproof { if outputs.len() > MAX_COMMITMENTS { Err(BulletproofError::TooManyCommitments)?; } - Ok(Bulletproof::Original(OriginalStruct::prove(rng, outputs))) + let commitments = outputs.iter().map(Commitment::calculate).collect::>(); + Ok(Bulletproof::Original( + OriginalStatement::new(&commitments) + .unwrap() + .prove(rng, OriginalWitness::new(outputs).unwrap()) + .unwrap(), + )) } /// Prove the list of commitments are within [0 .. 2^64) with an aggregate Bulletproof+. @@ -122,10 +136,11 @@ impl Bulletproof { if outputs.len() > MAX_COMMITMENTS { Err(BulletproofError::TooManyCommitments)?; } + let commitments = outputs.iter().map(Commitment::calculate).collect::>(); Ok(Bulletproof::Plus( - AggregateRangeStatement::new(outputs.iter().map(Commitment::calculate).collect()) + PlusStatement::new(&commitments) .unwrap() - .prove(rng, &Zeroizing::new(AggregateRangeWitness::new(outputs).unwrap())) + .prove(rng, &Zeroizing::new(PlusWitness::new(outputs).unwrap())) .unwrap(), )) } @@ -136,14 +151,17 @@ impl Bulletproof { match self { Bulletproof::Original(bp) => { let mut verifier = BulletproofsBatchVerifier::default(); - if !bp.verify(rng, &mut verifier, commitments) { + let Some(statement) = OriginalStatement::new(commitments) else { + return false; + }; + if !statement.verify(rng, &mut verifier, bp.clone()) { return false; } verifier.verify() } Bulletproof::Plus(bp) => { let mut verifier = BulletproofsPlusBatchVerifier::default(); - let Some(statement) = AggregateRangeStatement::new(commitments.to_vec()) else { + let Some(statement) = PlusStatement::new(commitments) else { return false; }; if !statement.verify(rng, &mut verifier, bp.clone()) { @@ -170,9 +188,14 @@ impl Bulletproof { commitments: &[EdwardsPoint], ) -> bool { match self { - Bulletproof::Original(bp) => bp.verify(rng, &mut verifier.original, commitments), + Bulletproof::Original(bp) => { + let Some(statement) = OriginalStatement::new(commitments) else { + return false; + }; + statement.verify(rng, &mut verifier.original, bp.clone()) + } Bulletproof::Plus(bp) => { - let Some(statement) = AggregateRangeStatement::new(commitments.to_vec()) else { + let Some(statement) = PlusStatement::new(commitments) else { return false; }; statement.verify(rng, &mut verifier.plus, bp.clone()) @@ -193,11 +216,11 @@ impl Bulletproof { write_point(&bp.T2, w)?; write_scalar(&bp.tau_x, w)?; write_scalar(&bp.mu, w)?; - specific_write_vec(&bp.L, w)?; - specific_write_vec(&bp.R, w)?; - write_scalar(&bp.a, w)?; - write_scalar(&bp.b, w)?; - write_scalar(&bp.t, w) + specific_write_vec(&bp.ip.L, w)?; + specific_write_vec(&bp.ip.R, w)?; + write_scalar(&bp.ip.a, w)?; + write_scalar(&bp.ip.b, w)?; + write_scalar(&bp.t_hat, w) } Bulletproof::Plus(bp) => { @@ -234,24 +257,26 @@ impl Bulletproof { /// Read a Bulletproof. pub fn read(r: &mut R) -> io::Result { - Ok(Bulletproof::Original(OriginalStruct { + Ok(Bulletproof::Original(OriginalProof { A: read_point(r)?, S: read_point(r)?, T1: read_point(r)?, T2: read_point(r)?, tau_x: read_scalar(r)?, mu: read_scalar(r)?, - L: read_vec(read_point, r)?, - R: read_vec(read_point, r)?, - a: read_scalar(r)?, - b: read_scalar(r)?, - t: read_scalar(r)?, + ip: IpProof { + L: read_vec(read_point, r)?, + R: read_vec(read_point, r)?, + a: read_scalar(r)?, + b: read_scalar(r)?, + }, + t_hat: read_scalar(r)?, })) } /// Read a Bulletproof+. pub fn read_plus(r: &mut R) -> io::Result { - Ok(Bulletproof::Plus(AggregateRangeProof { + Ok(Bulletproof::Plus(PlusProof { A: read_point(r)?, wip: WipProof { A: read_point(r)?, diff --git a/coins/monero/ringct/bulletproofs/src/original/inner_product.rs b/coins/monero/ringct/bulletproofs/src/original/inner_product.rs new file mode 100644 index 00000000..be8f3a83 --- /dev/null +++ b/coins/monero/ringct/bulletproofs/src/original/inner_product.rs @@ -0,0 +1,303 @@ +use std_shims::{vec, vec::Vec}; + +use zeroize::Zeroize; + +use curve25519_dalek::{Scalar, EdwardsPoint}; + +use monero_generators::H; +use monero_primitives::{INV_EIGHT, keccak256_to_scalar}; +use crate::{ + core::{multiexp_vartime, challenge_products}, + scalar_vector::ScalarVector, + point_vector::PointVector, + BulletproofsBatchVerifier, +}; + +/// An error from proving/verifying Inner-Product statements. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub(crate) enum IpError { + IncorrectAmountOfGenerators, + DifferingLrLengths, +} + +/// The Bulletproofs Inner-Product statement. +/// +/// This is for usage with Protocol 2 from the Bulletproofs paper. +#[derive(Clone, Debug)] +pub(crate) struct IpStatement { + // Weights for h_bold + h_bold_weights: ScalarVector, + // u as the discrete logarithm of G + u: Scalar, +} + +/// The witness for the Bulletproofs Inner-Product statement. +#[derive(Clone, Debug)] +pub(crate) struct IpWitness { + // a + a: ScalarVector, + // b + b: ScalarVector, +} + +impl IpWitness { + /// Construct a new witness for an Inner-Product statement. + /// + /// This functions return None if the lengths of a, b are mismatched, not a power of two, or are + /// empty. + pub(crate) fn new(a: ScalarVector, b: ScalarVector) -> Option { + if a.0.is_empty() || (a.len() != b.len()) { + None?; + } + + let mut power_of_2 = 1; + while power_of_2 < a.len() { + power_of_2 <<= 1; + } + if power_of_2 != a.len() { + None?; + } + + Some(Self { a, b }) + } +} + +/// A proof for the Bulletproofs Inner-Product statement. +#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] +pub(crate) struct IpProof { + pub(crate) L: Vec, + pub(crate) R: Vec, + pub(crate) a: Scalar, + pub(crate) b: Scalar, +} + +impl IpStatement { + /// Create a new Inner-Product statement which won't transcript P. + /// + /// This MUST only be called when P is deterministic to already transcripted elements. + pub(crate) fn new_without_P_transcript(h_bold_weights: ScalarVector, u: Scalar) -> Self { + Self { h_bold_weights, u } + } + + // Transcript a round of the protocol + fn transcript_L_R(transcript: Scalar, L: EdwardsPoint, R: EdwardsPoint) -> Scalar { + let mut transcript = transcript.to_bytes().to_vec(); + transcript.extend(L.compress().to_bytes()); + transcript.extend(R.compress().to_bytes()); + keccak256_to_scalar(transcript) + } + + /// Prove for this Inner-Product statement. + /// + /// Returns an error if this statement couldn't be proven for (such as if the witness isn't + /// consistent). + pub(crate) fn prove( + self, + mut transcript: Scalar, + witness: IpWitness, + ) -> Result { + let generators = crate::original::GENERATORS(); + let g_bold_slice = &generators.G[.. witness.a.len()]; + let h_bold_slice = &generators.H[.. witness.a.len()]; + + let (mut g_bold, mut h_bold, u, mut a, mut b) = { + let IpStatement { h_bold_weights, u } = self; + let u = H() * u; + + // Ensure we have the exact amount of weights + if h_bold_weights.len() != g_bold_slice.len() { + Err(IpError::IncorrectAmountOfGenerators)?; + } + // Acquire a local copy of the generators + let g_bold = PointVector(g_bold_slice.to_vec()); + let h_bold = PointVector(h_bold_slice.to_vec()).mul_vec(&h_bold_weights); + + let IpWitness { a, b } = witness; + + (g_bold, h_bold, u, a, b) + }; + + let mut L_vec = vec![]; + let mut R_vec = vec![]; + + // `else: (n > 1)` case, lines 18-35 of the Bulletproofs paper + // This interprets `g_bold.len()` as `n` + while g_bold.len() > 1 { + // Split a, b, g_bold, h_bold as needed for lines 20-24 + let (a1, a2) = a.clone().split(); + let (b1, b2) = b.clone().split(); + + let (g_bold1, g_bold2) = g_bold.split(); + let (h_bold1, h_bold2) = h_bold.split(); + + let n_hat = g_bold1.len(); + + // Sanity + debug_assert_eq!(a1.len(), n_hat); + debug_assert_eq!(a2.len(), n_hat); + debug_assert_eq!(b1.len(), n_hat); + debug_assert_eq!(b2.len(), n_hat); + debug_assert_eq!(g_bold1.len(), n_hat); + debug_assert_eq!(g_bold2.len(), n_hat); + debug_assert_eq!(h_bold1.len(), n_hat); + debug_assert_eq!(h_bold2.len(), n_hat); + + // cl, cr, lines 21-22 + let cl = a1.clone().inner_product(&b2); + let cr = a2.clone().inner_product(&b1); + + let L = { + let mut L_terms = Vec::with_capacity(1 + (2 * g_bold1.len())); + for (a, g) in a1.0.iter().zip(g_bold2.0.iter()) { + L_terms.push((*a, *g)); + } + for (b, h) in b2.0.iter().zip(h_bold1.0.iter()) { + L_terms.push((*b, *h)); + } + L_terms.push((cl, u)); + // Uses vartime since this isn't a ZK proof + multiexp_vartime(&L_terms) + }; + L_vec.push(L * INV_EIGHT()); + + let R = { + let mut R_terms = Vec::with_capacity(1 + (2 * g_bold1.len())); + for (a, g) in a2.0.iter().zip(g_bold1.0.iter()) { + R_terms.push((*a, *g)); + } + for (b, h) in b1.0.iter().zip(h_bold2.0.iter()) { + R_terms.push((*b, *h)); + } + R_terms.push((cr, u)); + multiexp_vartime(&R_terms) + }; + R_vec.push(R * INV_EIGHT()); + + // Now that we've calculate L, R, transcript them to receive x (26-27) + transcript = Self::transcript_L_R(transcript, *L_vec.last().unwrap(), *R_vec.last().unwrap()); + let x = transcript; + let x_inv = x.invert(); + + // The prover and verifier now calculate the following (28-31) + g_bold = PointVector(Vec::with_capacity(g_bold1.len())); + for (a, b) in g_bold1.0.into_iter().zip(g_bold2.0.into_iter()) { + g_bold.0.push(multiexp_vartime(&[(x_inv, a), (x, b)])); + } + h_bold = PointVector(Vec::with_capacity(h_bold1.len())); + for (a, b) in h_bold1.0.into_iter().zip(h_bold2.0.into_iter()) { + h_bold.0.push(multiexp_vartime(&[(x, a), (x_inv, b)])); + } + + // 32-34 + a = (a1 * x) + &(a2 * x_inv); + b = (b1 * x_inv) + &(b2 * x); + } + + // `if n = 1` case from line 14-17 + + // Sanity + debug_assert_eq!(g_bold.len(), 1); + debug_assert_eq!(h_bold.len(), 1); + debug_assert_eq!(a.len(), 1); + debug_assert_eq!(b.len(), 1); + + // We simply send a/b + Ok(IpProof { L: L_vec, R: R_vec, a: a[0], b: b[0] }) + } + + /// Queue an Inner-Product proof for batch verification. + /// + /// This will return Err if there is an error. This will return Ok if the proof was successfully + /// queued for batch verification. The caller is required to verify the batch in order to ensure + /// the proof is actually correct. + pub(crate) fn verify( + self, + verifier: &mut BulletproofsBatchVerifier, + ip_rows: usize, + mut transcript: Scalar, + verifier_weight: Scalar, + proof: IpProof, + ) -> Result<(), IpError> { + let generators = crate::original::GENERATORS(); + let g_bold_slice = &generators.G[.. ip_rows]; + let h_bold_slice = &generators.H[.. ip_rows]; + + let IpStatement { h_bold_weights, u } = self; + + // Verify the L/R lengths + { + // Calculate the discrete log w.r.t. 2 for the amount of generators present + let mut lr_len = 0; + while (1 << lr_len) < g_bold_slice.len() { + lr_len += 1; + } + + // This proof has less/more terms than the passed in generators are for + if proof.L.len() != lr_len { + Err(IpError::IncorrectAmountOfGenerators)?; + } + if proof.L.len() != proof.R.len() { + Err(IpError::DifferingLrLengths)?; + } + } + + // Again, we start with the `else: (n > 1)` case + + // We need x, x_inv per lines 25-27 for lines 28-31 + let mut xs = Vec::with_capacity(proof.L.len()); + for (L, R) in proof.L.iter().zip(proof.R.iter()) { + transcript = Self::transcript_L_R(transcript, *L, *R); + xs.push(transcript); + } + + // We calculate their inverse in batch + let mut x_invs = xs.clone(); + Scalar::batch_invert(&mut x_invs); + + // Now, with x and x_inv, we need to calculate g_bold', h_bold', P' + // + // For the sake of performance, we solely want to calculate all of these in terms of scalings + // for g_bold, h_bold, P, and don't want to actually perform intermediary scalings of the + // points + // + // L and R are easy, as it's simply x**2, x**-2 + // + // For the series of g_bold, h_bold, we use the `challenge_products` function + // For how that works, please see its own documentation + let product_cache = { + let mut challenges = Vec::with_capacity(proof.L.len()); + + let x_iter = xs.into_iter().zip(x_invs); + let lr_iter = proof.L.into_iter().zip(proof.R); + for ((x, x_inv), (L, R)) in x_iter.zip(lr_iter) { + challenges.push((x, x_inv)); + verifier.0.other.push((verifier_weight * (x * x), L.mul_by_cofactor())); + verifier.0.other.push((verifier_weight * (x_inv * x_inv), R.mul_by_cofactor())); + } + + challenge_products(&challenges) + }; + + // And now for the `if n = 1` case + let c = proof.a * proof.b; + + // The multiexp of these terms equate to the final permutation of P + // We now add terms for a * g_bold' + b * h_bold' b + c * u, with the scalars negative such + // that the terms sum to 0 for an honest prover + + // The g_bold * a term case from line 16 + #[allow(clippy::needless_range_loop)] + for i in 0 .. g_bold_slice.len() { + verifier.0.g_bold[i] -= verifier_weight * product_cache[i] * proof.a; + } + // The h_bold * b term case from line 16 + for i in 0 .. h_bold_slice.len() { + verifier.0.h_bold[i] -= + verifier_weight * product_cache[product_cache.len() - 1 - i] * proof.b * h_bold_weights[i]; + } + // The c * u term case from line 16 + verifier.0.h -= verifier_weight * c * u; + + Ok(()) + } +} diff --git a/coins/monero/ringct/bulletproofs/src/original/mod.rs b/coins/monero/ringct/bulletproofs/src/original/mod.rs index 43fa7cfe..10d63be4 100644 --- a/coins/monero/ringct/bulletproofs/src/original/mod.rs +++ b/coins/monero/ringct/bulletproofs/src/original/mod.rs @@ -1,395 +1,344 @@ -use std_shims::{vec, vec::Vec, sync::OnceLock}; +use std_shims::{sync::OnceLock, vec::Vec}; use rand_core::{RngCore, CryptoRng}; + use zeroize::Zeroize; -use subtle::{Choice, ConditionallySelectable}; -use curve25519_dalek::{ - constants::{ED25519_BASEPOINT_POINT, ED25519_BASEPOINT_TABLE}, - scalar::Scalar, - edwards::EdwardsPoint, -}; +use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, Scalar, EdwardsPoint}; -use monero_generators::{H, Generators}; -use monero_primitives::{INV_EIGHT, Commitment, keccak256_to_scalar}; +use monero_generators::{H, Generators, MAX_COMMITMENTS, COMMITMENT_BITS}; +use monero_primitives::{Commitment, INV_EIGHT, keccak256_to_scalar}; +use crate::{core::multiexp, scalar_vector::ScalarVector, BulletproofsBatchVerifier}; -use crate::{core::*, ScalarVector, batch_verifier::BulletproofsBatchVerifier}; +pub(crate) mod inner_product; +use inner_product::*; +pub(crate) use inner_product::IpProof; include!(concat!(env!("OUT_DIR"), "/generators.rs")); -static TWO_N_CELL: OnceLock = OnceLock::new(); -fn TWO_N() -> &'static ScalarVector { - TWO_N_CELL.get_or_init(|| ScalarVector::powers(Scalar::from(2u8), COMMITMENT_BITS)) +#[derive(Clone, Debug)] +pub(crate) struct AggregateRangeStatement<'a> { + commitments: &'a [EdwardsPoint], } -static IP12_CELL: OnceLock = OnceLock::new(); -fn IP12() -> Scalar { - *IP12_CELL.get_or_init(|| ScalarVector(vec![Scalar::ONE; COMMITMENT_BITS]).inner_product(TWO_N())) +#[derive(Clone, Debug)] +pub(crate) struct AggregateRangeWitness { + commitments: Vec, } -fn MN(outputs: usize) -> (usize, usize, usize) { - let mut logM = 0; - let mut M; - while { - M = 1 << logM; - (M <= MAX_COMMITMENTS) && (M < outputs) - } { - logM += 1; - } - - (logM + LOG_COMMITMENT_BITS, M, M * COMMITMENT_BITS) -} - -fn bit_decompose(commitments: &[Commitment]) -> (ScalarVector, ScalarVector) { - let (_, M, MN) = MN(commitments.len()); - - let sv = commitments.iter().map(|c| Scalar::from(c.amount)).collect::>(); - let mut aL = ScalarVector::new(MN); - let mut aR = ScalarVector::new(MN); - - for j in 0 .. M { - for i in (0 .. COMMITMENT_BITS).rev() { - let bit = - if j < sv.len() { Choice::from((sv[j][i / 8] >> (i % 8)) & 1) } else { Choice::from(0) }; - aL.0[(j * COMMITMENT_BITS) + i] = - Scalar::conditional_select(&Scalar::ZERO, &Scalar::ONE, bit); - aR.0[(j * COMMITMENT_BITS) + i] = - Scalar::conditional_select(&-Scalar::ONE, &Scalar::ZERO, bit); - } - } - - (aL, aR) -} - -fn hash_commitments>( - commitments: C, -) -> (Scalar, Vec) { - let V = commitments.into_iter().map(|c| c * INV_EIGHT()).collect::>(); - (keccak256_to_scalar(V.iter().flat_map(|V| V.compress().to_bytes()).collect::>()), V) -} - -fn alpha_rho( - rng: &mut R, - generators: &Generators, - aL: &ScalarVector, - aR: &ScalarVector, -) -> (Scalar, EdwardsPoint) { - fn vector_exponent(generators: &Generators, a: &ScalarVector, b: &ScalarVector) -> EdwardsPoint { - debug_assert_eq!(a.len(), b.len()); - (a * &generators.G[.. a.len()]) + (b * &generators.H[.. b.len()]) - } - - let ar = Scalar::random(rng); - (ar, (vector_exponent(generators, aL, aR) + (ED25519_BASEPOINT_TABLE * &ar)) * INV_EIGHT()) -} - -fn LR_statements( - a: &ScalarVector, - G_i: &[EdwardsPoint], - b: &ScalarVector, - H_i: &[EdwardsPoint], - cL: Scalar, - U: EdwardsPoint, -) -> Vec<(Scalar, EdwardsPoint)> { - let mut res = a - .0 - .iter() - .copied() - .zip(G_i.iter().copied()) - .chain(b.0.iter().copied().zip(H_i.iter().copied())) - .collect::>(); - res.push((cL, U)); - res -} - -fn hash_cache(cache: &mut Scalar, mash: &[[u8; 32]]) -> Scalar { - let slice = - &[cache.to_bytes().as_ref(), mash.iter().copied().flatten().collect::>().as_ref()] - .concat(); - *cache = keccak256_to_scalar(slice); - *cache -} - -fn hadamard_fold( - l: &[EdwardsPoint], - r: &[EdwardsPoint], - a: Scalar, - b: Scalar, -) -> Vec { - let mut res = Vec::with_capacity(l.len() / 2); - for i in 0 .. l.len() { - res.push(multiexp(&[(a, l[i]), (b, r[i])])); - } - res -} - -/// Internal structure representing a Bulletproof, as defined by Monero.. -#[doc(hidden)] -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct OriginalStruct { +#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] +pub struct AggregateRangeProof { pub(crate) A: EdwardsPoint, pub(crate) S: EdwardsPoint, pub(crate) T1: EdwardsPoint, pub(crate) T2: EdwardsPoint, pub(crate) tau_x: Scalar, pub(crate) mu: Scalar, - pub(crate) L: Vec, - pub(crate) R: Vec, - pub(crate) a: Scalar, - pub(crate) b: Scalar, - pub(crate) t: Scalar, + pub(crate) t_hat: Scalar, + pub(crate) ip: IpProof, } -impl OriginalStruct { - pub(crate) fn prove( - rng: &mut R, - commitments: &[Commitment], - ) -> OriginalStruct { - let (logMN, M, MN) = MN(commitments.len()); - - let (aL, aR) = bit_decompose(commitments); - let commitments_points = commitments.iter().map(Commitment::calculate).collect::>(); - let (mut cache, _) = hash_commitments(commitments_points.clone()); - - let (sL, sR) = - ScalarVector((0 .. (MN * 2)).map(|_| Scalar::random(&mut *rng)).collect::>()).split(); - - let generators = GENERATORS(); - let (mut alpha, A) = alpha_rho(&mut *rng, generators, &aL, &aR); - let (mut rho, S) = alpha_rho(&mut *rng, generators, &sL, &sR); - - let y = hash_cache(&mut cache, &[A.compress().to_bytes(), S.compress().to_bytes()]); - let mut cache = keccak256_to_scalar(y.to_bytes()); - let z = cache; - - let l0 = aL - z; - let l1 = sL; - - let mut zero_twos = Vec::with_capacity(MN); - let zpow = ScalarVector::powers(z, M + 2); - for j in 0 .. M { - for i in 0 .. COMMITMENT_BITS { - zero_twos.push(zpow[j + 2] * TWO_N()[i]); - } +impl<'a> AggregateRangeStatement<'a> { + pub(crate) fn new(commitments: &'a [EdwardsPoint]) -> Option { + if commitments.is_empty() || (commitments.len() > MAX_COMMITMENTS) { + None?; } + Some(Self { commitments }) + } +} - let yMN = ScalarVector::powers(y, MN); - let r0 = ((aR + z) * &yMN) + &ScalarVector(zero_twos); - let r1 = yMN * &sR; +impl AggregateRangeWitness { + pub(crate) fn new(commitments: Vec) -> Option { + if commitments.is_empty() || (commitments.len() > MAX_COMMITMENTS) { + None?; + } + Some(Self { commitments }) + } +} - let (T1, T2, x, mut tau_x) = { - let t1 = l0.clone().inner_product(&r1) + r0.clone().inner_product(&l1); - let t2 = l1.clone().inner_product(&r1); +impl<'a> AggregateRangeStatement<'a> { + fn initial_transcript(&self) -> (Scalar, Vec) { + let V = self.commitments.iter().map(|c| c * INV_EIGHT()).collect::>(); + (keccak256_to_scalar(V.iter().flat_map(|V| V.compress().to_bytes()).collect::>()), V) + } - let mut tau1 = Scalar::random(&mut *rng); - let mut tau2 = Scalar::random(&mut *rng); + fn transcript_A_S(transcript: Scalar, A: EdwardsPoint, S: EdwardsPoint) -> (Scalar, Scalar) { + let mut buf = Vec::with_capacity(96); + buf.extend(transcript.to_bytes()); + buf.extend(A.compress().to_bytes()); + buf.extend(S.compress().to_bytes()); + let y = keccak256_to_scalar(buf); + let z = keccak256_to_scalar(y.to_bytes()); + (y, z) + } - let T1 = multiexp(&[(t1, H()), (tau1, ED25519_BASEPOINT_POINT)]) * INV_EIGHT(); - let T2 = multiexp(&[(t2, H()), (tau2, ED25519_BASEPOINT_POINT)]) * INV_EIGHT(); + fn transcript_T12(transcript: Scalar, T1: EdwardsPoint, T2: EdwardsPoint) -> Scalar { + let mut buf = Vec::with_capacity(128); + buf.extend(transcript.to_bytes()); + buf.extend(transcript.to_bytes()); + buf.extend(T1.compress().to_bytes()); + buf.extend(T2.compress().to_bytes()); + keccak256_to_scalar(buf) + } - let x = - hash_cache(&mut cache, &[z.to_bytes(), T1.compress().to_bytes(), T2.compress().to_bytes()]); + fn transcript_tau_x_mu_t_hat( + transcript: Scalar, + tau_x: Scalar, + mu: Scalar, + t_hat: Scalar, + ) -> Scalar { + let mut buf = Vec::with_capacity(128); + buf.extend(transcript.to_bytes()); + buf.extend(transcript.to_bytes()); + buf.extend(tau_x.to_bytes()); + buf.extend(mu.to_bytes()); + buf.extend(t_hat.to_bytes()); + keccak256_to_scalar(buf) + } - let tau_x = (tau2 * (x * x)) + (tau1 * x); - - tau1.zeroize(); - tau2.zeroize(); - (T1, T2, x, tau_x) + #[allow(clippy::needless_pass_by_value)] + pub(crate) fn prove( + self, + rng: &mut (impl RngCore + CryptoRng), + witness: AggregateRangeWitness, + ) -> Option { + if self.commitments != witness.commitments.iter().map(Commitment::calculate).collect::>() + { + None? }; - let mu = (x * rho) + alpha; - alpha.zeroize(); - rho.zeroize(); + let generators = GENERATORS(); - for (i, gamma) in commitments.iter().map(|c| c.mask).enumerate() { - tau_x += zpow[i + 2] * gamma; + let (mut transcript, _) = self.initial_transcript(); + + // Find out the padded amount of commitments + let mut padded_pow_of_2 = 1; + while padded_pow_of_2 < witness.commitments.len() { + padded_pow_of_2 <<= 1; } - let l = l0 + &(l1 * x); - let r = r0 + &(r1 * x); - - let t = l.clone().inner_product(&r); - - let x_ip = - hash_cache(&mut cache, &[x.to_bytes(), tau_x.to_bytes(), mu.to_bytes(), t.to_bytes()]); - - let mut a = l; - let mut b = r; - - let yinv = y.invert(); - let yinvpow = ScalarVector::powers(yinv, MN); - - let mut G_proof = generators.G[.. a.len()].to_vec(); - let mut H_proof = generators.H[.. a.len()].to_vec(); - H_proof.iter_mut().zip(yinvpow.0.iter()).for_each(|(this_H, yinvpow)| *this_H *= yinvpow); - let U = H() * x_ip; - - let mut L = Vec::with_capacity(logMN); - let mut R = Vec::with_capacity(logMN); - - while a.len() != 1 { - let (aL, aR) = a.split(); - let (bL, bR) = b.split(); - - let cL = aL.clone().inner_product(&bR); - let cR = aR.clone().inner_product(&bL); - - let (G_L, G_R) = G_proof.split_at(aL.len()); - let (H_L, H_R) = H_proof.split_at(aL.len()); - - let L_i = multiexp(&LR_statements(&aL, G_R, &bR, H_L, cL, U)) * INV_EIGHT(); - let R_i = multiexp(&LR_statements(&aR, G_L, &bL, H_R, cR, U)) * INV_EIGHT(); - L.push(L_i); - R.push(R_i); - - let w = hash_cache(&mut cache, &[L_i.compress().to_bytes(), R_i.compress().to_bytes()]); - let w_inv = w.invert(); - - a = (aL * w) + &(aR * w_inv); - b = (bL * w_inv) + &(bR * w); - - if a.len() != 1 { - G_proof = hadamard_fold(G_L, G_R, w_inv, w); - H_proof = hadamard_fold(H_L, H_R, w, w_inv); + let mut aL = ScalarVector::new(padded_pow_of_2 * COMMITMENT_BITS); + for (i, commitment) in witness.commitments.iter().enumerate() { + let mut amount = commitment.amount; + for j in 0 .. COMMITMENT_BITS { + aL[(i * COMMITMENT_BITS) + j] = Scalar::from(amount & 1); + amount >>= 1; } } + let aR = aL.clone() - Scalar::ONE; - let res = OriginalStruct { A, S, T1, T2, tau_x, mu, L, R, a: a[0], b: b[0], t }; + let alpha = Scalar::random(&mut *rng); + let A = { + let mut terms = Vec::with_capacity(1 + (2 * aL.len())); + terms.push((alpha, ED25519_BASEPOINT_POINT)); + for (aL, G) in aL.0.iter().zip(&generators.G) { + terms.push((*aL, *G)); + } + for (aR, H) in aR.0.iter().zip(&generators.H) { + terms.push((*aR, *H)); + } + let res = multiexp(&terms) * INV_EIGHT(); + terms.zeroize(); + res + }; + + let mut sL = ScalarVector::new(padded_pow_of_2 * COMMITMENT_BITS); + let mut sR = ScalarVector::new(padded_pow_of_2 * COMMITMENT_BITS); + for i in 0 .. (padded_pow_of_2 * COMMITMENT_BITS) { + sL[i] = Scalar::random(&mut *rng); + sR[i] = Scalar::random(&mut *rng); + } + let rho = Scalar::random(&mut *rng); + + let S = { + let mut terms = Vec::with_capacity(1 + (2 * sL.len())); + terms.push((rho, ED25519_BASEPOINT_POINT)); + for (sL, G) in sL.0.iter().zip(&generators.G) { + terms.push((*sL, *G)); + } + for (sR, H) in sR.0.iter().zip(&generators.H) { + terms.push((*sR, *H)); + } + let res = multiexp(&terms) * INV_EIGHT(); + terms.zeroize(); + res + }; + + let (y, z) = Self::transcript_A_S(transcript, A, S); + transcript = z; + + let twos = ScalarVector::powers(Scalar::from(2u8), COMMITMENT_BITS); + + let l = [aL - z, sL]; + let y_pow_n = ScalarVector::powers(y, aR.len()); + let mut r = [((aR + z) * &y_pow_n), sR * &y_pow_n]; + { + let mut z_current = z * z; + for j in 0 .. padded_pow_of_2 { + for i in 0 .. COMMITMENT_BITS { + r[0].0[(j * COMMITMENT_BITS) + i] += z_current * twos[i]; + } + z_current *= z; + } + } + let t1 = (l[0].clone().inner_product(&r[1])) + (r[0].clone().inner_product(&l[1])); + let t2 = l[1].clone().inner_product(&r[1]); + + let tau_1 = Scalar::random(&mut *rng); + let T1 = { + let mut T1_terms = [(t1, H()), (tau_1, ED25519_BASEPOINT_POINT)]; + for term in &mut T1_terms { + term.0 *= INV_EIGHT(); + } + let T1 = multiexp(&T1_terms); + T1_terms.zeroize(); + T1 + }; + let tau_2 = Scalar::random(&mut *rng); + let T2 = { + let mut T2_terms = [(t2, H()), (tau_2, ED25519_BASEPOINT_POINT)]; + for term in &mut T2_terms { + term.0 *= INV_EIGHT(); + } + let T2 = multiexp(&T2_terms); + T2_terms.zeroize(); + T2 + }; + + transcript = Self::transcript_T12(transcript, T1, T2); + let x = transcript; + + let [l0, l1] = l; + let l = l0 + &(l1 * x); + let [r0, r1] = r; + let r = r0 + &(r1 * x); + let t_hat = l.clone().inner_product(&r); + let mut tau_x = ((tau_2 * x) + tau_1) * x; + { + let mut z_current = z * z; + for commitment in &witness.commitments { + tau_x += z_current * commitment.mask; + z_current *= z; + } + } + let mu = alpha + (rho * x); + + let y_inv_pow_n = ScalarVector::powers(y.invert(), l.len()); + + transcript = Self::transcript_tau_x_mu_t_hat(transcript, tau_x, mu, t_hat); + let x_ip = transcript; + + let ip = IpStatement::new_without_P_transcript(y_inv_pow_n, x_ip) + .prove(transcript, IpWitness::new(l, r).unwrap()) + .unwrap(); + + let res = AggregateRangeProof { A, S, T1, T2, tau_x, mu, t_hat, ip }; #[cfg(debug_assertions)] { let mut verifier = BulletproofsBatchVerifier::default(); - debug_assert!(res.verify(rng, &mut verifier, &commitments_points)); + debug_assert!(self.verify(rng, &mut verifier, res.clone())); debug_assert!(verifier.verify()); } - - res + Some(res) } #[must_use] - pub(crate) fn verify( - &self, - rng: &mut R, + pub(crate) fn verify( + self, + rng: &mut (impl RngCore + CryptoRng), verifier: &mut BulletproofsBatchVerifier, - commitments: &[EdwardsPoint], + mut proof: AggregateRangeProof, ) -> bool { - // Verify commitments are valid - if commitments.is_empty() || (commitments.len() > MAX_COMMITMENTS) { - return false; + let mut padded_pow_of_2 = 1; + while padded_pow_of_2 < self.commitments.len() { + padded_pow_of_2 <<= 1; + } + let ip_rows = padded_pow_of_2 * COMMITMENT_BITS; + + while verifier.0.g_bold.len() < ip_rows { + verifier.0.g_bold.push(Scalar::ZERO); + verifier.0.h_bold.push(Scalar::ZERO); } - // Verify L and R are properly sized - if self.L.len() != self.R.len() { - return false; + let (mut transcript, mut commitments) = self.initial_transcript(); + for commitment in &mut commitments { + *commitment = commitment.mul_by_cofactor(); } - let (logMN, M, MN) = MN(commitments.len()); - if self.L.len() != logMN { - return false; - } + let (y, z) = Self::transcript_A_S(transcript, proof.A, proof.S); + transcript = z; + transcript = Self::transcript_T12(transcript, proof.T1, proof.T2); + let x = transcript; + transcript = Self::transcript_tau_x_mu_t_hat(transcript, proof.tau_x, proof.mu, proof.t_hat); + let x_ip = transcript; - // Rebuild all challenges - let (mut cache, commitments) = hash_commitments(commitments.iter().copied()); - let y = hash_cache(&mut cache, &[self.A.compress().to_bytes(), self.S.compress().to_bytes()]); + proof.A = proof.A.mul_by_cofactor(); + proof.S = proof.S.mul_by_cofactor(); + proof.T1 = proof.T1.mul_by_cofactor(); + proof.T2 = proof.T2.mul_by_cofactor(); - let z = keccak256_to_scalar(y.to_bytes()); - cache = z; + let y_pow_n = ScalarVector::powers(y, ip_rows); + let y_inv_pow_n = ScalarVector::powers(y.invert(), ip_rows); - let x = hash_cache( - &mut cache, - &[z.to_bytes(), self.T1.compress().to_bytes(), self.T2.compress().to_bytes()], - ); + let twos = ScalarVector::powers(Scalar::from(2u8), COMMITMENT_BITS); - let x_ip = hash_cache( - &mut cache, - &[x.to_bytes(), self.tau_x.to_bytes(), self.mu.to_bytes(), self.t.to_bytes()], - ); - - let mut w_and_w_inv = Vec::with_capacity(logMN); - for (L, R) in self.L.iter().zip(&self.R) { - let w = hash_cache(&mut cache, &[L.compress().to_bytes(), R.compress().to_bytes()]); - let w_inv = w.invert(); - w_and_w_inv.push((w, w_inv)); - } - - // Convert the proof from * INV_EIGHT to its actual form - let normalize = |point: &EdwardsPoint| point.mul_by_cofactor(); - - let L = self.L.iter().map(normalize).collect::>(); - let R = self.R.iter().map(normalize).collect::>(); - let T1 = normalize(&self.T1); - let T2 = normalize(&self.T2); - let A = normalize(&self.A); - let S = normalize(&self.S); - - let commitments = commitments.iter().map(EdwardsPoint::mul_by_cofactor).collect::>(); - - // Verify it - let zpow = ScalarVector::powers(z, M + 3); - - // First multiexp + // 65 { - let verifier_weight = Scalar::random(rng); + let weight = Scalar::random(&mut *rng); + verifier.0.h += weight * proof.t_hat; + verifier.0.g += weight * proof.tau_x; - let ip1y = ScalarVector::powers(y, M * COMMITMENT_BITS).sum(); - let mut k = -(zpow[2] * ip1y); - for j in 1 ..= M { - k -= zpow[j + 2] * IP12(); - } - let y1 = self.t - ((z * ip1y) + k); - verifier.0.h -= verifier_weight * y1; + // Now that we've accumulated the lhs, negate the weight and accumulate the rhs + // These will now sum to 0 if equal + let weight = -weight; - verifier.0.g -= verifier_weight * self.tau_x; + verifier.0.h += weight * (z - (z * z)) * y_pow_n.sum(); - for (j, commitment) in commitments.iter().enumerate() { - verifier.0.other.push((verifier_weight * zpow[j + 2], *commitment)); + let mut z_current = z * z; + for commitment in &commitments { + verifier.0.other.push((weight * z_current, *commitment)); + z_current *= z; } - verifier.0.other.push((verifier_weight * x, T1)); - verifier.0.other.push((verifier_weight * (x * x), T2)); + let mut z_current = z * z * z; + for _ in 0 .. padded_pow_of_2 { + verifier.0.h -= weight * z_current * twos.clone().sum(); + z_current *= z; + } + verifier.0.other.push((weight * x, proof.T1)); + verifier.0.other.push((weight * (x * x), proof.T2)); } - // Second multiexp + let ip_weight = Scalar::random(&mut *rng); + + // 66 + verifier.0.other.push((ip_weight, proof.A)); + verifier.0.other.push((ip_weight * x, proof.S)); + // TODO: g_sum + for i in 0 .. ip_rows { + verifier.0.g_bold[i] += ip_weight * -z; + } + // TODO: h_sum + for i in 0 .. ip_rows { + verifier.0.h_bold[i] += ip_weight * z; + } { - let verifier_weight = Scalar::random(rng); - let z3 = (self.t - (self.a * self.b)) * x_ip; - verifier.0.h += verifier_weight * z3; - verifier.0.g -= verifier_weight * self.mu; - - verifier.0.other.push((verifier_weight, A)); - verifier.0.other.push((verifier_weight * x, S)); - - { - let ypow = ScalarVector::powers(y, MN); - let yinv = y.invert(); - let yinvpow = ScalarVector::powers(yinv, MN); - - let w_cache = challenge_products(&w_and_w_inv); - - while verifier.0.g_bold.len() < MN { - verifier.0.g_bold.push(Scalar::ZERO); + let mut z_current = z * z; + for j in 0 .. padded_pow_of_2 { + for i in 0 .. COMMITMENT_BITS { + let full_i = (j * COMMITMENT_BITS) + i; + verifier.0.h_bold[full_i] += ip_weight * y_inv_pow_n[full_i] * z_current * twos[i]; } - while verifier.0.h_bold.len() < MN { - verifier.0.h_bold.push(Scalar::ZERO); - } - - for i in 0 .. MN { - let g = (self.a * w_cache[i]) + z; - verifier.0.g_bold[i] -= verifier_weight * g; - - let mut h = self.b * yinvpow[i] * w_cache[(!i) & (MN - 1)]; - h -= ((zpow[(i / COMMITMENT_BITS) + 2] * TWO_N()[i % COMMITMENT_BITS]) + (z * ypow[i])) * - yinvpow[i]; - verifier.0.h_bold[i] -= verifier_weight * h; - } - } - - for i in 0 .. logMN { - verifier.0.other.push((verifier_weight * (w_and_w_inv[i].0 * w_and_w_inv[i].0), L[i])); - verifier.0.other.push((verifier_weight * (w_and_w_inv[i].1 * w_and_w_inv[i].1), R[i])); + z_current *= z; } } + verifier.0.h += ip_weight * x_ip * proof.t_hat; - true + // 67, 68 + verifier.0.g += ip_weight * -proof.mu; + let res = IpStatement::new_without_P_transcript(y_inv_pow_n, x_ip) + .verify(verifier, ip_rows, transcript, ip_weight, proof.ip); + res.is_ok() } } diff --git a/coins/monero/ringct/bulletproofs/src/plus/aggregate_range_proof.rs b/coins/monero/ringct/bulletproofs/src/plus/aggregate_range_proof.rs index 2f39c7d3..e3d4bc92 100644 --- a/coins/monero/ringct/bulletproofs/src/plus/aggregate_range_proof.rs +++ b/coins/monero/ringct/bulletproofs/src/plus/aggregate_range_proof.rs @@ -20,15 +20,9 @@ use crate::{ // Figure 3 of the Bulletproofs+ Paper #[derive(Clone, Debug)] -pub(crate) struct AggregateRangeStatement { +pub(crate) struct AggregateRangeStatement<'a> { generators: BpPlusGenerators, - V: Vec, -} - -impl Zeroize for AggregateRangeStatement { - fn zeroize(&mut self) { - self.V.zeroize(); - } + V: &'a [EdwardsPoint], } #[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)] @@ -61,8 +55,8 @@ struct AHatComputation { A_hat: EdwardsPoint, } -impl AggregateRangeStatement { - pub(crate) fn new(V: Vec) -> Option { +impl<'a> AggregateRangeStatement<'a> { + pub(crate) fn new(V: &'a [EdwardsPoint]) -> Option { if V.is_empty() || (V.len() > MAX_COMMITMENTS) { return None; } @@ -180,7 +174,7 @@ impl AggregateRangeStatement { // Commitments aren't transmitted INV_EIGHT though, so this multiplies by INV_EIGHT to enable // clearing its cofactor without mutating the value // For some reason, these values are transcripted * INV_EIGHT, not as transmitted - let V = V.into_iter().map(|V| V * INV_EIGHT()).collect::>(); + let V = V.iter().map(|V| V * INV_EIGHT()).collect::>(); let mut transcript = initial_transcript(V.iter()); let mut V = V.iter().map(EdwardsPoint::mul_by_cofactor).collect::>(); @@ -248,7 +242,7 @@ impl AggregateRangeStatement { ) -> bool { let Self { generators, V } = self; - let V = V.into_iter().map(|V| V * INV_EIGHT()).collect::>(); + let V = V.iter().map(|V| V * INV_EIGHT()).collect::>(); let mut transcript = initial_transcript(V.iter()); let V = V.iter().map(EdwardsPoint::mul_by_cofactor).collect::>(); diff --git a/coins/monero/ringct/bulletproofs/src/plus/mod.rs b/coins/monero/ringct/bulletproofs/src/plus/mod.rs index ec7ca6a7..92bff236 100644 --- a/coins/monero/ringct/bulletproofs/src/plus/mod.rs +++ b/coins/monero/ringct/bulletproofs/src/plus/mod.rs @@ -6,10 +6,7 @@ use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, scalar::Scalar, edwar use monero_generators::{H, Generators}; -pub(crate) use crate::scalar_vector::ScalarVector; - -mod point_vector; -pub(crate) use point_vector::PointVector; +pub(crate) use crate::{scalar_vector::ScalarVector, point_vector::PointVector}; pub(crate) mod transcript; pub(crate) mod weighted_inner_product; @@ -31,7 +28,6 @@ pub(crate) enum GeneratorsList { HBold, } -// TODO: Table these #[derive(Clone, Debug)] pub(crate) struct BpPlusGenerators { g_bold: &'static [EdwardsPoint], diff --git a/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs b/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs index abd7ea29..2a3bbe6c 100644 --- a/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs +++ b/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs @@ -107,9 +107,6 @@ impl WipStatement { // Prover's variant of the shared code block to calculate G/H/P when n > 1 // Returns each permutation of G/H since the prover needs to do operation on each permutation // P is dropped as it's unused in the prover's path - // TODO: It'd still probably be faster to keep in terms of the original generators, both between - // the reduced amount of group operations and the potential tabling of the generators under - // multiexp #[allow(clippy::too_many_arguments)] fn next_G_H( transcript: &mut Scalar, diff --git a/coins/monero/ringct/bulletproofs/src/plus/point_vector.rs b/coins/monero/ringct/bulletproofs/src/point_vector.rs similarity index 74% rename from coins/monero/ringct/bulletproofs/src/plus/point_vector.rs rename to coins/monero/ringct/bulletproofs/src/point_vector.rs index f9b52a61..c2635038 100644 --- a/coins/monero/ringct/bulletproofs/src/plus/point_vector.rs +++ b/coins/monero/ringct/bulletproofs/src/point_vector.rs @@ -1,14 +1,16 @@ use core::ops::{Index, IndexMut}; use std_shims::vec::Vec; -use zeroize::{Zeroize, ZeroizeOnDrop}; +use zeroize::Zeroize; use curve25519_dalek::edwards::EdwardsPoint; -#[cfg(test)] -use crate::{core::multiexp, plus::ScalarVector}; +use crate::scalar_vector::ScalarVector; -#[derive(Clone, PartialEq, Eq, Debug, Zeroize, ZeroizeOnDrop)] +#[cfg(test)] +use crate::core::multiexp; + +#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] pub(crate) struct PointVector(pub(crate) Vec); impl Index for PointVector { @@ -25,6 +27,15 @@ impl IndexMut for PointVector { } impl PointVector { + pub(crate) fn mul_vec(&self, vector: &ScalarVector) -> Self { + assert_eq!(self.len(), vector.len()); + let mut res = self.clone(); + for (i, val) in res.0.iter_mut().enumerate() { + *val *= vector.0[i]; + } + res + } + #[cfg(test)] pub(crate) fn multiexp(&self, vector: &ScalarVector) -> EdwardsPoint { debug_assert_eq!(self.len(), vector.len()); diff --git a/coins/monero/ringct/bulletproofs/src/tests/mod.rs b/coins/monero/ringct/bulletproofs/src/tests/mod.rs index 45a04362..fa4c8939 100644 --- a/coins/monero/ringct/bulletproofs/src/tests/mod.rs +++ b/coins/monero/ringct/bulletproofs/src/tests/mod.rs @@ -1,62 +1,13 @@ -use hex_literal::hex; -use rand_core::OsRng; +use rand_core::{RngCore, OsRng}; use curve25519_dalek::scalar::Scalar; -use monero_io::decompress_point; use monero_primitives::Commitment; +use crate::{batch_verifier::BatchVerifier, Bulletproof, BulletproofError}; -use crate::{batch_verifier::BatchVerifier, original::OriginalStruct, Bulletproof, BulletproofError}; - +mod original; mod plus; -#[test] -fn bulletproofs_vector() { - let scalar = |scalar| Scalar::from_canonical_bytes(scalar).unwrap(); - let point = |point| decompress_point(point).unwrap(); - - // Generated from Monero - assert!(Bulletproof::Original(OriginalStruct { - A: point(hex!("ef32c0b9551b804decdcb107eb22aa715b7ce259bf3c5cac20e24dfa6b28ac71")), - S: point(hex!("e1285960861783574ee2b689ae53622834eb0b035d6943103f960cd23e063fa0")), - T1: point(hex!("4ea07735f184ba159d0e0eb662bac8cde3eb7d39f31e567b0fbda3aa23fe5620")), - T2: point(hex!("b8390aa4b60b255630d40e592f55ec6b7ab5e3a96bfcdcd6f1cd1d2fc95f441e")), - tau_x: scalar(hex!("5957dba8ea9afb23d6e81cc048a92f2d502c10c749dc1b2bd148ae8d41ec7107")), - mu: scalar(hex!("923023b234c2e64774b820b4961f7181f6c1dc152c438643e5a25b0bf271bc02")), - L: vec![ - point(hex!("c45f656316b9ebf9d357fb6a9f85b5f09e0b991dd50a6e0ae9b02de3946c9d99")), - point(hex!("9304d2bf0f27183a2acc58cc755a0348da11bd345485fda41b872fee89e72aac")), - point(hex!("1bb8b71925d155dd9569f64129ea049d6149fdc4e7a42a86d9478801d922129b")), - point(hex!("5756a7bf887aa72b9a952f92f47182122e7b19d89e5dd434c747492b00e1c6b7")), - point(hex!("6e497c910d102592830555356af5ff8340e8d141e3fb60ea24cfa587e964f07d")), - point(hex!("f4fa3898e7b08e039183d444f3d55040f3c790ed806cb314de49f3068bdbb218")), - point(hex!("0bbc37597c3ead517a3841e159c8b7b79a5ceaee24b2a9a20350127aab428713")), - ], - R: vec![ - point(hex!("609420ba1702781692e84accfd225adb3d077aedc3cf8125563400466b52dbd9")), - point(hex!("fb4e1d079e7a2b0ec14f7e2a3943bf50b6d60bc346a54fcf562fb234b342abf8")), - point(hex!("6ae3ac97289c48ce95b9c557289e82a34932055f7f5e32720139824fe81b12e5")), - point(hex!("d071cc2ffbdab2d840326ad15f68c01da6482271cae3cf644670d1632f29a15c")), - point(hex!("e52a1754b95e1060589ba7ce0c43d0060820ebfc0d49dc52884bc3c65ad18af5")), - point(hex!("41573b06140108539957df71aceb4b1816d2409ce896659aa5c86f037ca5e851")), - point(hex!("a65970b2cc3c7b08b2b5b739dbc8e71e646783c41c625e2a5b1535e3d2e0f742")), - ], - a: scalar(hex!("0077c5383dea44d3cd1bc74849376bd60679612dc4b945255822457fa0c0a209")), - b: scalar(hex!("fe80cf5756473482581e1d38644007793ddc66fdeb9404ec1689a907e4863302")), - t: scalar(hex!("40dfb08e09249040df997851db311bd6827c26e87d6f0f332c55be8eef10e603")) - }) - .verify( - &mut OsRng, - &[ - // For some reason, these vectors are * INV_EIGHT - point(hex!("8e8f23f315edae4f6c2f948d9a861e0ae32d356b933cd11d2f0e031ac744c41f")) - .mul_by_cofactor(), - point(hex!("2829cbd025aa54cd6e1b59a032564f22f0b2e5627f7f2c4297f90da438b5510f")) - .mul_by_cofactor(), - ] - )); -} - macro_rules! bulletproofs_tests { ($name: ident, $max: ident, $plus: literal) => { #[test] @@ -65,13 +16,13 @@ macro_rules! bulletproofs_tests { let mut verifier = BatchVerifier::new(); for i in 1 ..= 16 { let commitments = (1 ..= i) - .map(|i| Commitment::new(Scalar::random(&mut OsRng), u64::try_from(i).unwrap())) + .map(|_| Commitment::new(Scalar::random(&mut OsRng), OsRng.next_u64())) .collect::>(); let bp = if $plus { Bulletproof::prove_plus(&mut OsRng, commitments.clone()).unwrap() } else { - Bulletproof::prove(&mut OsRng, &commitments).unwrap() + Bulletproof::prove(&mut OsRng, commitments.clone()).unwrap() }; let commitments = commitments.iter().map(Commitment::calculate).collect::>(); @@ -92,7 +43,7 @@ macro_rules! bulletproofs_tests { (if $plus { Bulletproof::prove_plus(&mut OsRng, commitments) } else { - Bulletproof::prove(&mut OsRng, &commitments) + Bulletproof::prove(&mut OsRng, commitments) }) .unwrap_err(), BulletproofError::TooManyCommitments, diff --git a/coins/monero/ringct/bulletproofs/src/tests/original/inner_product.rs b/coins/monero/ringct/bulletproofs/src/tests/original/inner_product.rs new file mode 100644 index 00000000..98aa842f --- /dev/null +++ b/coins/monero/ringct/bulletproofs/src/tests/original/inner_product.rs @@ -0,0 +1,75 @@ +// The inner product relation is P = sum(g_bold * a, h_bold * b, g * (a * b)) + +use rand_core::OsRng; + +use curve25519_dalek::Scalar; + +use monero_generators::H; + +use crate::{ + scalar_vector::ScalarVector, + point_vector::PointVector, + original::{ + GENERATORS, + inner_product::{IpStatement, IpWitness}, + }, + BulletproofsBatchVerifier, +}; + +#[test] +fn test_zero_inner_product() { + let statement = + IpStatement::new_without_P_transcript(ScalarVector(vec![Scalar::ONE; 1]), Scalar::ONE); + let witness = IpWitness::new(ScalarVector::new(1), ScalarVector::new(1)).unwrap(); + + let transcript = Scalar::random(&mut OsRng); + let proof = statement.clone().prove(transcript, witness).unwrap(); + + let mut verifier = BulletproofsBatchVerifier::default(); + verifier.0.g_bold = vec![Scalar::ZERO; 1]; + verifier.0.h_bold = vec![Scalar::ZERO; 1]; + statement.verify(&mut verifier, 1, transcript, Scalar::random(&mut OsRng), proof).unwrap(); + assert!(verifier.verify()); +} + +#[test] +fn test_inner_product() { + // P = sum(g_bold * a, h_bold * b, g * u * ) + let generators = GENERATORS(); + let mut verifier = BulletproofsBatchVerifier::default(); + verifier.0.g_bold = vec![Scalar::ZERO; 32]; + verifier.0.h_bold = vec![Scalar::ZERO; 32]; + for i in [1, 2, 4, 8, 16, 32] { + let g = H(); + let mut g_bold = vec![]; + let mut h_bold = vec![]; + for i in 0 .. i { + g_bold.push(generators.G[i]); + h_bold.push(generators.H[i]); + } + let g_bold = PointVector(g_bold); + let h_bold = PointVector(h_bold); + + let mut a = ScalarVector::new(i); + let mut b = ScalarVector::new(i); + + for i in 0 .. i { + a[i] = Scalar::random(&mut OsRng); + b[i] = Scalar::random(&mut OsRng); + } + + let P = g_bold.multiexp(&a) + h_bold.multiexp(&b) + (g * a.clone().inner_product(&b)); + + let statement = + IpStatement::new_without_P_transcript(ScalarVector(vec![Scalar::ONE; i]), Scalar::ONE); + let witness = IpWitness::new(a, b).unwrap(); + + let transcript = Scalar::random(&mut OsRng); + let proof = statement.clone().prove(transcript, witness).unwrap(); + + let weight = Scalar::random(&mut OsRng); + verifier.0.other.push((weight, P)); + statement.verify(&mut verifier, i, transcript, weight, proof).unwrap(); + } + assert!(verifier.verify()); +} diff --git a/coins/monero/ringct/bulletproofs/src/tests/original/mod.rs b/coins/monero/ringct/bulletproofs/src/tests/original/mod.rs new file mode 100644 index 00000000..c0010b4f --- /dev/null +++ b/coins/monero/ringct/bulletproofs/src/tests/original/mod.rs @@ -0,0 +1,62 @@ +use hex_literal::hex; +use rand_core::OsRng; + +use curve25519_dalek::scalar::Scalar; + +use monero_io::decompress_point; + +use crate::{ + original::{IpProof, AggregateRangeProof as OriginalProof}, + Bulletproof, +}; + +mod inner_product; + +#[test] +fn bulletproofs_vector() { + let scalar = |scalar| Scalar::from_canonical_bytes(scalar).unwrap(); + let point = |point| decompress_point(point).unwrap(); + + // Generated from Monero + assert!(Bulletproof::Original(OriginalProof { + A: point(hex!("ef32c0b9551b804decdcb107eb22aa715b7ce259bf3c5cac20e24dfa6b28ac71")), + S: point(hex!("e1285960861783574ee2b689ae53622834eb0b035d6943103f960cd23e063fa0")), + T1: point(hex!("4ea07735f184ba159d0e0eb662bac8cde3eb7d39f31e567b0fbda3aa23fe5620")), + T2: point(hex!("b8390aa4b60b255630d40e592f55ec6b7ab5e3a96bfcdcd6f1cd1d2fc95f441e")), + tau_x: scalar(hex!("5957dba8ea9afb23d6e81cc048a92f2d502c10c749dc1b2bd148ae8d41ec7107")), + mu: scalar(hex!("923023b234c2e64774b820b4961f7181f6c1dc152c438643e5a25b0bf271bc02")), + ip: IpProof { + L: vec![ + point(hex!("c45f656316b9ebf9d357fb6a9f85b5f09e0b991dd50a6e0ae9b02de3946c9d99")), + point(hex!("9304d2bf0f27183a2acc58cc755a0348da11bd345485fda41b872fee89e72aac")), + point(hex!("1bb8b71925d155dd9569f64129ea049d6149fdc4e7a42a86d9478801d922129b")), + point(hex!("5756a7bf887aa72b9a952f92f47182122e7b19d89e5dd434c747492b00e1c6b7")), + point(hex!("6e497c910d102592830555356af5ff8340e8d141e3fb60ea24cfa587e964f07d")), + point(hex!("f4fa3898e7b08e039183d444f3d55040f3c790ed806cb314de49f3068bdbb218")), + point(hex!("0bbc37597c3ead517a3841e159c8b7b79a5ceaee24b2a9a20350127aab428713")), + ], + R: vec![ + point(hex!("609420ba1702781692e84accfd225adb3d077aedc3cf8125563400466b52dbd9")), + point(hex!("fb4e1d079e7a2b0ec14f7e2a3943bf50b6d60bc346a54fcf562fb234b342abf8")), + point(hex!("6ae3ac97289c48ce95b9c557289e82a34932055f7f5e32720139824fe81b12e5")), + point(hex!("d071cc2ffbdab2d840326ad15f68c01da6482271cae3cf644670d1632f29a15c")), + point(hex!("e52a1754b95e1060589ba7ce0c43d0060820ebfc0d49dc52884bc3c65ad18af5")), + point(hex!("41573b06140108539957df71aceb4b1816d2409ce896659aa5c86f037ca5e851")), + point(hex!("a65970b2cc3c7b08b2b5b739dbc8e71e646783c41c625e2a5b1535e3d2e0f742")), + ], + a: scalar(hex!("0077c5383dea44d3cd1bc74849376bd60679612dc4b945255822457fa0c0a209")), + b: scalar(hex!("fe80cf5756473482581e1d38644007793ddc66fdeb9404ec1689a907e4863302")), + }, + t_hat: scalar(hex!("40dfb08e09249040df997851db311bd6827c26e87d6f0f332c55be8eef10e603")) + }) + .verify( + &mut OsRng, + &[ + // For some reason, these vectors are * INV_EIGHT + point(hex!("8e8f23f315edae4f6c2f948d9a861e0ae32d356b933cd11d2f0e031ac744c41f")) + .mul_by_cofactor(), + point(hex!("2829cbd025aa54cd6e1b59a032564f22f0b2e5627f7f2c4297f90da438b5510f")) + .mul_by_cofactor(), + ] + )); +} diff --git a/coins/monero/ringct/bulletproofs/src/tests/plus/aggregate_range_proof.rs b/coins/monero/ringct/bulletproofs/src/tests/plus/aggregate_range_proof.rs index fc5d429e..ba5d0543 100644 --- a/coins/monero/ringct/bulletproofs/src/tests/plus/aggregate_range_proof.rs +++ b/coins/monero/ringct/bulletproofs/src/tests/plus/aggregate_range_proof.rs @@ -17,8 +17,8 @@ fn test_aggregate_range_proof() { for _ in 0 .. m { commitments.push(Commitment::new(Scalar::random(&mut OsRng), OsRng.next_u64())); } - let commitment_points = commitments.iter().map(Commitment::calculate).collect(); - let statement = AggregateRangeStatement::new(commitment_points).unwrap(); + let commitment_points = commitments.iter().map(Commitment::calculate).collect::>(); + let statement = AggregateRangeStatement::new(&commitment_points).unwrap(); let witness = AggregateRangeWitness::new(commitments).unwrap(); let proof = statement.clone().prove(&mut OsRng, &witness).unwrap(); diff --git a/coins/monero/wallet/src/send/tx.rs b/coins/monero/wallet/src/send/tx.rs index b4ea3970..70bb8d62 100644 --- a/coins/monero/wallet/src/send/tx.rs +++ b/coins/monero/wallet/src/send/tx.rs @@ -272,7 +272,7 @@ impl SignableTransactionWithKeyImages { let bulletproof = { let mut bp_rng = self.intent.seeded_rng(b"bulletproof"); (match self.intent.rct_type { - RctType::ClsagBulletproof => Bulletproof::prove(&mut bp_rng, &bp_commitments), + RctType::ClsagBulletproof => Bulletproof::prove(&mut bp_rng, bp_commitments), RctType::ClsagBulletproofPlus => Bulletproof::prove_plus(&mut bp_rng, bp_commitments), _ => panic!("unsupported RctType"), })