From b8db677d4cacc139fc04c0956e7a639c33325b49 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Thu, 15 Dec 2022 19:23:42 -0500 Subject: [PATCH] Impl pow_vartime and sqrt on ed libs --- crypto/dalek-ff-group/src/field.rs | 56 +++++++++++++++++-- crypto/dalek-ff-group/src/lib.rs | 89 +++++++++++++++++++++++++++++- crypto/ed448/src/backend.rs | 16 +++++- crypto/frost/src/algorithm.rs | 1 + 4 files changed, 152 insertions(+), 10 deletions(-) diff --git a/crypto/dalek-ff-group/src/field.rs b/crypto/dalek-ff-group/src/field.rs index 1c89e63f..06fb057c 100644 --- a/crypto/dalek-ff-group/src/field.rs +++ b/crypto/dalek-ff-group/src/field.rs @@ -96,17 +96,32 @@ impl Field for FieldElement { fn sqrt(&self) -> CtOption { let tv1 = self.pow(MOD_3_8); let tv2 = tv1 * SQRT_M1; - CtOption::new(Self::conditional_select(&tv2, &tv1, tv1.square().ct_eq(self)), 1.into()) + let candidate = Self::conditional_select(&tv2, &tv1, tv1.square().ct_eq(self)); + CtOption::new(candidate, candidate.square().ct_eq(self)) } fn is_zero(&self) -> Choice { self.0.ct_eq(&U256::ZERO) } + fn cube(&self) -> Self { self.square() * self } - fn pow_vartime>(&self, _exp: S) -> Self { - unimplemented!() + + fn pow_vartime>(&self, exp: S) -> Self { + let mut sum = Self::one(); + let mut accum = *self; + for (_, num) in exp.as_ref().iter().enumerate() { + let mut num = *num; + for _ in 0 .. 64 { + if (num & 1) == 1 { + sum *= accum; + } + num >>= 1; + accum *= accum; + } + } + sum } } @@ -155,13 +170,13 @@ impl FieldElement { } pub fn pow(&self, other: FieldElement) -> FieldElement { - let mut table = [FieldElement(U256::ONE); 16]; + let mut table = [FieldElement::one(); 16]; table[1] = *self; for i in 2 .. 16 { table[i] = table[i - 1] * self; } - let mut res = FieldElement(U256::ONE); + let mut res = FieldElement::one(); let mut bits = 0; for (i, bit) in other.to_le_bits().iter().rev().enumerate() { bits <<= 1; @@ -248,6 +263,37 @@ fn test_mul() { assert_eq!(SQRT_M1.square(), -FieldElement::one()); } +#[test] +fn test_sqrt() { + assert_eq!(FieldElement::zero().sqrt().unwrap(), FieldElement::zero()); + assert_eq!(FieldElement::one().sqrt().unwrap(), FieldElement::one()); + for _ in 0 .. 10 { + let mut elem; + while { + elem = FieldElement::random(&mut rand_core::OsRng); + elem.sqrt().is_none().into() + } {} + assert_eq!(elem.sqrt().unwrap().square(), elem); + } +} + +#[test] +fn test_pow() { + let base = FieldElement::from(0b11100101u64); + assert_eq!(base.pow(FieldElement::zero()), FieldElement::one()); + assert_eq!(base.pow_vartime(&[]), FieldElement::one()); + assert_eq!(base.pow_vartime(&[0]), FieldElement::one()); + assert_eq!(base.pow_vartime(&[0, 0]), FieldElement::one()); + + assert_eq!(base.pow(FieldElement::one()), base); + assert_eq!(base.pow_vartime(&[1]), base); + assert_eq!(base.pow_vartime(&[1, 0]), base); + + let one_65 = FieldElement::from(u64::MAX) + FieldElement::one(); + assert_eq!(base.pow_vartime(&[0, 1]), base.pow(one_65)); + assert_eq!(base.pow_vartime(&[1, 1]), base.pow(one_65 + FieldElement::one())); +} + #[test] fn test_sqrt_ratio_i() { let zero = FieldElement::zero(); diff --git a/crypto/dalek-ff-group/src/lib.rs b/crypto/dalek-ff-group/src/lib.rs index 249c7577..4007b6f2 100644 --- a/crypto/dalek-ff-group/src/lib.rs +++ b/crypto/dalek-ff-group/src/lib.rs @@ -14,6 +14,7 @@ use digest::{consts::U64, Digest, HashMarker}; use subtle::{Choice, CtOption}; +use crypto_bigint::{Encoding, U256}; pub use curve25519_dalek as dalek; use dalek::{ @@ -175,7 +176,37 @@ constant_time!(Scalar, DScalar); math_neg!(Scalar, Scalar, DScalar::add, DScalar::sub, DScalar::mul); from_uint!(Scalar, DScalar); +const MODULUS: U256 = + U256::from_be_hex("1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed"); + impl Scalar { + pub fn pow(&self, other: Scalar) -> Scalar { + let mut table = [Scalar::one(); 16]; + table[1] = *self; + for i in 2 .. 16 { + table[i] = table[i - 1] * self; + } + + let mut res = Scalar::one(); + let mut bits = 0; + for (i, bit) in other.to_le_bits().iter().rev().enumerate() { + bits <<= 1; + let bit = u8::from(*bit); + bits |= bit; + + if ((i + 1) % 4) == 0 { + if i != 3 { + for _ in 0 .. 4 { + res *= res; + } + } + res *= table[usize::from(bits)]; + bits = 0; + } + } + res + } + /// Perform wide reduction on a 64-byte array to create a Scalar without bias. pub fn from_bytes_mod_order_wide(bytes: &[u8; 64]) -> Scalar { Self(DScalar::from_bytes_mod_order_wide(bytes)) @@ -214,7 +245,16 @@ impl Field for Scalar { CtOption::new(Self(self.0.invert()), !self.is_zero()) } fn sqrt(&self) -> CtOption { - unimplemented!() + let mod_3_8 = MODULUS.saturating_add(&U256::from_u8(3)).wrapping_div(&U256::from_u8(8)); + let mod_3_8 = Scalar::from_repr(mod_3_8.to_le_bytes()).unwrap(); + + let sqrt_m1 = MODULUS.saturating_sub(&U256::from_u8(1)).wrapping_div(&U256::from_u8(4)); + let sqrt_m1 = Scalar::one().double().pow(Scalar::from_repr(sqrt_m1.to_le_bytes()).unwrap()); + + let tv1 = self.pow(mod_3_8); + let tv2 = tv1 * sqrt_m1; + let candidate = Self::conditional_select(&tv2, &tv1, tv1.square().ct_eq(self)); + CtOption::new(candidate, candidate.square().ct_eq(self)) } fn is_zero(&self) -> Choice { self.0.ct_eq(&DScalar::zero()) @@ -222,8 +262,20 @@ impl Field for Scalar { fn cube(&self) -> Self { *self * self * self } - fn pow_vartime>(&self, _exp: S) -> Self { - unimplemented!() + fn pow_vartime>(&self, exp: S) -> Self { + let mut sum = Self::one(); + let mut accum = *self; + for (_, num) in exp.as_ref().iter().enumerate() { + let mut num = *num; + for _ in 0 .. 64 { + if (num & 1) == 1 { + sum *= accum; + } + num >>= 1; + accum *= accum; + } + } + sum } } @@ -396,3 +448,34 @@ dalek_group!( RISTRETTO_BASEPOINT_POINT, RISTRETTO_BASEPOINT_TABLE ); + +#[test] +fn test_sqrt() { + assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero()); + assert_eq!(Scalar::one().sqrt().unwrap(), Scalar::one()); + for _ in 0 .. 10 { + let mut elem; + while { + elem = Scalar::random(&mut rand_core::OsRng); + elem.sqrt().is_none().into() + } {} + assert_eq!(elem.sqrt().unwrap().square(), elem); + } +} + +#[test] +fn test_pow() { + let base = Scalar::from(0b11100101u64); + assert_eq!(base.pow(Scalar::zero()), Scalar::one()); + assert_eq!(base.pow_vartime(&[]), Scalar::one()); + assert_eq!(base.pow_vartime(&[0]), Scalar::one()); + assert_eq!(base.pow_vartime(&[0, 0]), Scalar::one()); + + assert_eq!(base.pow(Scalar::one()), base); + assert_eq!(base.pow_vartime(&[1]), base); + assert_eq!(base.pow_vartime(&[1, 0]), base); + + let one_65 = Scalar::from(u64::MAX) + Scalar::one(); + assert_eq!(base.pow_vartime(&[0, 1]), base.pow(one_65)); + assert_eq!(base.pow_vartime(&[1, 1]), base.pow(one_65 + Scalar::one())); +} diff --git a/crypto/ed448/src/backend.rs b/crypto/ed448/src/backend.rs index a46a1067..b63902f7 100644 --- a/crypto/ed448/src/backend.rs +++ b/crypto/ed448/src/backend.rs @@ -113,8 +113,20 @@ macro_rules! field { fn cube(&self) -> Self { self.square() * self } - fn pow_vartime>(&self, _exp: S) -> Self { - unimplemented!() + fn pow_vartime>(&self, exp: S) -> Self { + let mut sum = Self::one(); + let mut accum = *self; + for (_, num) in exp.as_ref().iter().enumerate() { + let mut num = *num; + for _ in 0 .. 64 { + if (num & 1) == 1 { + sum *= accum; + } + num >>= 1; + accum *= accum; + } + } + sum } } diff --git a/crypto/frost/src/algorithm.rs b/crypto/frost/src/algorithm.rs index d5cd9835..ad0dff5d 100644 --- a/crypto/frost/src/algorithm.rs +++ b/crypto/frost/src/algorithm.rs @@ -108,6 +108,7 @@ impl Transcript for IetfTranscript { self.0.clone() } + // FROST won't use this and this shouldn't be used outside of FROST fn rng_seed(&mut self, _: &'static [u8]) -> [u8; 32] { unimplemented!() }