Optimize Monero BPs

This commit is contained in:
Luke Parker 2024-07-12 02:18:57 -04:00
parent 2c165e19ae
commit c59be46e2f
No known key found for this signature in database
4 changed files with 32 additions and 30 deletions

1
Cargo.lock generated
View file

@ -4781,7 +4781,6 @@ dependencies = [
"monero-primitives", "monero-primitives",
"rand_core", "rand_core",
"std-shims", "std-shims",
"subtle",
"thiserror", "thiserror",
"zeroize", "zeroize",
] ]

View file

@ -22,7 +22,6 @@ thiserror = { version = "1", default-features = false, optional = true }
rand_core = { version = "0.6", default-features = false } rand_core = { version = "0.6", default-features = false }
zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] } zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] }
subtle = { version = "^2.4", default-features = false }
# Cryptographic dependencies # Cryptographic dependencies
curve25519-dalek = { version = "4", default-features = false, features = ["alloc", "zeroize"] } curve25519-dalek = { version = "4", default-features = false, features = ["alloc", "zeroize"] }
@ -47,7 +46,6 @@ std = [
"rand_core/std", "rand_core/std",
"zeroize/std", "zeroize/std",
"subtle/std",
"monero-io/std", "monero-io/std",
"monero-generators/std", "monero-generators/std",

View file

@ -167,19 +167,18 @@ impl<'a> AggregateRangeStatement<'a> {
let (y, z) = Self::transcript_A_S(transcript, A, S); let (y, z) = Self::transcript_A_S(transcript, A, S);
transcript = z; transcript = z;
let z = ScalarVector::powers(z, 3 + padded_pow_of_2);
let twos = ScalarVector::powers(Scalar::from(2u8), COMMITMENT_BITS); let twos = ScalarVector::powers(Scalar::from(2u8), COMMITMENT_BITS);
let l = [aL - z, sL]; let l = [aL - z[1], sL];
let y_pow_n = ScalarVector::powers(y, aR.len()); let y_pow_n = ScalarVector::powers(y, aR.len());
let mut r = [((aR + z) * &y_pow_n), sR * &y_pow_n]; let mut r = [((aR + z[1]) * &y_pow_n), sR * &y_pow_n];
{ {
let mut z_current = z * z;
for j in 0 .. padded_pow_of_2 { for j in 0 .. padded_pow_of_2 {
for i in 0 .. COMMITMENT_BITS { for i in 0 .. COMMITMENT_BITS {
r[0].0[(j * COMMITMENT_BITS) + i] += z_current * twos[i]; r[0].0[(j * COMMITMENT_BITS) + i] += z[2 + j] * twos[i];
} }
z_current *= z;
} }
} }
let t1 = (l[0].clone().inner_product(&r[1])) + (r[0].clone().inner_product(&l[1])); let t1 = (l[0].clone().inner_product(&r[1])) + (r[0].clone().inner_product(&l[1]));
@ -216,10 +215,8 @@ impl<'a> AggregateRangeStatement<'a> {
let t_hat = l.clone().inner_product(&r); let t_hat = l.clone().inner_product(&r);
let mut tau_x = ((tau_2 * x) + tau_1) * x; let mut tau_x = ((tau_2 * x) + tau_1) * x;
{ {
let mut z_current = z * z; for (i, commitment) in witness.commitments.iter().enumerate() {
for commitment in &witness.commitments { tau_x += z[2 + i] * commitment.mask;
tau_x += z_current * commitment.mask;
z_current *= z;
} }
} }
let mu = alpha + (rho * x); let mu = alpha + (rho * x);
@ -268,6 +265,7 @@ impl<'a> AggregateRangeStatement<'a> {
let (y, z) = Self::transcript_A_S(transcript, proof.A, proof.S); let (y, z) = Self::transcript_A_S(transcript, proof.A, proof.S);
transcript = z; transcript = z;
let z = ScalarVector::powers(z, 3 + padded_pow_of_2);
transcript = Self::transcript_T12(transcript, proof.T1, proof.T2); transcript = Self::transcript_T12(transcript, proof.T1, proof.T2);
let x = transcript; let x = transcript;
transcript = Self::transcript_tau_x_mu_t_hat(transcript, proof.tau_x, proof.mu, proof.t_hat); transcript = Self::transcript_tau_x_mu_t_hat(transcript, proof.tau_x, proof.mu, proof.t_hat);
@ -293,18 +291,14 @@ impl<'a> AggregateRangeStatement<'a> {
// These will now sum to 0 if equal // These will now sum to 0 if equal
let weight = -weight; let weight = -weight;
verifier.0.h += weight * (z - (z * z)) * y_pow_n.sum(); verifier.0.h += weight * (z[1] - (z[2])) * y_pow_n.sum();
let mut z_current = z * z; for (i, commitment) in commitments.iter().enumerate() {
for commitment in &commitments { verifier.0.other.push((weight * z[2 + i], *commitment));
verifier.0.other.push((weight * z_current, *commitment));
z_current *= z;
} }
let mut z_current = z * z * z; for i in 0 .. padded_pow_of_2 {
for _ in 0 .. padded_pow_of_2 { verifier.0.h -= weight * z[3 + i] * twos.clone().sum();
verifier.0.h -= weight * z_current * twos.clone().sum();
z_current *= z;
} }
verifier.0.other.push((weight * x, proof.T1)); verifier.0.other.push((weight * x, proof.T1));
verifier.0.other.push((weight * (x * x), proof.T2)); verifier.0.other.push((weight * (x * x), proof.T2));
@ -315,22 +309,23 @@ impl<'a> AggregateRangeStatement<'a> {
// 66 // 66
verifier.0.other.push((ip_weight, proof.A)); verifier.0.other.push((ip_weight, proof.A));
verifier.0.other.push((ip_weight * x, proof.S)); verifier.0.other.push((ip_weight * x, proof.S));
// TODO: g_sum // We can replace these with a g_sum, h_sum scalar in the batch verifier
// It'd trade `2 * ip_rows` scalar additions (per proof) for one scalar addition and an
// additional term in the MSM
let ip_z = ip_weight * z[1];
for i in 0 .. ip_rows { for i in 0 .. ip_rows {
verifier.0.g_bold[i] += ip_weight * -z; verifier.0.h_bold[i] += ip_z;
} }
// TODO: h_sum let neg_ip_z = -ip_z;
for i in 0 .. ip_rows { for i in 0 .. ip_rows {
verifier.0.h_bold[i] += ip_weight * z; verifier.0.g_bold[i] += neg_ip_z;
} }
{ {
let mut z_current = z * z;
for j in 0 .. padded_pow_of_2 { for j in 0 .. padded_pow_of_2 {
for i in 0 .. COMMITMENT_BITS { for i in 0 .. COMMITMENT_BITS {
let full_i = (j * COMMITMENT_BITS) + i; let full_i = (j * COMMITMENT_BITS) + i;
verifier.0.h_bold[full_i] += ip_weight * y_inv_pow_n[full_i] * z_current * twos[i]; verifier.0.h_bold[full_i] += ip_weight * y_inv_pow_n[full_i] * z[2 + j] * twos[i];
} }
z_current *= z;
} }
} }
verifier.0.h += ip_weight * x_ip * proof.t_hat; verifier.0.h += ip_weight * x_ip * proof.t_hat;

View file

@ -166,6 +166,17 @@ impl WipStatement {
let mut g_bold = PointVector(g_bold); let mut g_bold = PointVector(g_bold);
let mut h_bold = PointVector(h_bold); let mut h_bold = PointVector(h_bold);
let mut y_inv = {
let mut i = 1;
let mut to_invert = vec![];
while i < g_bold.len() {
to_invert.push(y[i - 1]);
i *= 2;
}
Scalar::batch_invert(&mut to_invert);
to_invert
};
// Check P has the expected relationship // Check P has the expected relationship
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@ -219,8 +230,7 @@ impl WipStatement {
let c_l = a1.clone().weighted_inner_product(&b2, &y); let c_l = a1.clone().weighted_inner_product(&b2, &y);
let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y); let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
// TODO: Calculate these with a batch inversion let y_inv_n_hat = y_inv.pop().unwrap();
let y_inv_n_hat = y_n_hat.invert();
let mut L_terms = (a1.clone() * y_inv_n_hat) let mut L_terms = (a1.clone() * y_inv_n_hat)
.0 .0