From ee65e4df8fae23ada98065e9c6cbe90d1af2ece4 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Thu, 20 Apr 2023 01:11:39 -0400 Subject: [PATCH] Resolve #68 Notably speeds up monero-serai's build and CLSAG performance. --- crypto/dalek-ff-group/src/field.rs | 161 +++++++++++++---------------- crypto/dalek-ff-group/src/lib.rs | 39 +++---- 2 files changed, 90 insertions(+), 110 deletions(-) diff --git a/crypto/dalek-ff-group/src/field.rs b/crypto/dalek-ff-group/src/field.rs index 51a2e3a2..fb25281c 100644 --- a/crypto/dalek-ff-group/src/field.rs +++ b/crypto/dalek-ff-group/src/field.rs @@ -11,94 +11,84 @@ use subtle::{ ConditionallySelectable, }; -use crypto_bigint::{Integer, NonZero, Encoding, U256, U512}; +use crypto_bigint::{ + Integer, NonZero, Encoding, U256, U512, + modular::constant_mod::{ResidueParams, Residue}, + impl_modulus, +}; use group::ff::{Field, PrimeField, FieldBits, PrimeFieldBits}; -use crate::{u8_from_bool, constant_time, math_op, math, from_wrapper, from_uint}; +use crate::{u8_from_bool, constant_time, math_op, math}; -// 2^255 - 19 +// 2 ** 255 - 19 // Uses saturating_sub because checked_sub isn't available at compile time const MODULUS: U256 = U256::from_u8(1).shl_vartime(255).saturating_sub(&U256::from_u8(19)); const WIDE_MODULUS: U512 = U256::ZERO.concat(&MODULUS); +impl_modulus!( + FieldModulus, + U256, + // 2 ** 255 - 19 + "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed" +); +type ResidueType = Residue; + /// A constant-time implementation of the Ed25519 field. #[derive(Clone, Copy, PartialEq, Eq, Default, Debug)] -pub struct FieldElement(U256); +pub struct FieldElement(ResidueType); -/* -The following is a valid const definition of sqrt(-1) yet exceeds the const_eval_limit by 24x. -Accordingly, it'd only be usable on a nightly compiler with the following crate attributes: -#![feature(const_eval_limit)] -#![const_eval_limit = "24000000"] - -const SQRT_M1: FieldElement = { - // Formula from RFC-8032 (modp_sqrt_m1/sqrt8k5 z) - // 2 ** ((MODULUS - 1) // 4) % MODULUS - let base = U256::from_u8(2); - let exp = MODULUS.saturating_sub(&U256::from_u8(1)).wrapping_div(&U256::from_u8(4)); - - const fn mul(x: U256, y: U256) -> U256 { - let wide = U256::mul_wide(&x, &y); - let wide = U256::concat(&wide.1, &wide.0); - wide.wrapping_rem(&WIDE_MODULUS).split().1 - } - - // Perform the pow via multiply and square - let mut res = U256::ONE; - // Iterate from highest bit to lowest bit - let mut bit = 255; - loop { - if bit != 255 { - res = mul(res, res); - } - - // Reverse from little endian to big endian - if exp.bit_vartime(bit) == 1 { - res = mul(res, base); - } - - if bit == 0 { - break; - } - bit -= 1; - } - - FieldElement(res) -}; -*/ - -// Use a constant since we can't calculate it at compile-time without a nightly compiler -// Even without const_eval_limit, it'd take ~30s to calculate, which isn't worth it -const SQRT_M1: FieldElement = FieldElement(U256::from_be_hex( - "2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0", -)); +// Square root of -1. +// Formula from RFC-8032 (modp_sqrt_m1/sqrt8k5 z) +// 2 ** ((MODULUS - 1) // 4) % MODULUS +const SQRT_M1: FieldElement = FieldElement( + ResidueType::new(&U256::from_u8(2)) + .pow(&MODULUS.saturating_sub(&U256::ONE).wrapping_div(&U256::from_u8(4))), +); // Constant useful in calculating square roots (RFC-8032 sqrt8k5's exponent used to calculate y) -const MOD_3_8: FieldElement = - FieldElement(MODULUS.saturating_add(&U256::from_u8(3)).wrapping_div(&U256::from_u8(8))); +const MOD_3_8: FieldElement = FieldElement(ResidueType::new( + &MODULUS.saturating_add(&U256::from_u8(3)).wrapping_div(&U256::from_u8(8)), +)); // Constant useful in sqrt_ratio_i (sqrt(u / v)) -const MOD_5_8: FieldElement = FieldElement(MOD_3_8.0.saturating_sub(&U256::ONE)); +const MOD_5_8: FieldElement = FieldElement(ResidueType::sub(&MOD_3_8.0, &ResidueType::ONE)); -fn reduce(x: U512) -> U256 { - U256::from_le_slice(&x.rem(&NonZero::new(WIDE_MODULUS).unwrap()).to_le_bytes()[.. 32]) +fn reduce(x: U512) -> ResidueType { + ResidueType::new(&U256::from_le_slice( + &x.rem(&NonZero::new(WIDE_MODULUS).unwrap()).to_le_bytes()[.. 32], + )) } -constant_time!(FieldElement, U256); +constant_time!(FieldElement, ResidueType); math!( FieldElement, FieldElement, - |x, y| U256::add_mod(&x, &y, &MODULUS), - |x, y| U256::sub_mod(&x, &y, &MODULUS), - |x, y| reduce(U512::from(U256::mul_wide(&x, &y))) + |x: ResidueType, y: ResidueType| x.add(&y), + |x: ResidueType, y: ResidueType| x.sub(&y), + |x: ResidueType, y: ResidueType| x.mul(&y) ); -from_uint!(FieldElement, U256); + +macro_rules! from_wrapper { + ($uint: ident) => { + impl From<$uint> for FieldElement { + fn from(a: $uint) -> FieldElement { + Self(ResidueType::new(&U256::from(a))) + } + } + }; +} + +from_wrapper!(u8); +from_wrapper!(u16); +from_wrapper!(u32); +from_wrapper!(u64); +from_wrapper!(u128); impl Neg for FieldElement { type Output = Self; fn neg(self) -> Self::Output { - Self(self.0.neg_mod(&MODULUS)) + Self(self.0.neg()) } } @@ -110,8 +100,8 @@ impl<'a> Neg for &'a FieldElement { } impl Field for FieldElement { - const ZERO: Self = Self(U256::ZERO); - const ONE: Self = Self(U256::ONE); + const ZERO: Self = Self(ResidueType::ZERO); + const ONE: Self = Self(ResidueType::ONE); fn random(mut rng: impl RngCore) -> Self { let mut bytes = [0; 64]; @@ -120,14 +110,15 @@ impl Field for FieldElement { } fn square(&self) -> Self { - FieldElement(reduce(self.0.square())) + FieldElement(self.0.square()) } fn double(&self) -> Self { - FieldElement((self.0 << 1).rem(&NonZero::new(MODULUS).unwrap())) + FieldElement(self.0.add(&self.0)) } fn invert(&self) -> CtOption { - const NEG_2: FieldElement = FieldElement(MODULUS.saturating_sub(&U256::from_u8(2))); + const NEG_2: FieldElement = + FieldElement(ResidueType::new(&MODULUS.saturating_sub(&U256::from_u8(2)))); CtOption::new(self.pow(NEG_2), !self.is_zero()) } @@ -172,44 +163,39 @@ impl PrimeField for FieldElement { const NUM_BITS: u32 = 255; const CAPACITY: u32 = 254; - // 2.invert() - const TWO_INV: Self = FieldElement(U256::from_be_hex( - "3ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7", - )); + const TWO_INV: Self = FieldElement(ResidueType::new(&U256::from_u8(2)).invert().0); // This was calculated with the method from the ff crate docs // SageMath GF(modulus).primitive_element() - const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(2)); + const MULTIPLICATIVE_GENERATOR: Self = Self(ResidueType::new(&U256::from_u8(2))); // This was set per the specification in the ff crate docs // The number of leading zero bits in the little-endian bit representation of (modulus - 1) const S: u32 = 2; // This was calculated via the formula from the ff crate docs // Self::MULTIPLICATIVE_GENERATOR ** ((modulus - 1) >> Self::S) - const ROOT_OF_UNITY: Self = FieldElement(U256::from_be_hex( + const ROOT_OF_UNITY: Self = FieldElement(ResidueType::new(&U256::from_be_hex( "2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0", - )); + ))); // Self::ROOT_OF_UNITY.invert() - const ROOT_OF_UNITY_INV: Self = FieldElement(U256::from_be_hex( - "547cdb7fb03e20f4d4b2ff66c2042858d0bce7f952d01b873b11e4d8b5f15f3d", - )); + const ROOT_OF_UNITY_INV: Self = FieldElement(Self::ROOT_OF_UNITY.0.invert().0); // This was calculated via the formula from the ff crate docs // Self::MULTIPLICATIVE_GENERATOR ** (2 ** Self::S) - const DELTA: Self = FieldElement(U256::from_be_hex( + const DELTA: Self = FieldElement(ResidueType::new(&U256::from_be_hex( "0000000000000000000000000000000000000000000000000000000000000010", - )); + ))); fn from_repr(bytes: [u8; 32]) -> CtOption { - let res = Self(U256::from_le_bytes(bytes)); - CtOption::new(res, res.0.ct_lt(&MODULUS)) + let res = U256::from_le_bytes(bytes); + CtOption::new(Self(ResidueType::new(&res)), res.ct_lt(&MODULUS)) } fn to_repr(&self) -> [u8; 32] { - self.0.to_le_bytes() + self.0.retrieve().to_le_bytes() } fn is_odd(&self) -> Choice { - self.0.is_odd() + self.0.retrieve().is_odd() } fn from_u128(num: u128) -> Self { @@ -233,7 +219,7 @@ impl FieldElement { /// Interpret the value as a little-endian integer, square it, and reduce it into a FieldElement. pub fn from_square(value: [u8; 32]) -> FieldElement { let value = U256::from_le_bytes(value); - FieldElement(value) * FieldElement(value) + FieldElement(reduce(U512::from(value.mul_wide(&value)))) } /// Perform an exponentation. @@ -346,14 +332,15 @@ fn test_sqrt_m1() { // Test equivalence against the known constant value const SQRT_M1_MAGIC: U256 = U256::from_be_hex("2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0"); - assert_eq!(SQRT_M1.0, SQRT_M1_MAGIC); + assert_eq!(SQRT_M1.0.retrieve(), SQRT_M1_MAGIC); // Also test equivalence against the result of the formula from RFC-8032 (modp_sqrt_m1/sqrt8k5 z) // 2 ** ((MODULUS - 1) // 4) % MODULUS assert_eq!( SQRT_M1, - FieldElement::from(2u8) - .pow(FieldElement((FieldElement::ZERO - FieldElement::ONE).0.wrapping_div(&U256::from(4u8)))) + FieldElement::from(2u8).pow(FieldElement(ResidueType::new( + &(FieldElement::ZERO - FieldElement::ONE).0.retrieve().wrapping_div(&U256::from(4u8)) + ))) ); } diff --git a/crypto/dalek-ff-group/src/lib.rs b/crypto/dalek-ff-group/src/lib.rs index 3d67ebca..21ed119e 100644 --- a/crypto/dalek-ff-group/src/lib.rs +++ b/crypto/dalek-ff-group/src/lib.rs @@ -162,35 +162,28 @@ macro_rules! math_neg { }; } -macro_rules! from_wrapper { - ($wrapper: ident, $inner: ident, $uint: ident) => { - impl From<$uint> for $wrapper { - fn from(a: $uint) -> $wrapper { - Self($inner::from(a)) - } - } - }; -} -pub(crate) use from_wrapper; - -macro_rules! from_uint { - ($wrapper: ident, $inner: ident) => { - from_wrapper!($wrapper, $inner, u8); - from_wrapper!($wrapper, $inner, u16); - from_wrapper!($wrapper, $inner, u32); - from_wrapper!($wrapper, $inner, u64); - from_wrapper!($wrapper, $inner, u128); - }; -} -pub(crate) use from_uint; - /// Wrapper around the dalek Scalar type. #[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Zeroize)] pub struct Scalar(pub DScalar); deref_borrow!(Scalar, DScalar); constant_time!(Scalar, DScalar); math_neg!(Scalar, Scalar, DScalar::add, DScalar::sub, DScalar::mul); -from_uint!(Scalar, DScalar); + +macro_rules! from_wrapper { + ($uint: ident) => { + impl From<$uint> for Scalar { + fn from(a: $uint) -> Scalar { + Scalar(DScalar::from(a)) + } + } + }; +} + +from_wrapper!(u8); +from_wrapper!(u16); +from_wrapper!(u32); +from_wrapper!(u64); +from_wrapper!(u128); // Ed25519 order/scalar modulus const MODULUS: U256 =