diff --git a/crypto/multiexp/src/batch.rs b/crypto/multiexp/src/batch.rs index 5584ba69..24384596 100644 --- a/crypto/multiexp/src/batch.rs +++ b/crypto/multiexp/src/batch.rs @@ -22,11 +22,13 @@ where /// A batch verifier intended to verify a series of statements are each equivalent to zero. #[allow(clippy::type_complexity)] #[derive(Clone, Zeroize)] -pub struct BatchVerifier( - Zeroizing)>>, -) +pub struct BatchVerifier where - ::Scalar: PrimeFieldBits + Zeroize; + ::Scalar: PrimeFieldBits + Zeroize, +{ + split: u64, + statements: Zeroizing)>>, +} impl BatchVerifier where @@ -35,7 +37,7 @@ where /// Create a new batch verifier, expected to verify the following amount of statements. /// This is a size hint and is not required to be accurate. pub fn new(capacity: usize) -> BatchVerifier { - BatchVerifier(Zeroizing::new(Vec::with_capacity(capacity))) + BatchVerifier { split: 0, statements: Zeroizing::new(Vec::with_capacity(capacity)) } } /// Queue a statement for batch verification. @@ -45,8 +47,15 @@ where id: Id, pairs: I, ) { + // If this is the first time we're queueing a value, grab a u64 (AKA 64 bits) to determine + // which side to handle odd splits with during blame (preventing malicious actors from gaming + // the system by maximizing the round length) + if self.statements.len() == 0 { + self.split = rng.next_u64(); + } + // Define a unique scalar factor for this set of variables so individual items can't overlap - let u = if self.0.is_empty() { + let u = if self.statements.is_empty() { G::Scalar::one() } else { let mut weight; @@ -88,19 +97,21 @@ where weight }; - self.0.push((id, pairs.into_iter().map(|(scalar, point)| (scalar * u, point)).collect())); + self + .statements + .push((id, pairs.into_iter().map(|(scalar, point)| (scalar * u, point)).collect())); } /// Perform batch verification, returning a boolean of if the statements equaled zero. #[must_use] pub fn verify(&self) -> bool { - multiexp(&flat(&self.0)).is_identity().into() + multiexp(&flat(&self.statements)).is_identity().into() } /// Perform batch verification in variable time. #[must_use] pub fn verify_vartime(&self) -> bool { - multiexp_vartime(&flat(&self.0)).is_identity().into() + multiexp_vartime(&flat(&self.statements)).is_identity().into() } /// Perform a binary search to identify which statement does not equal 0, returning None if all @@ -108,9 +119,21 @@ where /// multiple are invalid. // A constant time variant may be beneficial for robust protocols pub fn blame_vartime(&self) -> Option { - let mut slice = self.0.as_slice(); + let mut slice = self.statements.as_slice(); + let mut split_side = self.split; + while slice.len() > 1 { - let split = slice.len() / 2; + let mut split = slice.len() / 2; + // If there's an odd number of elements, this can be gamed to increase the amount of rounds + // For [0, 1, 2], where 2 is invalid, this will take three rounds + // ([0], [1, 2]), then ([1], [2]), before just 2 + // If 1 and 2 were valid, this would've only taken 2 rounds to complete + // To prevent this from being gamed, if there's an odd number of elements, randomize which + // side the split occurs on + if slice.len() % 2 == 1 { + split += usize::try_from(split_side & 1).unwrap(); + split_side >>= 1; + } if multiexp_vartime(&flat(&slice[.. split])).is_identity().into() { slice = &slice[split ..];