From 7890827a481419f4a5f66e879a1614ef8a3d1f07 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Thu, 30 Jun 2022 09:30:24 -0400 Subject: [PATCH] Implement variable-sized windows into multiexp Closes https://github.com/serai-dex/serai/issues/17 by using the PrimeFieldBits API to do so. Should greatly speed up small batches, along with batches in the hundreds. Saves almost a full second on the cross-group DLEq proof. --- crypto/dleq/Cargo.toml | 4 +- crypto/dleq/src/cross_group/mod.rs | 26 ++--- crypto/frost/Cargo.toml | 5 +- crypto/frost/src/curve/dalek.rs | 2 - crypto/frost/src/curve/kp256.rs | 2 - crypto/frost/src/curve/mod.rs | 8 +- crypto/frost/src/key_gen.rs | 4 +- crypto/frost/src/schnorr.rs | 2 +- crypto/frost/src/tests/curve.rs | 7 +- crypto/multiexp/Cargo.toml | 7 ++ crypto/multiexp/src/batch.rs | 22 ++-- crypto/multiexp/src/lib.rs | 157 +++++++++++++++++++++++++---- crypto/multiexp/src/pippenger.rs | 66 +++++------- crypto/multiexp/src/straus.rs | 66 +++++------- crypto/multiexp/src/tests/mod.rs | 112 ++++++++++++++++++++ 15 files changed, 342 insertions(+), 148 deletions(-) create mode 100644 crypto/multiexp/src/tests/mod.rs diff --git a/crypto/dleq/Cargo.toml b/crypto/dleq/Cargo.toml index f8d26a25..de5338b7 100644 --- a/crypto/dleq/Cargo.toml +++ b/crypto/dleq/Cargo.toml @@ -10,10 +10,12 @@ edition = "2021" thiserror = "1" rand_core = "0.6" +transcript = { package = "flexible-transcript", path = "../transcript", version = "0.1" } + ff = "0.12" group = "0.12" -transcript = { package = "flexible-transcript", path = "../transcript", version = "0.1" } +multiexp = { path = "../multiexp" } [dev-dependencies] hex-literal = "0.3" diff --git a/crypto/dleq/src/cross_group/mod.rs b/crypto/dleq/src/cross_group/mod.rs index e8146c4b..498d5f9f 100644 --- a/crypto/dleq/src/cross_group/mod.rs +++ b/crypto/dleq/src/cross_group/mod.rs @@ -83,7 +83,8 @@ pub struct DLEqProof { poks: (SchnorrPoK, SchnorrPoK) } -impl DLEqProof { +impl DLEqProof + where G0::Scalar: PrimeFieldBits, G1::Scalar: PrimeFieldBits { fn initialize_transcript( transcript: &mut T, generators: (Generators, Generators), @@ -134,13 +135,17 @@ impl DLEqProof { } // TODO: Use multiexp here after https://github.com/serai-dex/serai/issues/17 - fn reconstruct_key(commitments: impl Iterator) -> G { + fn reconstruct_key( + commitments: impl Iterator + ) -> G where G::Scalar: PrimeFieldBits { let mut pow_2 = G::Scalar::one(); - commitments.fold(G::identity(), |key, commitment| { - let res = key + (commitment * pow_2); - pow_2 = pow_2.double(); - res - }) + multiexp::multiexp_vartime( + &commitments.map(|commitment| { + let res = (pow_2, commitment); + pow_2 = pow_2.double(); + res + }).collect::>() + ) } fn reconstruct_keys(&self) -> (G0, G1) { @@ -169,10 +174,7 @@ impl DLEqProof { transcript: &mut T, generators: (Generators, Generators), f: G0::Scalar - ) -> ( - Self, - (G0::Scalar, G1::Scalar) - ) where G0::Scalar: PrimeFieldBits, G1::Scalar: PrimeFieldBits { + ) -> (Self, (G0::Scalar, G1::Scalar)) { // At least one bit will be dropped from either field element, making it irrelevant which one // we get a random element in let f = scalar_normalize::<_, G1::Scalar>(f); @@ -262,7 +264,7 @@ impl DLEqProof { &self, transcript: &mut T, generators: (Generators, Generators) - ) -> Result<(G0, G1), DLEqError> where G0::Scalar: PrimeFieldBits, G1::Scalar: PrimeFieldBits { + ) -> Result<(G0, G1), DLEqError> { let capacity = G0::Scalar::CAPACITY.min(G1::Scalar::CAPACITY); if self.bits.len() != capacity.try_into().unwrap() { return Err(DLEqError::InvalidProofLength); diff --git a/crypto/frost/Cargo.toml b/crypto/frost/Cargo.toml index 73dbc52d..436c3966 100644 --- a/crypto/frost/Cargo.toml +++ b/crypto/frost/Cargo.toml @@ -16,11 +16,12 @@ hex = "0.4" sha2 = { version = "0.10", optional = true } +ff = "0.12" group = "0.12" elliptic-curve = { version = "0.12", features = ["hash2curve"], optional = true } -p256 = { version = "0.11", features = ["arithmetic", "hash2curve"], optional = true } -k256 = { version = "0.11", features = ["arithmetic", "hash2curve"], optional = true } +p256 = { version = "0.11", features = ["arithmetic", "bits", "hash2curve"], optional = true } +k256 = { version = "0.11", features = ["arithmetic", "bits", "hash2curve"], optional = true } dalek-ff-group = { path = "../dalek-ff-group", version = "0.1", optional = true } transcript = { package = "flexible-transcript", path = "../transcript", version = "0.1" } diff --git a/crypto/frost/src/curve/dalek.rs b/crypto/frost/src/curve/dalek.rs index 07515eee..40e6c252 100644 --- a/crypto/frost/src/curve/dalek.rs +++ b/crypto/frost/src/curve/dalek.rs @@ -35,8 +35,6 @@ macro_rules! dalek_curve { const GENERATOR: Self::G = $POINT; const GENERATOR_TABLE: Self::T = &$TABLE; - const LITTLE_ENDIAN: bool = true; - fn random_nonce(secret: Self::F, rng: &mut R) -> Self::F { let mut seed = vec![0; 32]; rng.fill_bytes(&mut seed); diff --git a/crypto/frost/src/curve/kp256.rs b/crypto/frost/src/curve/kp256.rs index 278e4eaa..9b1874d8 100644 --- a/crypto/frost/src/curve/kp256.rs +++ b/crypto/frost/src/curve/kp256.rs @@ -29,8 +29,6 @@ macro_rules! kp_curve { const GENERATOR: Self::G = $lib::ProjectivePoint::GENERATOR; const GENERATOR_TABLE: Self::G = $lib::ProjectivePoint::GENERATOR; - const LITTLE_ENDIAN: bool = false; - fn random_nonce(secret: Self::F, rng: &mut R) -> Self::F { let mut seed = vec![0; 32]; rng.fill_bytes(&mut seed); diff --git a/crypto/frost/src/curve/mod.rs b/crypto/frost/src/curve/mod.rs index 2de31a2a..e08e2faf 100644 --- a/crypto/frost/src/curve/mod.rs +++ b/crypto/frost/src/curve/mod.rs @@ -4,7 +4,8 @@ use thiserror::Error; use rand_core::{RngCore, CryptoRng}; -use group::{ff::PrimeField, Group, GroupOps, prime::PrimeGroup}; +use ff::{PrimeField, PrimeFieldBits}; +use group::{Group, GroupOps, prime::PrimeGroup}; #[cfg(any(test, feature = "dalek"))] mod dalek; @@ -40,7 +41,7 @@ pub enum CurveError { pub trait Curve: Clone + Copy + PartialEq + Eq + Debug { /// Scalar field element type // This is available via G::Scalar yet `C::G::Scalar` is ambiguous, forcing horrific accesses - type F: PrimeField; + type F: PrimeField + PrimeFieldBits; /// Group element type type G: Group + GroupOps + PrimeGroup; /// Precomputed table type @@ -57,9 +58,6 @@ pub trait Curve: Clone + Copy + PartialEq + Eq + Debug { /// If there isn't a precomputed table available, the generator itself should be used const GENERATOR_TABLE: Self::T; - /// If little endian is used for the scalar field's Repr - const LITTLE_ENDIAN: bool; - /// Securely generate a random nonce. H4 from the IETF draft fn random_nonce(secret: Self::F, rng: &mut R) -> Self::F; diff --git a/crypto/frost/src/key_gen.rs b/crypto/frost/src/key_gen.rs index 4f2832e5..e5b0f76f 100644 --- a/crypto/frost/src/key_gen.rs +++ b/crypto/frost/src/key_gen.rs @@ -224,7 +224,7 @@ fn complete_r2( res }; - let mut batch = BatchVerifier::new(shares.len(), C::LITTLE_ENDIAN); + let mut batch = BatchVerifier::new(shares.len()); for (l, share) in &shares { if *l == params.i() { continue; @@ -254,7 +254,7 @@ fn complete_r2( // Calculate each user's verification share let mut verification_shares = HashMap::new(); for i in 1 ..= params.n() { - verification_shares.insert(i, multiexp_vartime(&exponential(i, &stripes), C::LITTLE_ENDIAN)); + verification_shares.insert(i, multiexp_vartime(&exponential(i, &stripes))); } // Removing this check would enable optimizing the above from t + (n * t) to t + ((n - 1) * t) debug_assert_eq!(C::GENERATOR_TABLE * secret_share, verification_shares[¶ms.i()]); diff --git a/crypto/frost/src/schnorr.rs b/crypto/frost/src/schnorr.rs index 9424fd28..af9ff808 100644 --- a/crypto/frost/src/schnorr.rs +++ b/crypto/frost/src/schnorr.rs @@ -46,7 +46,7 @@ pub(crate) fn batch_verify( triplets: &[(u16, C::G, C::F, SchnorrSignature)] ) -> Result<(), u16> { let mut values = [(C::F::one(), C::GENERATOR); 3]; - let mut batch = BatchVerifier::new(triplets.len(), C::LITTLE_ENDIAN); + let mut batch = BatchVerifier::new(triplets.len()); for triple in triplets { // s = r + ca // sG == R + cA diff --git a/crypto/frost/src/tests/curve.rs b/crypto/frost/src/tests/curve.rs index 48dd78de..092ee50f 100644 --- a/crypto/frost/src/tests/curve.rs +++ b/crypto/frost/src/tests/curve.rs @@ -21,7 +21,8 @@ pub fn test_curve(rng: &mut R) { // TODO: Test the Curve functions themselves // Test successful multiexp, with enough pairs to trigger its variety of algorithms - // TODO: This should probably be under multiexp + // Multiexp has its own tests, yet only against k256 and Ed25519 (which should be sufficient + // as-is to prove multiexp), and this doesn't hurt { let mut pairs = Vec::with_capacity(1000); let mut sum = C::G::identity(); @@ -30,8 +31,8 @@ pub fn test_curve(rng: &mut R) { pairs.push((C::F::random(&mut *rng), C::GENERATOR * C::F::random(&mut *rng))); sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0; } - assert_eq!(multiexp::multiexp(&pairs, C::LITTLE_ENDIAN), sum); - assert_eq!(multiexp::multiexp_vartime(&pairs, C::LITTLE_ENDIAN), sum); + assert_eq!(multiexp::multiexp(&pairs), sum); + assert_eq!(multiexp::multiexp_vartime(&pairs), sum); } } diff --git a/crypto/multiexp/Cargo.toml b/crypto/multiexp/Cargo.toml index c4c73690..0342f0ee 100644 --- a/crypto/multiexp/Cargo.toml +++ b/crypto/multiexp/Cargo.toml @@ -9,9 +9,16 @@ keywords = ["multiexp", "ff", "group"] edition = "2021" [dependencies] +ff = "0.12" group = "0.12" rand_core = { version = "0.6", optional = true } +[dev-dependencies] +rand_core = "0.6" + +k256 = { version = "0.11", features = ["bits"] } +dalek-ff-group = { path = "../dalek-ff-group" } + [features] batch = ["rand_core"] diff --git a/crypto/multiexp/src/batch.rs b/crypto/multiexp/src/batch.rs index 6962ea86..5b5d65fb 100644 --- a/crypto/multiexp/src/batch.rs +++ b/crypto/multiexp/src/batch.rs @@ -1,16 +1,17 @@ use rand_core::{RngCore, CryptoRng}; -use group::{ff::Field, Group}; +use ff::{Field, PrimeFieldBits}; +use group::Group; use crate::{multiexp, multiexp_vartime}; #[cfg(feature = "batch")] -pub struct BatchVerifier(Vec<(Id, Vec<(G::Scalar, G)>)>, bool); +pub struct BatchVerifier(Vec<(Id, Vec<(G::Scalar, G)>)>); #[cfg(feature = "batch")] -impl BatchVerifier { - pub fn new(capacity: usize, endian: bool) -> BatchVerifier { - BatchVerifier(Vec::with_capacity(capacity), endian) +impl BatchVerifier where ::Scalar: PrimeFieldBits { + pub fn new(capacity: usize) -> BatchVerifier { + BatchVerifier(Vec::with_capacity(capacity)) } pub fn queue< @@ -28,15 +29,13 @@ impl BatchVerifier { pub fn verify(&self) -> bool { multiexp( - &self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>(), - self.1 + &self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>() ).is_identity().into() } pub fn verify_vartime(&self) -> bool { multiexp_vartime( - &self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>(), - self.1 + &self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>() ).is_identity().into() } @@ -46,8 +45,7 @@ impl BatchVerifier { while slice.len() > 1 { let split = slice.len() / 2; if multiexp_vartime( - &slice[.. split].iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>(), - self.1 + &slice[.. split].iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>() ).is_identity().into() { slice = &slice[split ..]; } else { @@ -56,7 +54,7 @@ impl BatchVerifier { } slice.get(0).filter( - |(_, value)| !bool::from(multiexp_vartime(value, self.1).is_identity()) + |(_, value)| !bool::from(multiexp_vartime(value).is_identity()) ).map(|(id, _)| *id) } diff --git a/crypto/multiexp/src/lib.rs b/crypto/multiexp/src/lib.rs index 51651e64..ca1b6495 100644 --- a/crypto/multiexp/src/lib.rs +++ b/crypto/multiexp/src/lib.rs @@ -1,3 +1,4 @@ +use ff::PrimeFieldBits; use group::Group; mod straus; @@ -11,39 +12,151 @@ mod batch; #[cfg(feature = "batch")] pub use batch::BatchVerifier; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -enum Algorithm { - Straus, - Pippenger +#[cfg(test)] +mod tests; + +pub(crate) fn prep_bits( + pairs: &[(G::Scalar, G)], + window: u8 +) -> Vec> where G::Scalar: PrimeFieldBits { + let w_usize = usize::from(window); + + let mut groupings = vec![]; + for pair in pairs { + let p = groupings.len(); + let bits = pair.0.to_le_bits(); + groupings.push(vec![0; (bits.len() + (w_usize - 1)) / w_usize]); + + for (i, bit) in bits.into_iter().enumerate() { + let bit = bit as u8; + debug_assert_eq!(bit | 1, 1); + groupings[p][i / w_usize] |= bit << (i % w_usize); + } + } + + groupings } -fn algorithm(pairs: usize) -> Algorithm { - // TODO: Replace this with an actual formula determining which will use less additions - // Right now, Straus is used until 600, instead of the far more accurate 300, as Pippenger - // operates per byte instead of per nibble, and therefore requires a much longer series to be - // performant - // Technically, 800 is dalek's number for when to use byte Pippenger, yet given Straus's own - // implementation limitations... - if pairs < 600 { - Algorithm::Straus +pub(crate) fn prep_tables( + pairs: &[(G::Scalar, G)], + window: u8 +) -> Vec> { + let mut tables = Vec::with_capacity(pairs.len()); + for pair in pairs { + let p = tables.len(); + tables.push(vec![G::identity(); 2_usize.pow(window.into())]); + let mut accum = G::identity(); + for i in 1 .. tables[p].len() { + accum += pair.1; + tables[p][i] = accum; + } + } + tables +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +enum Algorithm { + Straus(u8), + Pippenger(u8) +} + +/* +Release (with runs 20, so all of these are off by 20x): + +k256 +Straus 3 is more efficient at 5 with 678µs per +Straus 4 is more efficient at 10 with 530µs per +Straus 5 is more efficient at 35 with 467µs per + +Pippenger 5 is more efficient at 125 with 431µs per +Pippenger 6 is more efficient at 275 with 349µs per +Pippenger 7 is more efficient at 375 with 360µs per + +dalek +Straus 3 is more efficient at 5 with 519µs per +Straus 4 is more efficient at 10 with 376µs per +Straus 5 is more efficient at 170 with 330µs per + +Pippenger 5 is more efficient at 125 with 305µs per +Pippenger 6 is more efficient at 275 with 250µs per +Pippenger 7 is more efficient at 450 with 205µs per +Pippenger 8 is more efficient at 800 with 213µs per + +Debug (with runs 5, so...): + +k256 +Straus 3 is more efficient at 5 with 2532µs per +Straus 4 is more efficient at 10 with 1930µs per +Straus 5 is more efficient at 80 with 1632µs per + +Pippenger 5 is more efficient at 150 with 1441µs per +Pippenger 6 is more efficient at 300 with 1235µs per +Pippenger 7 is more efficient at 475 with 1182µs per +Pippenger 8 is more efficient at 625 with 1170µs per + +dalek: +Straus 3 is more efficient at 5 with 971µs per +Straus 4 is more efficient at 10 with 782µs per +Straus 5 is more efficient at 75 with 778µs per +Straus 6 is more efficient at 165 with 867µs per + +Pippenger 5 is more efficient at 125 with 677µs per +Pippenger 6 is more efficient at 250 with 655µs per +Pippenger 7 is more efficient at 475 with 500µs per +Pippenger 8 is more efficient at 875 with 499µs per +*/ +fn algorithm(len: usize) -> Algorithm { + #[cfg(not(debug_assertions))] + if len < 10 { + // Straus 2 never showed a performance benefit, even with just 2 elements + Algorithm::Straus(3) + } else if len < 20 { + Algorithm::Straus(4) + } else if len < 50 { + Algorithm::Straus(5) + } else if len < 100 { + Algorithm::Pippenger(4) + } else if len < 125 { + Algorithm::Pippenger(5) + } else if len < 275 { + Algorithm::Pippenger(6) + } else if len < 400 { + Algorithm::Pippenger(7) } else { - Algorithm::Pippenger + Algorithm::Pippenger(8) + } + + #[cfg(debug_assertions)] + if len < 10 { + Algorithm::Straus(3) + } else if len < 80 { + Algorithm::Straus(4) + } else if len < 100 { + Algorithm::Straus(5) + } else if len < 125 { + Algorithm::Pippenger(4) + } else if len < 275 { + Algorithm::Pippenger(5) + } else if len < 475 { + Algorithm::Pippenger(6) + } else if len < 750 { + Algorithm::Pippenger(7) + } else { + Algorithm::Pippenger(8) } } // Performs a multiexp, automatically selecting the optimal algorithm based on amount of pairs -// Takes in an iterator of scalars and points, with a boolean for if the scalars are little endian -// encoded in their Reprs or not -pub fn multiexp(pairs: &[(G::Scalar, G)], little: bool) -> G { +pub fn multiexp(pairs: &[(G::Scalar, G)]) -> G where G::Scalar: PrimeFieldBits { match algorithm(pairs.len()) { - Algorithm::Straus => straus(pairs, little), - Algorithm::Pippenger => pippenger(pairs, little) + Algorithm::Straus(window) => straus(pairs, window), + Algorithm::Pippenger(window) => pippenger(pairs, window) } } -pub fn multiexp_vartime(pairs: &[(G::Scalar, G)], little: bool) -> G { +pub fn multiexp_vartime(pairs: &[(G::Scalar, G)]) -> G where G::Scalar: PrimeFieldBits { match algorithm(pairs.len()) { - Algorithm::Straus => straus_vartime(pairs, little), - Algorithm::Pippenger => pippenger_vartime(pairs, little) + Algorithm::Straus(window) => straus_vartime(pairs, window), + Algorithm::Pippenger(window) => pippenger_vartime(pairs, window) } } diff --git a/crypto/multiexp/src/pippenger.rs b/crypto/multiexp/src/pippenger.rs index b812c922..cfc24f1b 100644 --- a/crypto/multiexp/src/pippenger.rs +++ b/crypto/multiexp/src/pippenger.rs @@ -1,42 +1,23 @@ -use group::{ff::PrimeField, Group}; +use ff::PrimeFieldBits; +use group::Group; -fn prep(pairs: &[(G::Scalar, G)], little: bool) -> (Vec>, Vec) { - let mut res = vec![]; - let mut points = vec![]; - for pair in pairs { - let p = res.len(); - res.push(vec![]); - { - let mut repr = pair.0.to_repr(); - let bytes = repr.as_mut(); - if !little { - bytes.reverse(); - } +use crate::prep_bits; - res[p].resize(bytes.len(), 0); - for i in 0 .. bytes.len() { - res[p][i] = bytes[i]; - } - } - - points.push(pair.1); - } - - (res, points) -} - -pub(crate) fn pippenger(pairs: &[(G::Scalar, G)], little: bool) -> G { - let (bytes, points) = prep(pairs, little); +pub(crate) fn pippenger( + pairs: &[(G::Scalar, G)], + window: u8 +) -> G where G::Scalar: PrimeFieldBits { + let bits = prep_bits(pairs, window); let mut res = G::identity(); - for n in (0 .. bytes[0].len()).rev() { - for _ in 0 .. 8 { + for n in (0 .. bits[0].len()).rev() { + for _ in 0 .. window { res = res.double(); } - let mut buckets = [G::identity(); 256]; - for p in 0 .. bytes.len() { - buckets[usize::from(bytes[p][n])] += points[p]; + let mut buckets = vec![G::identity(); 2_usize.pow(window.into())]; + for p in 0 .. bits.len() { + buckets[usize::from(bits[p][n])] += pairs[p].1; } let mut intermediate_sum = G::identity(); @@ -49,22 +30,25 @@ pub(crate) fn pippenger(pairs: &[(G::Scalar, G)], little: bool) -> G { res } -pub(crate) fn pippenger_vartime(pairs: &[(G::Scalar, G)], little: bool) -> G { - let (bytes, points) = prep(pairs, little); +pub(crate) fn pippenger_vartime( + pairs: &[(G::Scalar, G)], + window: u8 +) -> G where G::Scalar: PrimeFieldBits { + let bits = prep_bits(pairs, window); let mut res = G::identity(); - for n in (0 .. bytes[0].len()).rev() { - if n != (bytes[0].len() - 1) { - for _ in 0 .. 8 { + for n in (0 .. bits[0].len()).rev() { + if n != (bits[0].len() - 1) { + for _ in 0 .. window { res = res.double(); } } - let mut buckets = [G::identity(); 256]; - for p in 0 .. bytes.len() { - let nibble = usize::from(bytes[p][n]); + let mut buckets = vec![G::identity(); 2_usize.pow(window.into())]; + for p in 0 .. bits.len() { + let nibble = usize::from(bits[p][n]); if nibble != 0 { - buckets[nibble] += points[p]; + buckets[nibble] += pairs[p].1; } } diff --git a/crypto/multiexp/src/straus.rs b/crypto/multiexp/src/straus.rs index b8660f1b..e2955d94 100644 --- a/crypto/multiexp/src/straus.rs +++ b/crypto/multiexp/src/straus.rs @@ -1,66 +1,46 @@ -use group::{ff::PrimeField, Group}; +use ff::PrimeFieldBits; +use group::Group; -fn prep(pairs: &[(G::Scalar, G)], little: bool) -> (Vec>, Vec<[G; 16]>) { - let mut nibbles = vec![]; - let mut tables = vec![]; - for pair in pairs { - let p = nibbles.len(); - nibbles.push(vec![]); - { - let mut repr = pair.0.to_repr(); - let bytes = repr.as_mut(); - if !little { - bytes.reverse(); - } +use crate::{prep_bits, prep_tables}; - nibbles[p].resize(bytes.len() * 2, 0); - for i in 0 .. bytes.len() { - nibbles[p][i * 2] = bytes[i] & 0b1111; - nibbles[p][(i * 2) + 1] = (bytes[i] >> 4) & 0b1111; - } - } - - tables.push([G::identity(); 16]); - let mut accum = G::identity(); - for i in 1 .. 16 { - accum += pair.1; - tables[p][i] = accum; - } - } - - (nibbles, tables) -} - -pub(crate) fn straus(pairs: &[(G::Scalar, G)], little: bool) -> G { - let (nibbles, tables) = prep(pairs, little); +pub(crate) fn straus( + pairs: &[(G::Scalar, G)], + window: u8 +) -> G where G::Scalar: PrimeFieldBits { + let groupings = prep_bits(pairs, window); + let tables = prep_tables(pairs, window); let mut res = G::identity(); - for b in (0 .. nibbles[0].len()).rev() { - for _ in 0 .. 4 { + for b in (0 .. groupings[0].len()).rev() { + for _ in 0 .. window { res = res.double(); } for s in 0 .. tables.len() { - res += tables[s][usize::from(nibbles[s][b])]; + res += tables[s][usize::from(groupings[s][b])]; } } res } -pub(crate) fn straus_vartime(pairs: &[(G::Scalar, G)], little: bool) -> G { - let (nibbles, tables) = prep(pairs, little); +pub(crate) fn straus_vartime( + pairs: &[(G::Scalar, G)], + window: u8 +) -> G where G::Scalar: PrimeFieldBits { + let groupings = prep_bits(pairs, window); + let tables = prep_tables(pairs, window); let mut res = G::identity(); - for b in (0 .. nibbles[0].len()).rev() { - if b != (nibbles[0].len() - 1) { - for _ in 0 .. 4 { + for b in (0 .. groupings[0].len()).rev() { + if b != (groupings[0].len() - 1) { + for _ in 0 .. window { res = res.double(); } } for s in 0 .. tables.len() { - if nibbles[s][b] != 0 { - res += tables[s][usize::from(nibbles[s][b])]; + if groupings[s][b] != 0 { + res += tables[s][usize::from(groupings[s][b])]; } } } diff --git a/crypto/multiexp/src/tests/mod.rs b/crypto/multiexp/src/tests/mod.rs new file mode 100644 index 00000000..628c52c8 --- /dev/null +++ b/crypto/multiexp/src/tests/mod.rs @@ -0,0 +1,112 @@ +use std::time::Instant; + +use rand_core::OsRng; + +use ff::{Field, PrimeFieldBits}; +use group::Group; + +use k256::ProjectivePoint; +use dalek_ff_group::EdwardsPoint; + +use crate::{straus, pippenger, multiexp, multiexp_vartime}; + +#[allow(dead_code)] +fn benchmark_internal(straus_bool: bool) where G::Scalar: PrimeFieldBits { + let runs: usize = 20; + + let mut start = 0; + let mut increment: usize = 5; + let mut total: usize = 250; + let mut current = 2; + + if !straus_bool { + start = 100; + increment = 25; + total = 1000; + current = 4; + }; + + let mut pairs = Vec::with_capacity(total); + let mut sum = G::identity(); + + for _ in 0 .. start { + pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng))); + sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0; + } + + for _ in 0 .. (total / increment) { + for _ in 0 .. increment { + pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng))); + sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0; + } + + let now = Instant::now(); + for _ in 0 .. runs { + if straus_bool { + assert_eq!(straus(&pairs, current), sum); + } else { + assert_eq!(pippenger(&pairs, current), sum); + } + } + let current_per = now.elapsed().as_micros() / u128::try_from(pairs.len()).unwrap(); + + let now = Instant::now(); + for _ in 0 .. runs { + if straus_bool { + assert_eq!(straus(&pairs, current + 1), sum); + } else { + assert_eq!(pippenger(&pairs, current + 1), sum); + } + } + let next_per = now.elapsed().as_micros() / u128::try_from(pairs.len()).unwrap(); + + if next_per < current_per { + current += 1; + println!( + "{} {} is more efficient at {} with {}µs per", + if straus_bool { "Straus" } else { "Pippenger" }, current, pairs.len(), next_per + ); + if current >= 8 { + return; + } + } + } +} + +fn test_multiexp() where G::Scalar: PrimeFieldBits { + let mut pairs = Vec::with_capacity(1000); + let mut sum = G::identity(); + for _ in 0 .. 10 { + for _ in 0 .. 100 { + pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng))); + sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0; + } + assert_eq!(multiexp(&pairs), sum); + assert_eq!(multiexp_vartime(&pairs), sum); + } +} + +#[test] +fn test_secp256k1() { + test_multiexp::(); +} + +#[test] +fn test_ed25519() { + test_multiexp::(); +} + +#[test] +#[ignore] +fn benchmark() { + // Activate the processor's boost clock + for _ in 0 .. 30 { + test_multiexp::(); + } + + benchmark_internal::(true); + benchmark_internal::(false); + + benchmark_internal::(true); + benchmark_internal::(false); +}