From 5629c94b8bacfbb1aadee5820403206402cfa1aa Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Sat, 2 Mar 2024 17:15:16 -0500 Subject: [PATCH] Reconcile the two copies of scalar_vector.rs in monero-serai --- .../src/ringct/bulletproofs/original.rs | 41 ++-- .../plus/aggregate_range_proof.rs | 26 ++- .../src/ringct/bulletproofs/plus/mod.rs | 3 +- .../ringct/bulletproofs/plus/scalar_vector.rs | 114 ---------- .../plus/weighted_inner_product.rs | 21 +- .../src/ringct/bulletproofs/scalar_vector.rs | 194 ++++++++++-------- .../plus/weighted_inner_product.rs | 3 +- 7 files changed, 164 insertions(+), 238 deletions(-) delete mode 100644 coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs diff --git a/coins/monero/src/ringct/bulletproofs/original.rs b/coins/monero/src/ringct/bulletproofs/original.rs index 5e50c02e..0e841080 100644 --- a/coins/monero/src/ringct/bulletproofs/original.rs +++ b/coins/monero/src/ringct/bulletproofs/original.rs @@ -9,7 +9,7 @@ use curve25519_dalek::{scalar::Scalar as DalekScalar, edwards::EdwardsPoint as D use group::{ff::Field, Group}; use dalek_ff_group::{ED25519_BASEPOINT_POINT as G, Scalar, EdwardsPoint}; -use multiexp::BatchVerifier; +use multiexp::{BatchVerifier, multiexp}; use crate::{Commitment, ringct::bulletproofs::core::*}; @@ -17,7 +17,20 @@ include!(concat!(env!("OUT_DIR"), "/generators.rs")); static IP12_CELL: OnceLock = OnceLock::new(); pub(crate) fn IP12() -> Scalar { - *IP12_CELL.get_or_init(|| inner_product(&ScalarVector(vec![Scalar::ONE; N]), TWO_N())) + *IP12_CELL.get_or_init(|| ScalarVector(vec![Scalar::ONE; N]).inner_product(TWO_N())) +} + +pub(crate) fn hadamard_fold( + l: &[EdwardsPoint], + r: &[EdwardsPoint], + a: Scalar, + b: Scalar, +) -> Vec { + let mut res = Vec::with_capacity(l.len() / 2); + for i in 0 .. l.len() { + res.push(multiexp(&[(a, l[i]), (b, r[i])])); + } + res } #[derive(Clone, PartialEq, Eq, Debug)] @@ -57,7 +70,7 @@ impl OriginalStruct { let mut cache = hash_to_scalar(&y.to_bytes()); let z = cache; - let l0 = &aL - z; + let l0 = aL - z; let l1 = sL; let mut zero_twos = Vec::with_capacity(MN); @@ -69,12 +82,12 @@ impl OriginalStruct { } let yMN = ScalarVector::powers(y, MN); - let r0 = (&(aR + z) * &yMN) + ScalarVector(zero_twos); - let r1 = yMN * sR; + let r0 = ((aR + z) * &yMN) + &ScalarVector(zero_twos); + let r1 = yMN * &sR; let (T1, T2, x, mut taux) = { - let t1 = inner_product(&l0, &r1) + inner_product(&l1, &r0); - let t2 = inner_product(&l1, &r1); + let t1 = l0.clone().inner_product(&r1) + r0.clone().inner_product(&l1); + let t2 = l1.clone().inner_product(&r1); let mut tau1 = Scalar::random(&mut *rng); let mut tau2 = Scalar::random(&mut *rng); @@ -100,10 +113,10 @@ impl OriginalStruct { taux += zpow[i + 2] * gamma; } - let l = &l0 + &(l1 * x); - let r = &r0 + &(r1 * x); + let l = l0 + &(l1 * x); + let r = r0 + &(r1 * x); - let t = inner_product(&l, &r); + let t = l.clone().inner_product(&r); let x_ip = hash_cache(&mut cache, &[x.to_bytes(), taux.to_bytes(), mu.to_bytes(), t.to_bytes()]); @@ -126,8 +139,8 @@ impl OriginalStruct { let (aL, aR) = a.split(); let (bL, bR) = b.split(); - let cL = inner_product(&aL, &bR); - let cR = inner_product(&aR, &bL); + let cL = aL.clone().inner_product(&bR); + let cR = aR.clone().inner_product(&bL); let (G_L, G_R) = G_proof.split_at(aL.len()); let (H_L, H_R) = H_proof.split_at(aL.len()); @@ -140,8 +153,8 @@ impl OriginalStruct { let w = hash_cache(&mut cache, &[L_i.compress().to_bytes(), R_i.compress().to_bytes()]); let winv = w.invert().unwrap(); - a = (aL * w) + (aR * winv); - b = (bL * winv) + (bR * w); + a = (aL * w) + &(aR * winv); + b = (bL * winv) + &(bR * w); if a.len() != 1 { G_proof = hadamard_fold(G_L, G_R, winv, w); diff --git a/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs b/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs index 859cb1e4..af5c0275 100644 --- a/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs +++ b/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs @@ -112,7 +112,7 @@ impl AggregateRangeStatement { let mut d = ScalarVector::new(mn); for j in 1 ..= V.len() { z_pow.push(z.pow(Scalar::from(2 * u64::try_from(j).unwrap()))); // TODO: Optimize this - d = d.add_vec(&Self::d_j(j, V.len()).mul(z_pow[j - 1])); + d = d + &(Self::d_j(j, V.len()) * (z_pow[j - 1])); } let mut ascending_y = ScalarVector(vec![y]); @@ -124,7 +124,8 @@ impl AggregateRangeStatement { let mut descending_y = ascending_y.clone(); descending_y.0.reverse(); - let d_descending_y = d.mul_vec(&descending_y); + let d_descending_y = d.clone() * &descending_y; + let d_descending_y_plus_z = d_descending_y + z; let y_mn_plus_one = descending_y[0] * y; @@ -135,9 +136,9 @@ impl AggregateRangeStatement { let neg_z = -z; let mut A_terms = Vec::with_capacity((generators.len() * 2) + 2); - for (i, d_y_z) in d_descending_y.add(z).0.drain(..).enumerate() { + for (i, d_y_z) in d_descending_y_plus_z.0.iter().enumerate() { A_terms.push((neg_z, generators.generator(GeneratorsList::GBold1, i))); - A_terms.push((d_y_z, generators.generator(GeneratorsList::HBold1, i))); + A_terms.push((*d_y_z, generators.generator(GeneratorsList::HBold1, i))); } A_terms.push((y_mn_plus_one, commitment_accum)); A_terms.push(( @@ -145,7 +146,14 @@ impl AggregateRangeStatement { Generators::g(), )); - (y, d_descending_y, y_mn_plus_one, z, ScalarVector(z_pow), A + multiexp_vartime(&A_terms)) + ( + y, + d_descending_y_plus_z, + y_mn_plus_one, + z, + ScalarVector(z_pow), + A + multiexp_vartime(&A_terms), + ) } pub(crate) fn prove( @@ -191,7 +199,7 @@ impl AggregateRangeStatement { a_l.0.append(&mut u64_decompose(*witness.values.get(j - 1).unwrap_or(&0)).0); } - let a_r = a_l.sub(Scalar::ONE); + let a_r = a_l.clone() - Scalar::ONE; let alpha = Scalar::random(&mut *rng); @@ -209,11 +217,11 @@ impl AggregateRangeStatement { // Multiply by INV_EIGHT per earlier commentary A.0 *= crate::INV_EIGHT(); - let (y, d_descending_y, y_mn_plus_one, z, z_pow, A_hat) = + let (y, d_descending_y_plus_z, y_mn_plus_one, z, z_pow, A_hat) = Self::compute_A_hat(PointVector(V), &generators, &mut transcript, A); - let a_l = a_l.sub(z); - let a_r = a_r.add_vec(&d_descending_y).add(z); + let a_l = a_l - z; + let a_r = a_r + &d_descending_y_plus_z; let mut alpha = alpha; for j in 1 ..= witness.gammas.len() { alpha += z_pow[j - 1] * witness.gammas[j - 1] * y_mn_plus_one; diff --git a/coins/monero/src/ringct/bulletproofs/plus/mod.rs b/coins/monero/src/ringct/bulletproofs/plus/mod.rs index 6a2d7b9c..30417821 100644 --- a/coins/monero/src/ringct/bulletproofs/plus/mod.rs +++ b/coins/monero/src/ringct/bulletproofs/plus/mod.rs @@ -3,8 +3,7 @@ use group::Group; use dalek_ff_group::{Scalar, EdwardsPoint}; -mod scalar_vector; -pub(crate) use scalar_vector::{ScalarVector, weighted_inner_product}; +pub(crate) use crate::ringct::bulletproofs::scalar_vector::ScalarVector; mod point_vector; pub(crate) use point_vector::PointVector; diff --git a/coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs b/coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs deleted file mode 100644 index 7bc0c3f4..00000000 --- a/coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs +++ /dev/null @@ -1,114 +0,0 @@ -use core::{ - borrow::Borrow, - ops::{Index, IndexMut}, -}; -use std_shims::vec::Vec; - -use zeroize::Zeroize; - -use group::ff::Field; -use dalek_ff_group::Scalar; - -#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] -pub(crate) struct ScalarVector(pub(crate) Vec); - -impl Index for ScalarVector { - type Output = Scalar; - fn index(&self, index: usize) -> &Scalar { - &self.0[index] - } -} - -impl IndexMut for ScalarVector { - fn index_mut(&mut self, index: usize) -> &mut Scalar { - &mut self.0[index] - } -} - -impl ScalarVector { - pub(crate) fn new(len: usize) -> Self { - ScalarVector(vec![Scalar::ZERO; len]) - } - - pub(crate) fn add(&self, scalar: impl Borrow) -> Self { - let mut res = self.clone(); - for val in &mut res.0 { - *val += scalar.borrow(); - } - res - } - - pub(crate) fn sub(&self, scalar: impl Borrow) -> Self { - let mut res = self.clone(); - for val in &mut res.0 { - *val -= scalar.borrow(); - } - res - } - - pub(crate) fn mul(&self, scalar: impl Borrow) -> Self { - let mut res = self.clone(); - for val in &mut res.0 { - *val *= scalar.borrow(); - } - res - } - - pub(crate) fn add_vec(&self, vector: &Self) -> Self { - debug_assert_eq!(self.len(), vector.len()); - let mut res = self.clone(); - for (i, val) in res.0.iter_mut().enumerate() { - *val += vector.0[i]; - } - res - } - - pub(crate) fn mul_vec(&self, vector: &Self) -> Self { - debug_assert_eq!(self.len(), vector.len()); - let mut res = self.clone(); - for (i, val) in res.0.iter_mut().enumerate() { - *val *= vector.0[i]; - } - res - } - - pub(crate) fn inner_product(&self, vector: &Self) -> Scalar { - self.mul_vec(vector).sum() - } - - pub(crate) fn powers(x: Scalar, len: usize) -> Self { - debug_assert!(len != 0); - - let mut res = Vec::with_capacity(len); - res.push(Scalar::ONE); - res.push(x); - for i in 2 .. len { - res.push(res[i - 1] * x); - } - res.truncate(len); - ScalarVector(res) - } - - pub(crate) fn sum(mut self) -> Scalar { - self.0.drain(..).sum() - } - - pub(crate) fn len(&self) -> usize { - self.0.len() - } - - pub(crate) fn split(mut self) -> (Self, Self) { - debug_assert!(self.len() > 1); - let r = self.0.split_off(self.0.len() / 2); - debug_assert_eq!(self.len(), r.len()); - (self, ScalarVector(r)) - } -} - -pub(crate) fn weighted_inner_product( - a: &ScalarVector, - b: &ScalarVector, - y: &ScalarVector, -) -> Scalar { - a.inner_product(&b.mul_vec(y)) -} diff --git a/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs b/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs index 1bc1e85d..09bb6748 100644 --- a/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs +++ b/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs @@ -4,7 +4,7 @@ use rand_core::{RngCore, CryptoRng}; use zeroize::{Zeroize, ZeroizeOnDrop}; -use multiexp::{multiexp, multiexp_vartime, BatchVerifier}; +use multiexp::{BatchVerifier, multiexp, multiexp_vartime}; use group::{ ff::{Field, PrimeField}, GroupEncoding, @@ -12,8 +12,7 @@ use group::{ use dalek_ff_group::{Scalar, EdwardsPoint}; use crate::ringct::bulletproofs::plus::{ - ScalarVector, PointVector, GeneratorsList, Generators, padded_pow_of_2, weighted_inner_product, - transcript::*, + ScalarVector, PointVector, GeneratorsList, Generators, padded_pow_of_2, transcript::*, }; // Figure 1 @@ -219,7 +218,7 @@ impl WipStatement { .zip(g_bold.0.iter().copied()) .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied())) .collect::>(); - P_terms.push((weighted_inner_product(&witness.a, &witness.b, &y), g)); + P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g)); P_terms.push((witness.alpha, h)); debug_assert_eq!(multiexp(&P_terms), P); P_terms.zeroize(); @@ -258,14 +257,13 @@ impl WipStatement { let d_l = Scalar::random(&mut *rng); let d_r = Scalar::random(&mut *rng); - let c_l = weighted_inner_product(&a1, &b2, &y); - let c_r = weighted_inner_product(&(a2.mul(y_n_hat)), &b1, &y); + let c_l = a1.clone().weighted_inner_product(&b2, &y); + let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y); // TODO: Calculate these with a batch inversion let y_inv_n_hat = y_n_hat.invert().unwrap(); - let mut L_terms = a1 - .mul(y_inv_n_hat) + let mut L_terms = (a1.clone() * y_inv_n_hat) .0 .drain(..) .zip(g_bold2.0.iter().copied()) @@ -277,8 +275,7 @@ impl WipStatement { L_vec.push(L); L_terms.zeroize(); - let mut R_terms = a2 - .mul(y_n_hat) + let mut R_terms = (a2.clone() * y_n_hat) .0 .drain(..) .zip(g_bold1.0.iter().copied()) @@ -294,8 +291,8 @@ impl WipStatement { (e, inv_e, e_square, inv_e_square, g_bold, h_bold) = Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat); - a = a1.mul(e).add_vec(&a2.mul(y_n_hat * inv_e)); - b = b1.mul(inv_e).add_vec(&b2.mul(e)); + a = (a1 * e) + &(a2 * (y_n_hat * inv_e)); + b = (b1 * inv_e) + &(b2 * e); alpha += (d_l * e_square) + (d_r * inv_e_square); debug_assert_eq!(g_bold.len(), a.len()); diff --git a/coins/monero/src/ringct/bulletproofs/scalar_vector.rs b/coins/monero/src/ringct/bulletproofs/scalar_vector.rs index 6f94f228..e6288367 100644 --- a/coins/monero/src/ringct/bulletproofs/scalar_vector.rs +++ b/coins/monero/src/ringct/bulletproofs/scalar_vector.rs @@ -1,85 +1,17 @@ -use core::ops::{Add, Sub, Mul, Index}; +use core::{ + borrow::Borrow, + ops::{Index, IndexMut, Add, Sub, Mul}, +}; use std_shims::vec::Vec; use zeroize::{Zeroize, ZeroizeOnDrop}; use group::ff::Field; use dalek_ff_group::{Scalar, EdwardsPoint}; - use multiexp::multiexp; #[derive(Clone, PartialEq, Eq, Debug, Zeroize, ZeroizeOnDrop)] pub(crate) struct ScalarVector(pub(crate) Vec); -macro_rules! math_op { - ($Op: ident, $op: ident, $f: expr) => { - #[allow(clippy::redundant_closure_call)] - impl $Op for ScalarVector { - type Output = ScalarVector; - fn $op(self, b: Scalar) -> ScalarVector { - ScalarVector(self.0.iter().map(|a| $f((a, &b))).collect()) - } - } - - #[allow(clippy::redundant_closure_call)] - impl $Op for &ScalarVector { - type Output = ScalarVector; - fn $op(self, b: Scalar) -> ScalarVector { - ScalarVector(self.0.iter().map(|a| $f((a, &b))).collect()) - } - } - - #[allow(clippy::redundant_closure_call)] - impl $Op for ScalarVector { - type Output = ScalarVector; - fn $op(self, b: ScalarVector) -> ScalarVector { - debug_assert_eq!(self.len(), b.len()); - ScalarVector(self.0.iter().zip(b.0.iter()).map($f).collect()) - } - } - - #[allow(clippy::redundant_closure_call)] - impl $Op<&ScalarVector> for &ScalarVector { - type Output = ScalarVector; - fn $op(self, b: &ScalarVector) -> ScalarVector { - debug_assert_eq!(self.len(), b.len()); - ScalarVector(self.0.iter().zip(b.0.iter()).map($f).collect()) - } - } - }; -} -math_op!(Add, add, |(a, b): (&Scalar, &Scalar)| *a + *b); -math_op!(Sub, sub, |(a, b): (&Scalar, &Scalar)| *a - *b); -math_op!(Mul, mul, |(a, b): (&Scalar, &Scalar)| *a * *b); - -impl ScalarVector { - pub(crate) fn new(len: usize) -> ScalarVector { - ScalarVector(vec![Scalar::ZERO; len]) - } - - pub(crate) fn powers(x: Scalar, len: usize) -> ScalarVector { - debug_assert!(len != 0); - - let mut res = Vec::with_capacity(len); - res.push(Scalar::ONE); - for i in 1 .. len { - res.push(res[i - 1] * x); - } - ScalarVector(res) - } - - pub(crate) fn sum(mut self) -> Scalar { - self.0.drain(..).sum() - } - - pub(crate) fn len(&self) -> usize { - self.0.len() - } - - pub(crate) fn split(self) -> (ScalarVector, ScalarVector) { - let (l, r) = self.0.split_at(self.0.len() / 2); - (ScalarVector(l.to_vec()), ScalarVector(r.to_vec())) - } -} impl Index for ScalarVector { type Output = Scalar; @@ -87,28 +19,120 @@ impl Index for ScalarVector { &self.0[index] } } +impl IndexMut for ScalarVector { + fn index_mut(&mut self, index: usize) -> &mut Scalar { + &mut self.0[index] + } +} -pub(crate) fn inner_product(a: &ScalarVector, b: &ScalarVector) -> Scalar { - (a * b).sum() +impl> Add for ScalarVector { + type Output = ScalarVector; + fn add(mut self, scalar: S) -> ScalarVector { + for s in &mut self.0 { + *s += scalar.borrow(); + } + self + } +} +impl> Sub for ScalarVector { + type Output = ScalarVector; + fn sub(mut self, scalar: S) -> ScalarVector { + for s in &mut self.0 { + *s -= scalar.borrow(); + } + self + } +} +impl> Mul for ScalarVector { + type Output = ScalarVector; + fn mul(mut self, scalar: S) -> ScalarVector { + for s in &mut self.0 { + *s *= scalar.borrow(); + } + self + } +} + +impl Add<&ScalarVector> for ScalarVector { + type Output = ScalarVector; + fn add(mut self, other: &ScalarVector) -> ScalarVector { + debug_assert_eq!(self.len(), other.len()); + for (s, o) in self.0.iter_mut().zip(other.0.iter()) { + *s += o; + } + self + } +} +impl Sub<&ScalarVector> for ScalarVector { + type Output = ScalarVector; + fn sub(mut self, other: &ScalarVector) -> ScalarVector { + debug_assert_eq!(self.len(), other.len()); + for (s, o) in self.0.iter_mut().zip(other.0.iter()) { + *s -= o; + } + self + } +} +impl Mul<&ScalarVector> for ScalarVector { + type Output = ScalarVector; + fn mul(mut self, other: &ScalarVector) -> ScalarVector { + debug_assert_eq!(self.len(), other.len()); + for (s, o) in self.0.iter_mut().zip(other.0.iter()) { + *s *= o; + } + self + } } impl Mul<&[EdwardsPoint]> for &ScalarVector { type Output = EdwardsPoint; fn mul(self, b: &[EdwardsPoint]) -> EdwardsPoint { debug_assert_eq!(self.len(), b.len()); - multiexp(&self.0.iter().copied().zip(b.iter().copied()).collect::>()) + let mut multiexp_args = self.0.iter().copied().zip(b.iter().copied()).collect::>(); + let res = multiexp(&multiexp_args); + multiexp_args.zeroize(); + res } } -pub(crate) fn hadamard_fold( - l: &[EdwardsPoint], - r: &[EdwardsPoint], - a: Scalar, - b: Scalar, -) -> Vec { - let mut res = Vec::with_capacity(l.len() / 2); - for i in 0 .. l.len() { - res.push(multiexp(&[(a, l[i]), (b, r[i])])); +impl ScalarVector { + pub(crate) fn new(len: usize) -> Self { + ScalarVector(vec![Scalar::ZERO; len]) + } + + pub(crate) fn powers(x: Scalar, len: usize) -> Self { + debug_assert!(len != 0); + + let mut res = Vec::with_capacity(len); + res.push(Scalar::ONE); + res.push(x); + for i in 2 .. len { + res.push(res[i - 1] * x); + } + res.truncate(len); + ScalarVector(res) + } + + pub(crate) fn len(&self) -> usize { + self.0.len() + } + + pub(crate) fn sum(mut self) -> Scalar { + self.0.drain(..).sum() + } + + pub(crate) fn inner_product(self, vector: &Self) -> Scalar { + (self * vector).sum() + } + + pub(crate) fn weighted_inner_product(self, vector: &Self, y: &Self) -> Scalar { + (self * vector * y).sum() + } + + pub(crate) fn split(mut self) -> (Self, Self) { + debug_assert!(self.len() > 1); + let r = self.0.split_off(self.0.len() / 2); + debug_assert_eq!(self.len(), r.len()); + (self, ScalarVector(r)) } - res } diff --git a/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs b/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs index 7db2ecc8..b0890cf8 100644 --- a/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs +++ b/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs @@ -9,7 +9,6 @@ use dalek_ff_group::{Scalar, EdwardsPoint}; use crate::ringct::bulletproofs::plus::{ ScalarVector, PointVector, GeneratorsList, Generators, weighted_inner_product::{WipStatement, WipWitness}, - weighted_inner_product, }; #[test] @@ -68,7 +67,7 @@ fn test_weighted_inner_product() { #[allow(non_snake_case)] let P = g_bold.multiexp(&a) + h_bold.multiexp(&b) + - (g * weighted_inner_product(&a, &b, &y_vec)) + + (g * a.clone().weighted_inner_product(&b, &y_vec)) + (h * alpha); let statement = WipStatement::new(generators, P, y);