From 15d6be16783064a3e94780e1fd79861534c6aef8 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Thu, 2 Mar 2023 01:13:07 -0500 Subject: [PATCH] 3.7.1 Deduplicate flattening/zeroize code While the prior intent was to avoid zeroizing for vartime verification, which is assumed to not have any private data, this simplifies the code and promotes safety. --- crypto/multiexp/src/batch.rs | 59 +++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/crypto/multiexp/src/batch.rs b/crypto/multiexp/src/batch.rs index 3fd07733..5584ba69 100644 --- a/crypto/multiexp/src/batch.rs +++ b/crypto/multiexp/src/batch.rs @@ -1,16 +1,32 @@ use rand_core::{RngCore, CryptoRng}; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use ff::{Field, PrimeFieldBits}; use group::Group; use crate::{multiexp, multiexp_vartime}; +// Flatten the contained statements to a single Vec. +// Wrapped in Zeroizing in case any of the included statements contain private values. +#[allow(clippy::type_complexity)] +fn flat( + slice: &[(Id, Vec<(G::Scalar, G)>)], +) -> Zeroizing> +where + ::Scalar: PrimeFieldBits + Zeroize, +{ + Zeroizing::new(slice.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>()) +} + /// A batch verifier intended to verify a series of statements are each equivalent to zero. #[allow(clippy::type_complexity)] #[derive(Clone, Zeroize)] -pub struct BatchVerifier(Vec<(Id, Vec<(G::Scalar, G)>)>); +pub struct BatchVerifier( + Zeroizing)>>, +) +where + ::Scalar: PrimeFieldBits + Zeroize; impl BatchVerifier where @@ -19,7 +35,7 @@ where /// Create a new batch verifier, expected to verify the following amount of statements. /// This is a size hint and is not required to be accurate. pub fn new(capacity: usize) -> BatchVerifier { - BatchVerifier(Vec::with_capacity(capacity)) + BatchVerifier(Zeroizing::new(Vec::with_capacity(capacity))) } /// Queue a statement for batch verification. @@ -71,31 +87,20 @@ where } {} weight }; + self.0.push((id, pairs.into_iter().map(|(scalar, point)| (scalar * u, point)).collect())); } /// Perform batch verification, returning a boolean of if the statements equaled zero. #[must_use] - pub fn verify_core(&self) -> bool { - let mut flat = self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>(); - let res = multiexp(&flat).is_identity().into(); - flat.zeroize(); - res - } - - /// Perform batch verification, zeroizing the statements verified. - pub fn verify(mut self) -> bool { - let res = self.verify_core(); - self.zeroize(); - res + pub fn verify(&self) -> bool { + multiexp(&flat(&self.0)).is_identity().into() } /// Perform batch verification in variable time. #[must_use] pub fn verify_vartime(&self) -> bool { - multiexp_vartime(&self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>()) - .is_identity() - .into() + multiexp_vartime(&flat(&self.0)).is_identity().into() } /// Perform a binary search to identify which statement does not equal 0, returning None if all @@ -106,12 +111,8 @@ where let mut slice = self.0.as_slice(); while slice.len() > 1 { let split = slice.len() / 2; - if multiexp_vartime( - &slice[.. split].iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::>(), - ) - .is_identity() - .into() - { + + if multiexp_vartime(&flat(&slice[.. split])).is_identity().into() { slice = &slice[split ..]; } else { slice = &slice[.. split]; @@ -126,10 +127,12 @@ where /// Perform constant time batch verification, and if verification fails, identify one faulty /// statement in variable time. - pub fn verify_with_vartime_blame(mut self) -> Result<(), Id> { - let res = if self.verify_core() { Ok(()) } else { Err(self.blame_vartime().unwrap()) }; - self.zeroize(); - res + pub fn verify_with_vartime_blame(&self) -> Result<(), Id> { + if self.verify() { + Ok(()) + } else { + Err(self.blame_vartime().unwrap()) + } } /// Perform variable time batch verification, and if verification fails, identify one faulty