use thiserror::Error; use rand_core::{RngCore, CryptoRng}; use digest::Digest; use transcript::Transcript; use group::{ff::{Field, PrimeField, PrimeFieldBits}, prime::PrimeGroup}; use multiexp::BatchVerifier; use crate::Generators; pub mod scalar; use scalar::{scalar_convert, mutual_scalar_from_bytes}; pub(crate) mod schnorr; use schnorr::SchnorrPoK; pub(crate) mod aos; mod bits; use bits::{BitSignature, Bits}; #[cfg(feature = "serialize")] use std::io::{Read, Write}; #[cfg(feature = "serialize")] pub(crate) fn read_point(r: &mut R) -> std::io::Result { let mut repr = G::Repr::default(); r.read_exact(repr.as_mut())?; let point = G::from_bytes(&repr); if point.is_none().into() { Err(std::io::Error::new(std::io::ErrorKind::Other, "invalid point"))?; } Ok(point.unwrap()) } #[derive(Error, PartialEq, Eq, Debug)] pub enum DLEqError { #[error("invalid proof of knowledge")] InvalidProofOfKnowledge, #[error("invalid proof length")] InvalidProofLength, #[error("invalid challenge")] InvalidChallenge, #[error("invalid proof")] InvalidProof } // Debug would be such a dump of data this likely isn't helpful, but at least it's available to // anyone who wants it #[derive(Clone, PartialEq, Eq, Debug)] pub struct DLEqProof< G0: PrimeGroup, G1: PrimeGroup, const SIGNATURE: u8, const RING_LEN: usize, const REMAINDER_RING_LEN: usize > where G0::Scalar: PrimeFieldBits, G1::Scalar: PrimeFieldBits { bits: Vec>, remainder: Option>, poks: (SchnorrPoK, SchnorrPoK) } pub type ConciseLinearDLEq = DLEqProof< G0, G1, { BitSignature::ConciseLinear.to_u8() }, { BitSignature::ConciseLinear.ring_len() }, // There may not be a remainder, yet if there is, it'll be just one bit // A ring for one bit has a RING_LEN of 2 2 >; pub type EfficientLinearDLEq = DLEqProof< G0, G1, { BitSignature::EfficientLinear.to_u8() }, { BitSignature::EfficientLinear.ring_len() }, 0 >; impl< G0: PrimeGroup, G1: PrimeGroup, const SIGNATURE: u8, const RING_LEN: usize, const REMAINDER_RING_LEN: usize > DLEqProof where G0::Scalar: PrimeFieldBits, G1::Scalar: PrimeFieldBits { pub(crate) fn transcript( transcript: &mut T, generators: (Generators, Generators), keys: (G0, G1) ) { transcript.domain_separate(b"cross_group_dleq"); generators.0.transcript(transcript); generators.1.transcript(transcript); transcript.domain_separate(b"points"); transcript.append_message(b"point_0", keys.0.to_bytes().as_ref()); transcript.append_message(b"point_1", keys.1.to_bytes().as_ref()); } pub(crate) fn blinding_key( rng: &mut R, total: &mut F, last: bool ) -> F { let blinding_key = if last { -*total } else { F::random(&mut *rng) }; *total += blinding_key; blinding_key } fn reconstruct_keys(&self) -> (G0, G1) { let mut res = ( self.bits.iter().map(|bit| bit.commitments.0).sum::(), self.bits.iter().map(|bit| bit.commitments.1).sum::() ); if let Some(bit) = &self.remainder { res.0 += bit.commitments.0; res.1 += bit.commitments.1; } res } fn prove_internal( rng: &mut R, transcript: &mut T, generators: (Generators, Generators), f: (G0::Scalar, G1::Scalar) ) -> (Self, (G0::Scalar, G1::Scalar)) { Self::transcript( transcript, generators, ((generators.0.primary * f.0), (generators.1.primary * f.1)) ); let poks = ( SchnorrPoK::::prove(rng, transcript, generators.0.primary, f.0), SchnorrPoK::::prove(rng, transcript, generators.1.primary, f.1) ); let mut blinding_key_total = (G0::Scalar::zero(), G1::Scalar::zero()); let mut blinding_key = |rng: &mut R, last| { let blinding_key = ( Self::blinding_key(&mut *rng, &mut blinding_key_total.0, last), Self::blinding_key(&mut *rng, &mut blinding_key_total.1, last) ); if last { debug_assert_eq!(blinding_key_total.0, G0::Scalar::zero()); debug_assert_eq!(blinding_key_total.1, G1::Scalar::zero()); } blinding_key }; let capacity = usize::try_from(G0::Scalar::CAPACITY.min(G1::Scalar::CAPACITY)).unwrap(); let bits_per_group = BitSignature::from(SIGNATURE).bits(); let mut pow_2 = (generators.0.primary, generators.1.primary); let raw_bits = f.0.to_le_bits(); let mut bits = Vec::with_capacity(capacity); let mut these_bits: u8 = 0; for (i, bit) in raw_bits.iter().enumerate() { if i == capacity { break; } let bit = *bit as u8; debug_assert_eq!(bit | 1, 1); // Accumulate this bit these_bits |= bit << (i % bits_per_group); if (i % bits_per_group) == (bits_per_group - 1) { let last = i == (capacity - 1); let blinding_key = blinding_key(&mut *rng, last); bits.push( Bits::prove( &mut *rng, transcript, generators, i / bits_per_group, &mut pow_2, these_bits, blinding_key ) ); these_bits = 0; } } debug_assert_eq!(bits.len(), capacity / bits_per_group); let mut remainder = None; if capacity != ((capacity / bits_per_group) * bits_per_group) { let blinding_key = blinding_key(&mut *rng, true); remainder = Some( Bits::prove( &mut *rng, transcript, generators, capacity / bits_per_group, &mut pow_2, these_bits, blinding_key ) ); } let proof = DLEqProof { bits, remainder, poks }; debug_assert_eq!( proof.reconstruct_keys(), (generators.0.primary * f.0, generators.1.primary * f.1) ); (proof, f) } /// Prove the cross-Group Discrete Log Equality for the points derived from the scalar created as /// the output of the passed in Digest. Given the non-standard requirements to achieve /// uniformity, needing to be < 2^x instead of less than a prime moduli, this is the simplest way /// to safely and securely generate a Scalar, without risk of failure, nor bias /// It also ensures a lack of determinable relation between keys, guaranteeing security in the /// currently expected use case for this, atomic swaps, where each swap leaks the key. Knowing /// the relationship between keys would allow breaking all swaps after just one pub fn prove( rng: &mut R, transcript: &mut T, generators: (Generators, Generators), digest: D ) -> (Self, (G0::Scalar, G1::Scalar)) { Self::prove_internal( rng, transcript, generators, mutual_scalar_from_bytes(digest.finalize().as_ref()) ) } /// Prove the cross-Group Discrete Log Equality for the points derived from the scalar passed in, /// failing if it's not mutually valid. This allows for rejection sampling externally derived /// scalars until they're safely usable, as needed pub fn prove_without_bias( rng: &mut R, transcript: &mut T, generators: (Generators, Generators), f0: G0::Scalar ) -> Option<(Self, (G0::Scalar, G1::Scalar))> { scalar_convert(f0).map(|f1| Self::prove_internal(rng, transcript, generators, (f0, f1))) } /// Verify a cross-Group Discrete Log Equality statement, returning the points proven for pub fn verify( &self, rng: &mut R, transcript: &mut T, generators: (Generators, Generators) ) -> Result<(G0, G1), DLEqError> { let capacity = usize::try_from( G0::Scalar::CAPACITY.min(G1::Scalar::CAPACITY) ).unwrap(); let bits_per_group = BitSignature::from(SIGNATURE).bits(); let has_remainder = (capacity % bits_per_group) != 0; // These shouldn't be possible, as locally created and deserialized proofs should be properly // formed in these regards, yet it doesn't hurt to check and would be problematic if true if (self.bits.len() != (capacity / bits_per_group)) || ( (self.remainder.is_none() && has_remainder) || (self.remainder.is_some() && !has_remainder) ) { return Err(DLEqError::InvalidProofLength); } let keys = self.reconstruct_keys(); Self::transcript(transcript, generators, keys); let batch_capacity = match BitSignature::from(SIGNATURE) { BitSignature::ConciseLinear => 3, BitSignature::EfficientLinear => (self.bits.len() + 1) * 3 }; let mut batch = (BatchVerifier::new(batch_capacity), BatchVerifier::new(batch_capacity)); self.poks.0.verify(&mut *rng, transcript, generators.0.primary, keys.0, &mut batch.0); self.poks.1.verify(&mut *rng, transcript, generators.1.primary, keys.1, &mut batch.1); let mut pow_2 = (generators.0.primary, generators.1.primary); for (i, bits) in self.bits.iter().enumerate() { bits.verify(&mut *rng, transcript, generators, &mut batch, i, &mut pow_2)?; } if let Some(bit) = &self.remainder { bit.verify(&mut *rng, transcript, generators, &mut batch, self.bits.len(), &mut pow_2)?; } if (!batch.0.verify_vartime()) || (!batch.1.verify_vartime()) { Err(DLEqError::InvalidProof)?; } Ok(keys) } #[cfg(feature = "serialize")] pub fn serialize(&self, w: &mut W) -> std::io::Result<()> { for bit in &self.bits { bit.serialize(w)?; } if let Some(bit) = &self.remainder { bit.serialize(w)?; } self.poks.0.serialize(w)?; self.poks.1.serialize(w) } #[cfg(feature = "serialize")] pub fn deserialize(r: &mut R) -> std::io::Result { let capacity = usize::try_from( G0::Scalar::CAPACITY.min(G1::Scalar::CAPACITY) ).unwrap(); let bits_per_group = BitSignature::from(SIGNATURE).bits(); let mut bits = Vec::with_capacity(capacity / bits_per_group); for _ in 0 .. (capacity / bits_per_group) { bits.push(Bits::deserialize(r)?); } let mut remainder = None; if (capacity % bits_per_group) != 0 { remainder = Some(Bits::deserialize(r)?); } Ok( DLEqProof { bits, remainder, poks: (SchnorrPoK::deserialize(r)?, SchnorrPoK::deserialize(r)?) } ) } }