From c59be46e2fb765b98f37867bac7626f446ef81fe Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Fri, 12 Jul 2024 02:18:57 -0400 Subject: [PATCH] Optimize Monero BPs --- Cargo.lock | 1 - coins/monero/ringct/bulletproofs/Cargo.toml | 2 - .../ringct/bulletproofs/src/original/mod.rs | 45 +++++++++---------- .../src/plus/weighted_inner_product.rs | 14 +++++- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9cc94f37..d1ad4689 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4781,7 +4781,6 @@ dependencies = [ "monero-primitives", "rand_core", "std-shims", - "subtle", "thiserror", "zeroize", ] diff --git a/coins/monero/ringct/bulletproofs/Cargo.toml b/coins/monero/ringct/bulletproofs/Cargo.toml index 121fd883..f5b44622 100644 --- a/coins/monero/ringct/bulletproofs/Cargo.toml +++ b/coins/monero/ringct/bulletproofs/Cargo.toml @@ -22,7 +22,6 @@ thiserror = { version = "1", default-features = false, optional = true } rand_core = { version = "0.6", default-features = false } zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] } -subtle = { version = "^2.4", default-features = false } # Cryptographic dependencies curve25519-dalek = { version = "4", default-features = false, features = ["alloc", "zeroize"] } @@ -47,7 +46,6 @@ std = [ "rand_core/std", "zeroize/std", - "subtle/std", "monero-io/std", "monero-generators/std", diff --git a/coins/monero/ringct/bulletproofs/src/original/mod.rs b/coins/monero/ringct/bulletproofs/src/original/mod.rs index 10d63be4..18fac4d6 100644 --- a/coins/monero/ringct/bulletproofs/src/original/mod.rs +++ b/coins/monero/ringct/bulletproofs/src/original/mod.rs @@ -167,19 +167,18 @@ impl<'a> AggregateRangeStatement<'a> { let (y, z) = Self::transcript_A_S(transcript, A, S); transcript = z; + let z = ScalarVector::powers(z, 3 + padded_pow_of_2); let twos = ScalarVector::powers(Scalar::from(2u8), COMMITMENT_BITS); - let l = [aL - z, sL]; + let l = [aL - z[1], sL]; let y_pow_n = ScalarVector::powers(y, aR.len()); - let mut r = [((aR + z) * &y_pow_n), sR * &y_pow_n]; + let mut r = [((aR + z[1]) * &y_pow_n), sR * &y_pow_n]; { - let mut z_current = z * z; for j in 0 .. padded_pow_of_2 { for i in 0 .. COMMITMENT_BITS { - r[0].0[(j * COMMITMENT_BITS) + i] += z_current * twos[i]; + r[0].0[(j * COMMITMENT_BITS) + i] += z[2 + j] * twos[i]; } - z_current *= z; } } let t1 = (l[0].clone().inner_product(&r[1])) + (r[0].clone().inner_product(&l[1])); @@ -216,10 +215,8 @@ impl<'a> AggregateRangeStatement<'a> { let t_hat = l.clone().inner_product(&r); let mut tau_x = ((tau_2 * x) + tau_1) * x; { - let mut z_current = z * z; - for commitment in &witness.commitments { - tau_x += z_current * commitment.mask; - z_current *= z; + for (i, commitment) in witness.commitments.iter().enumerate() { + tau_x += z[2 + i] * commitment.mask; } } let mu = alpha + (rho * x); @@ -268,6 +265,7 @@ impl<'a> AggregateRangeStatement<'a> { let (y, z) = Self::transcript_A_S(transcript, proof.A, proof.S); transcript = z; + let z = ScalarVector::powers(z, 3 + padded_pow_of_2); transcript = Self::transcript_T12(transcript, proof.T1, proof.T2); let x = transcript; transcript = Self::transcript_tau_x_mu_t_hat(transcript, proof.tau_x, proof.mu, proof.t_hat); @@ -293,18 +291,14 @@ impl<'a> AggregateRangeStatement<'a> { // These will now sum to 0 if equal let weight = -weight; - verifier.0.h += weight * (z - (z * z)) * y_pow_n.sum(); + verifier.0.h += weight * (z[1] - (z[2])) * y_pow_n.sum(); - let mut z_current = z * z; - for commitment in &commitments { - verifier.0.other.push((weight * z_current, *commitment)); - z_current *= z; + for (i, commitment) in commitments.iter().enumerate() { + verifier.0.other.push((weight * z[2 + i], *commitment)); } - let mut z_current = z * z * z; - for _ in 0 .. padded_pow_of_2 { - verifier.0.h -= weight * z_current * twos.clone().sum(); - z_current *= z; + for i in 0 .. padded_pow_of_2 { + verifier.0.h -= weight * z[3 + i] * twos.clone().sum(); } verifier.0.other.push((weight * x, proof.T1)); verifier.0.other.push((weight * (x * x), proof.T2)); @@ -315,22 +309,23 @@ impl<'a> AggregateRangeStatement<'a> { // 66 verifier.0.other.push((ip_weight, proof.A)); verifier.0.other.push((ip_weight * x, proof.S)); - // TODO: g_sum + // We can replace these with a g_sum, h_sum scalar in the batch verifier + // It'd trade `2 * ip_rows` scalar additions (per proof) for one scalar addition and an + // additional term in the MSM + let ip_z = ip_weight * z[1]; for i in 0 .. ip_rows { - verifier.0.g_bold[i] += ip_weight * -z; + verifier.0.h_bold[i] += ip_z; } - // TODO: h_sum + let neg_ip_z = -ip_z; for i in 0 .. ip_rows { - verifier.0.h_bold[i] += ip_weight * z; + verifier.0.g_bold[i] += neg_ip_z; } { - let mut z_current = z * z; for j in 0 .. padded_pow_of_2 { for i in 0 .. COMMITMENT_BITS { let full_i = (j * COMMITMENT_BITS) + i; - verifier.0.h_bold[full_i] += ip_weight * y_inv_pow_n[full_i] * z_current * twos[i]; + verifier.0.h_bold[full_i] += ip_weight * y_inv_pow_n[full_i] * z[2 + j] * twos[i]; } - z_current *= z; } } verifier.0.h += ip_weight * x_ip * proof.t_hat; diff --git a/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs b/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs index 2a3bbe6c..4c838840 100644 --- a/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs +++ b/coins/monero/ringct/bulletproofs/src/plus/weighted_inner_product.rs @@ -166,6 +166,17 @@ impl WipStatement { let mut g_bold = PointVector(g_bold); let mut h_bold = PointVector(h_bold); + let mut y_inv = { + let mut i = 1; + let mut to_invert = vec![]; + while i < g_bold.len() { + to_invert.push(y[i - 1]); + i *= 2; + } + Scalar::batch_invert(&mut to_invert); + to_invert + }; + // Check P has the expected relationship #[cfg(debug_assertions)] { @@ -219,8 +230,7 @@ impl WipStatement { let c_l = a1.clone().weighted_inner_product(&b2, &y); let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y); - // TODO: Calculate these with a batch inversion - let y_inv_n_hat = y_n_hat.invert(); + let y_inv_n_hat = y_inv.pop().unwrap(); let mut L_terms = (a1.clone() * y_inv_n_hat) .0