Constant-time divisors (#617)

* WIP constant-time implementation of the ec-divisors library

* Fix misc logic errors in poly.rs

* Remove accidentally committed test statements

* Fix ConstantTimeEq for CoefficientIndex

* Correct the iterations formula

x**3 / (0 y + x**1) would prior be considered indivisible with iterations = 0.
It is divisible however. The amount of iterations should be the amount of
coefficients within the numerator *excluding the coefficient for y**0 x**0*.

* Poly PartialEq, conditional_select_poly which checks poly structure equivalence

If the first passed argument is smaller than the latter, it's padded to the
necessary length.

Also adds code to trim the remainder as the remainder is the value modulo, so
it's very important it remains concise and workable.

* Fix the line function

It selected the case if both were identity before selecting the case if either
were identity, the latter overwriting the former.

* Final fixes re: ct_get

1) Our quotient structure does need to be of size equal to the numerator
   entirely to prevent out-of-bounds reads on it
2) We need to get from yx_coefficients if of length >=, so if the length is 1
   we can read y_pow=1 from it. If y_pow=0, and its length is 0 so it has no
   inner Vecs, we need to fall back with the guard y_pow != 0.

* Add a trim algorithm to lib.rs to prevent Polys from becoming unbearably gigantic

Our Poly algorithm is incredibly leaky. While it presumably should be improved,
we can take advantage of our known structure while constructing divisors (and
the small modulus) to simply trim out the zero coefficients leaked. This
maintains Polys in a manageable size.

* Move constant-time scalar mul gadget divisor creation from dkg to ec-divisors

Anyone creating a divisor for the scalar mul gadget should use constant time
code, so this code should at least be in the EC gadgets crate It's of
non-trivial complexity to deal with otherwise.

* Remove unsafe, cache timing attacks from ec-divisors
This commit is contained in:
Luke Parker 2024-09-24 14:27:05 -07:00 committed by GitHub
parent 2c8af04781
commit 251a6e96e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 763 additions and 368 deletions

3
Cargo.lock generated
View file

@ -2196,7 +2196,6 @@ dependencies = [
"schnorr-signatures",
"secq256k1",
"std-shims",
"subtle",
"thiserror",
"zeroize",
]
@ -2293,10 +2292,12 @@ name = "ec-divisors"
version = "0.1.0"
dependencies = [
"dalek-ff-group",
"ff",
"group",
"hex",
"pasta_curves",
"rand_core",
"subtle",
"zeroize",
]

View file

@ -35,7 +35,7 @@ impl_modulus!(
type ResidueType = Residue<FieldModulus, { FieldModulus::LIMBS }>;
/// A constant-time implementation of the Ed25519 field.
#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)]
#[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Zeroize)]
pub struct FieldElement(ResidueType);
// Square root of -1.

View file

@ -37,7 +37,6 @@ schnorr = { package = "schnorr-signatures", path = "../schnorr", version = "^0.5
dleq = { path = "../dleq", version = "^0.4.1", default-features = false }
# eVRF DKG dependencies
subtle = { version = "2", default-features = false, features = ["std"], optional = true }
generic-array = { version = "1", default-features = false, features = ["alloc"], optional = true }
blake2 = { version = "0.10", default-features = false, features = ["std"], optional = true }
rand_chacha = { version = "0.3", default-features = false, features = ["std"], optional = true }
@ -82,7 +81,6 @@ borsh = ["dep:borsh"]
evrf = [
"std",
"dep:subtle",
"dep:generic-array",
"dep:blake2",

View file

@ -1,6 +1,5 @@
use core::{marker::PhantomData, ops::Deref, fmt};
use subtle::*;
use zeroize::{Zeroize, Zeroizing};
use rand_core::{RngCore, CryptoRng, SeedableRng};
@ -10,10 +9,7 @@ use generic_array::{typenum::Unsigned, ArrayLength, GenericArray};
use blake2::{Digest, Blake2s256};
use ciphersuite::{
group::{
ff::{Field, PrimeField, PrimeFieldBits},
Group, GroupEncoding,
},
group::{ff::Field, Group, GroupEncoding},
Ciphersuite,
};
@ -24,7 +20,7 @@ use generalized_bulletproofs::{
};
use generalized_bulletproofs_circuit_abstraction::*;
use ec_divisors::{DivisorCurve, new_divisor};
use ec_divisors::{DivisorCurve, ScalarDecomposition};
use generalized_bulletproofs_ec_gadgets::*;
/// A pair of curves to perform the eVRF with.
@ -309,147 +305,6 @@ impl<C: EvrfCurve> Evrf<C> {
debug_assert!(challenged_generators.next().is_none());
}
/// Convert a scalar to a sequence of coefficients for the polynomial 2**i, where the sum of the
/// coefficients is F::NUM_BITS.
///
/// Despite the name, the returned coefficients are not guaranteed to be bits (0 or 1).
///
/// This scalar will presumably be used in a discrete log proof. That requires calculating a
/// divisor which is variable time to the amount of points interpolated. Since the amount of
/// points interpolated is equal to the sum of the coefficients in the polynomial, we need all
/// scalars to have a constant sum of their coefficients (instead of one variable to its bits).
///
/// We achieve this by finding the highest non-0 coefficient, decrementing it, and increasing the
/// immediately less significant coefficient by 2. This increases the sum of the coefficients by
/// 1 (-1+2=1).
fn scalar_to_bits(scalar: &<C::EmbeddedCurve as Ciphersuite>::F) -> Vec<u64> {
let num_bits = u64::from(<<C as EvrfCurve>::EmbeddedCurve as Ciphersuite>::F::NUM_BITS);
// Obtain the bits of the private key
let num_bits_usize = usize::try_from(num_bits).unwrap();
let mut decomposition = vec![0; num_bits_usize];
for (i, bit) in scalar.to_le_bits().into_iter().take(num_bits_usize).enumerate() {
let bit = u64::from(u8::from(bit));
decomposition[i] = bit;
}
// The following algorithm only works if the value of the scalar exceeds num_bits
// If it isn't, we increase it by the modulus such that it does exceed num_bits
{
let mut less_than_num_bits = Choice::from(0);
for i in 0 .. num_bits {
less_than_num_bits |= scalar.ct_eq(&<C::EmbeddedCurve as Ciphersuite>::F::from(i));
}
let mut decomposition_of_modulus = vec![0; num_bits_usize];
// Decompose negative one
for (i, bit) in (-<C::EmbeddedCurve as Ciphersuite>::F::ONE)
.to_le_bits()
.into_iter()
.take(num_bits_usize)
.enumerate()
{
let bit = u64::from(u8::from(bit));
decomposition_of_modulus[i] = bit;
}
// Increment it by one
decomposition_of_modulus[0] += 1;
// Add the decomposition onto the decomposition of the modulus
for i in 0 .. num_bits_usize {
let new_decomposition = <_>::conditional_select(
&decomposition[i],
&(decomposition[i] + decomposition_of_modulus[i]),
less_than_num_bits,
);
decomposition[i] = new_decomposition;
}
}
// Calculcate the sum of the coefficients
let mut sum_of_coefficients: u64 = 0;
for decomposition in &decomposition {
sum_of_coefficients += *decomposition;
}
/*
Now, because we added a log2(k)-bit number to a k-bit number, we may have our sum of
coefficients be *too high*. We attempt to reduce the sum of the coefficients accordingly.
This algorithm is guaranteed to complete as expected. Take the sequence `222`. `222` becomes
`032` becomes `013`. Even if the next coefficient in the sequence is `2`, the third
coefficient will be reduced once and the next coefficient (`2`, increased to `3`) will only
be eligible for reduction once. This demonstrates, even for a worst case of log2(k) `2`s
followed by `1`s (as possible if the modulus is a Mersenne prime), the log2(k) `2`s can be
reduced as necessary so long as there is a single coefficient after (requiring the entire
sequence be at least of length log2(k) + 1). For a 2-bit number, log2(k) + 1 == 2, so this
holds for any odd prime field.
To fully type out the demonstration for the Mersenne prime 3, with scalar to encode 1 (the
highest value less than the number of bits):
10 - Little-endian bits of 1
21 - Little-endian bits of 1, plus the modulus
02 - After one reduction, where the sum of the coefficients does in fact equal 2 (the target)
*/
{
let mut log2_num_bits = 0;
while (1 << log2_num_bits) < num_bits {
log2_num_bits += 1;
}
for _ in 0 .. log2_num_bits {
// If the sum of coefficients is the amount of bits, we're done
let mut done = sum_of_coefficients.ct_eq(&num_bits);
for i in 0 .. (num_bits_usize - 1) {
let should_act = (!done) & decomposition[i].ct_gt(&1);
// Subtract 2 from this coefficient
let amount_to_sub = <_>::conditional_select(&0, &2, should_act);
decomposition[i] -= amount_to_sub;
// Add 1 to the next coefficient
let amount_to_add = <_>::conditional_select(&0, &1, should_act);
decomposition[i + 1] += amount_to_add;
// Also update the sum of coefficients
sum_of_coefficients -= <_>::conditional_select(&0, &1, should_act);
// If we updated the coefficients this loop iter, we're done for this loop iter
done |= should_act;
}
}
}
for _ in 0 .. num_bits {
// If the sum of coefficients is the amount of bits, we're done
let mut done = sum_of_coefficients.ct_eq(&num_bits);
// Find the highest coefficient currently non-zero
for i in (1 .. decomposition.len()).rev() {
// If this is non-zero, we should decrement this coefficient if we haven't already
// decremented a coefficient this round
let is_non_zero = !(0.ct_eq(&decomposition[i]));
let should_act = (!done) & is_non_zero;
// Update this coefficient and the prior coefficient
let amount_to_sub = <_>::conditional_select(&0, &1, should_act);
decomposition[i] -= amount_to_sub;
let amount_to_add = <_>::conditional_select(&0, &2, should_act);
// i must be at least 1, so i - 1 will be at least 0 (meaning it's safe to index with)
decomposition[i - 1] += amount_to_add;
// Also update the sum of coefficients
sum_of_coefficients += <_>::conditional_select(&0, &1, should_act);
// If we updated the coefficients this loop iter, we're done for this loop iter
done |= should_act;
}
}
debug_assert!(bool::from(decomposition.iter().sum::<u64>().ct_eq(&num_bits)));
decomposition
}
/// Prove a point on an elliptic curve had its discrete logarithm generated via an eVRF.
pub(crate) fn prove(
rng: &mut (impl RngCore + CryptoRng),
@ -471,11 +326,9 @@ impl<C: EvrfCurve> Evrf<C> {
// A function to calculate a divisor and push it onto the tape
// This defines a vec, divisor_points, outside of the fn to reuse its allocation
let mut divisor_points =
Vec::with_capacity((<C::EmbeddedCurve as Ciphersuite>::F::NUM_BITS as usize) + 1);
let mut divisor =
|vector_commitment_tape: &mut Vec<_>,
dlog: &[u64],
dlog: &ScalarDecomposition<<<C as EvrfCurve>::EmbeddedCurve as Ciphersuite>::F>,
push_generator: bool,
generator: <<C as EvrfCurve>::EmbeddedCurve as Ciphersuite>::G,
dh: <<C as EvrfCurve>::EmbeddedCurve as Ciphersuite>::G| {
@ -484,24 +337,7 @@ impl<C: EvrfCurve> Evrf<C> {
generator_tables.push(GeneratorTable::new(&curve_spec, x, y));
}
{
let mut generator = generator;
for coefficient in dlog {
let mut coefficient = *coefficient;
while coefficient != 0 {
coefficient -= 1;
divisor_points.push(generator);
}
generator = generator.double();
}
debug_assert_eq!(
dlog.iter().sum::<u64>(),
u64::from(<C::EmbeddedCurve as Ciphersuite>::F::NUM_BITS)
);
}
divisor_points.push(-dh);
let mut divisor = new_divisor(&divisor_points).unwrap().normalize_x_coefficient();
divisor_points.zeroize();
let mut divisor = dlog.scalar_mul_divisor(generator).normalize_x_coefficient();
vector_commitment_tape.push(divisor.zero_coefficient);
@ -540,11 +376,12 @@ impl<C: EvrfCurve> Evrf<C> {
let evrf_public_key;
let mut actual_coefficients = Vec::with_capacity(coefficients);
{
let mut dlog = Self::scalar_to_bits(evrf_private_key);
let dlog =
ScalarDecomposition::<<C::EmbeddedCurve as Ciphersuite>::F>::new(**evrf_private_key);
let points = Self::transcript_to_points(transcript, coefficients);
// Start by pushing the discrete logarithm onto the tape
for coefficient in &dlog {
for coefficient in dlog.decomposition() {
vector_commitment_tape.push(<_>::from(*coefficient));
}
@ -573,8 +410,6 @@ impl<C: EvrfCurve> Evrf<C> {
actual_coefficients.push(res);
}
debug_assert_eq!(actual_coefficients.len(), coefficients);
dlog.zeroize();
}
// Now do the ECDHs for the encryption
@ -595,14 +430,15 @@ impl<C: EvrfCurve> Evrf<C> {
break;
}
}
let mut dlog = Self::scalar_to_bits(&ecdh_private_key);
let dlog =
ScalarDecomposition::<<C::EmbeddedCurve as Ciphersuite>::F>::new(ecdh_private_key);
let ecdh_commitment = <C::EmbeddedCurve as Ciphersuite>::generator() * ecdh_private_key;
ecdh_commitments.push(ecdh_commitment);
ecdh_commitments_xy.last_mut().unwrap()[j] =
<<C::EmbeddedCurve as Ciphersuite>::G as DivisorCurve>::to_xy(ecdh_commitment).unwrap();
// Start by pushing the discrete logarithm onto the tape
for coefficient in &dlog {
for coefficient in dlog.decomposition() {
vector_commitment_tape.push(<_>::from(*coefficient));
}
@ -625,7 +461,6 @@ impl<C: EvrfCurve> Evrf<C> {
*res += dh_x;
ecdh_private_key.zeroize();
dlog.zeroize();
}
encryption_masks.push(res);
}

View file

@ -14,9 +14,11 @@ rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
rand_core = { version = "0.6", default-features = false }
zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] }
zeroize = { version = "^1.5", default-features = false, features = ["std", "zeroize_derive"] }
group = "0.13"
subtle = { version = "2", default-features = false, features = ["std"] }
ff = { version = "0.13", default-features = false, features = ["std", "bits"] }
group = { version = "0.13", default-features = false }
hex = { version = "0.4", optional = true }
dalek-ff-group = { path = "../../dalek-ff-group", features = ["std"], optional = true }

View file

@ -3,21 +3,24 @@
#![deny(missing_docs)]
#![allow(non_snake_case)]
use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConditionallySelectable};
use zeroize::{Zeroize, ZeroizeOnDrop};
use group::{
ff::{Field, PrimeField},
ff::{Field, PrimeField, PrimeFieldBits},
Group,
};
mod poly;
pub use poly::*;
pub use poly::Poly;
#[cfg(test)]
mod tests;
/// A curve usable with this library.
pub trait DivisorCurve: Group {
pub trait DivisorCurve: Group + ConstantTimeEq + ConditionallySelectable {
/// An element of the field this curve is defined over.
type FieldElement: PrimeField;
type FieldElement: Zeroize + PrimeField + ConditionallySelectable;
/// The A in the curve equation y^2 = x^3 + A x + B.
fn a() -> Self::FieldElement;
@ -72,46 +75,89 @@ pub(crate) fn slope_intercept<C: DivisorCurve>(a: C, b: C) -> (C::FieldElement,
}
// The line interpolating two points.
fn line<C: DivisorCurve>(a: C, mut b: C) -> Poly<C::FieldElement> {
// If they're both the point at infinity, we simply set the line to one
if bool::from(a.is_identity() & b.is_identity()) {
return Poly {
y_coefficients: vec![],
yx_coefficients: vec![],
x_coefficients: vec![],
zero_coefficient: C::FieldElement::ONE,
};
fn line<C: DivisorCurve>(a: C, b: C) -> Poly<C::FieldElement> {
#[derive(Clone, Copy)]
struct LinesRes<F: ConditionallySelectable> {
y_coefficient: F,
x_coefficient: F,
zero_coefficient: F,
}
impl<F: ConditionallySelectable> ConditionallySelectable for LinesRes<F> {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
Self {
y_coefficient: <_>::conditional_select(&a.y_coefficient, &b.y_coefficient, choice),
x_coefficient: <_>::conditional_select(&a.x_coefficient, &b.x_coefficient, choice),
zero_coefficient: <_>::conditional_select(&a.zero_coefficient, &b.zero_coefficient, choice),
}
}
}
let a_is_identity = a.is_identity();
let b_is_identity = b.is_identity();
// If they're both the point at infinity, we simply set the line to one
let both_are_identity = a_is_identity & b_is_identity;
let if_both_are_identity = LinesRes {
y_coefficient: C::FieldElement::ZERO,
x_coefficient: C::FieldElement::ZERO,
zero_coefficient: C::FieldElement::ONE,
};
// If either point is the point at infinity, or these are additive inverses, the line is
// `1 * x - x`. The first `x` is a term in the polynomial, the `x` is the `x` coordinate of these
// points (of which there is one, as the second point is either at infinity or has a matching `x`
// coordinate).
if bool::from(a.is_identity() | b.is_identity()) || (a == -b) {
let (x, _) = C::to_xy(if !bool::from(a.is_identity()) { a } else { b }).unwrap();
return Poly {
y_coefficients: vec![],
yx_coefficients: vec![],
x_coefficients: vec![C::FieldElement::ONE],
let one_is_identity = a_is_identity | b_is_identity;
let additive_inverses = a.ct_eq(&-b);
let one_is_identity_or_additive_inverses = one_is_identity | additive_inverses;
let if_one_is_identity_or_additive_inverses = {
// If both are identity, set `a` to the generator so we can safely evaluate the following
// (which we won't select at the end of this function)
let a = <_>::conditional_select(&a, &C::generator(), both_are_identity);
// If `a` is identity, this selects `b`. If `a` isn't identity, this selects `a`
let non_identity = <_>::conditional_select(&a, &b, a.is_identity());
let (x, _) = C::to_xy(non_identity).unwrap();
LinesRes {
y_coefficient: C::FieldElement::ZERO,
x_coefficient: C::FieldElement::ONE,
zero_coefficient: -x,
};
}
}
};
// The following calculation assumes neither point is the point at infinity
// If either are, we use a prior result
// To ensure we can calculcate a result here, set any points at infinity to the generator
let a = <_>::conditional_select(&a, &C::generator(), a_is_identity);
let b = <_>::conditional_select(&b, &C::generator(), b_is_identity);
// It also assumes a, b aren't additive inverses which is also covered by a prior result
let b = <_>::conditional_select(&b, &a.double(), additive_inverses);
// If the points are equal, we use the line interpolating the sum of these points with the point
// at infinity
if a == b {
b = -a.double();
}
let b = <_>::conditional_select(&b, &-a.double(), a.ct_eq(&b));
let (slope, intercept) = slope_intercept::<C>(a, b);
// Section 4 of the proofs explicitly state the line `L = y - lambda * x - mu`
// y - (slope * x) - intercept
Poly {
y_coefficients: vec![C::FieldElement::ONE],
yx_coefficients: vec![],
x_coefficients: vec![-slope],
let mut res = LinesRes {
y_coefficient: C::FieldElement::ONE,
x_coefficient: -slope,
zero_coefficient: -intercept,
};
res = <_>::conditional_select(
&res,
&if_one_is_identity_or_additive_inverses,
one_is_identity_or_additive_inverses,
);
res = <_>::conditional_select(&res, &if_both_are_identity, both_are_identity);
Poly {
y_coefficients: vec![res.y_coefficient],
yx_coefficients: vec![],
x_coefficients: vec![res.x_coefficient],
zero_coefficient: res.zero_coefficient,
}
}
@ -121,36 +167,65 @@ fn line<C: DivisorCurve>(a: C, mut b: C) -> Poly<C::FieldElement> {
/// - No points were passed in
/// - The points don't sum to the point at infinity
/// - A passed in point was the point at infinity
///
/// If the arguments were valid, this function executes in an amount of time constant to the amount
/// of points.
#[allow(clippy::new_ret_no_self)]
pub fn new_divisor<C: DivisorCurve>(points: &[C]) -> Option<Poly<C::FieldElement>> {
// A single point is either the point at infinity, or this doesn't sum to the point at infinity
// Both cause us to return None
if points.len() < 2 {
None?;
}
if points.iter().sum::<C>() != C::identity() {
// No points were passed in, this is the point at infinity, or the single point isn't infinity
// and accordingly doesn't sum to infinity. All three cause us to return None
// Checks a bit other than the first bit is set, meaning this is >= 2
let mut invalid_args = (points.len() & (!1)).ct_eq(&0);
// The points don't sum to the point at infinity
invalid_args |= !points.iter().sum::<C>().is_identity();
// A point was the point at identity
for point in points {
invalid_args |= point.is_identity();
}
if bool::from(invalid_args) {
None?;
}
let points_len = points.len();
// Create the initial set of divisors
let mut divs = vec![];
let mut iter = points.iter().copied();
while let Some(a) = iter.next() {
if a == C::identity() {
None?;
}
let b = iter.next();
if b == Some(C::identity()) {
None?;
}
// Draw the line between those points
divs.push((a + b.unwrap_or(C::identity()), line::<C>(a, b.unwrap_or(-a))));
// These unwraps are branching on the length of the iterator, not violating the constant-time
// priorites desired
divs.push((2, a + b.unwrap_or(C::identity()), line::<C>(a, b.unwrap_or(-a))));
}
let modulus = C::divisor_modulus();
// Our Poly algorithm is leaky and will create an excessive amount of y x**j and x**j
// coefficients which are zero, yet as our implementation is constant time, still come with
// an immense performance cost. This code truncates the coefficients we know are zero.
let trim = |divisor: &mut Poly<_>, points_len: usize| {
// We should only be trimming divisors reduced by the modulus
debug_assert!(divisor.yx_coefficients.len() <= 1);
if divisor.yx_coefficients.len() == 1 {
let truncate_to = ((points_len + 1) / 2).saturating_sub(2);
#[cfg(debug_assertions)]
for p in truncate_to .. divisor.yx_coefficients[0].len() {
debug_assert_eq!(divisor.yx_coefficients[0][p], <C::FieldElement as Field>::ZERO);
}
divisor.yx_coefficients[0].truncate(truncate_to);
}
{
let truncate_to = points_len / 2;
#[cfg(debug_assertions)]
for p in truncate_to .. divisor.x_coefficients.len() {
debug_assert_eq!(divisor.x_coefficients[p], <C::FieldElement as Field>::ZERO);
}
divisor.x_coefficients.truncate(truncate_to);
}
};
// Pair them off until only one remains
while divs.len() > 1 {
let mut next_divs = vec![];
@ -159,23 +234,208 @@ pub fn new_divisor<C: DivisorCurve>(points: &[C]) -> Option<Poly<C::FieldElement
next_divs.push(divs.pop().unwrap());
}
while let Some((a, a_div)) = divs.pop() {
let (b, b_div) = divs.pop().unwrap();
while let Some((a_points, a, a_div)) = divs.pop() {
let (b_points, b, b_div) = divs.pop().unwrap();
let points = a_points + b_points;
// Merge the two divisors
let numerator = a_div.mul_mod(b_div, &modulus).mul_mod(line::<C>(a, b), &modulus);
let denominator = line::<C>(a, -a).mul_mod(line::<C>(b, -b), &modulus);
let (q, r) = numerator.div_rem(&denominator);
assert_eq!(r, Poly::zero());
let numerator = a_div.mul_mod(&b_div, &modulus).mul_mod(&line::<C>(a, b), &modulus);
let denominator = line::<C>(a, -a).mul_mod(&line::<C>(b, -b), &modulus);
let (mut q, r) = numerator.div_rem(&denominator);
debug_assert_eq!(r, Poly::zero());
next_divs.push((a + b, q));
trim(&mut q, 1 + points);
next_divs.push((points, a + b, q));
}
divs = next_divs;
}
// Return the unified divisor
Some(divs.remove(0).1)
let mut divisor = divs.remove(0).2;
trim(&mut divisor, points_len);
Some(divisor)
}
/// The decomposition of a scalar.
///
/// The decomposition ($d$) of a scalar ($s$) has the following two properties:
///
/// - $\sum^{\mathsf{NUM_BITS} - 1}_{i=0} d_i * 2^i = s$
/// - $\sum^{\mathsf{NUM_BITS} - 1}_{i=0} d_i = \mathsf{NUM_BITS}$
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct ScalarDecomposition<F: Zeroize + PrimeFieldBits> {
scalar: F,
decomposition: Vec<u64>,
}
impl<F: Zeroize + PrimeFieldBits> ScalarDecomposition<F> {
/// Decompose a scalar.
pub fn new(scalar: F) -> Self {
/*
We need the sum of the coefficients to equal F::NUM_BITS. The scalar's bits will be less than
F::NUM_BITS. Accordingly, we need to increment the sum of the coefficients without
incrementing the scalar represented. We do this by finding the highest non-0 coefficient,
decrementing it, and increasing the immediately less significant coefficient by 2. This
increases the sum of the coefficients by 1 (-1+2=1).
*/
let num_bits = u64::from(F::NUM_BITS);
// Obtain the bits of the scalar
let num_bits_usize = usize::try_from(num_bits).unwrap();
let mut decomposition = vec![0; num_bits_usize];
for (i, bit) in scalar.to_le_bits().into_iter().take(num_bits_usize).enumerate() {
let bit = u64::from(u8::from(bit));
decomposition[i] = bit;
}
// The following algorithm only works if the value of the scalar exceeds num_bits
// If it isn't, we increase it by the modulus such that it does exceed num_bits
{
let mut less_than_num_bits = Choice::from(0);
for i in 0 .. num_bits {
less_than_num_bits |= scalar.ct_eq(&F::from(i));
}
let mut decomposition_of_modulus = vec![0; num_bits_usize];
// Decompose negative one
for (i, bit) in (-F::ONE).to_le_bits().into_iter().take(num_bits_usize).enumerate() {
let bit = u64::from(u8::from(bit));
decomposition_of_modulus[i] = bit;
}
// Increment it by one
decomposition_of_modulus[0] += 1;
// Add the decomposition onto the decomposition of the modulus
for i in 0 .. num_bits_usize {
let new_decomposition = <_>::conditional_select(
&decomposition[i],
&(decomposition[i] + decomposition_of_modulus[i]),
less_than_num_bits,
);
decomposition[i] = new_decomposition;
}
}
// Calculcate the sum of the coefficients
let mut sum_of_coefficients: u64 = 0;
for decomposition in &decomposition {
sum_of_coefficients += *decomposition;
}
/*
Now, because we added a log2(k)-bit number to a k-bit number, we may have our sum of
coefficients be *too high*. We attempt to reduce the sum of the coefficients accordingly.
This algorithm is guaranteed to complete as expected. Take the sequence `222`. `222` becomes
`032` becomes `013`. Even if the next coefficient in the sequence is `2`, the third
coefficient will be reduced once and the next coefficient (`2`, increased to `3`) will only
be eligible for reduction once. This demonstrates, even for a worst case of log2(k) `2`s
followed by `1`s (as possible if the modulus is a Mersenne prime), the log2(k) `2`s can be
reduced as necessary so long as there is a single coefficient after (requiring the entire
sequence be at least of length log2(k) + 1). For a 2-bit number, log2(k) + 1 == 2, so this
holds for any odd prime field.
To fully type out the demonstration for the Mersenne prime 3, with scalar to encode 1 (the
highest value less than the number of bits):
10 - Little-endian bits of 1
21 - Little-endian bits of 1, plus the modulus
02 - After one reduction, where the sum of the coefficients does in fact equal 2 (the target)
*/
{
let mut log2_num_bits = 0;
while (1 << log2_num_bits) < num_bits {
log2_num_bits += 1;
}
for _ in 0 .. log2_num_bits {
// If the sum of coefficients is the amount of bits, we're done
let mut done = sum_of_coefficients.ct_eq(&num_bits);
for i in 0 .. (num_bits_usize - 1) {
let should_act = (!done) & decomposition[i].ct_gt(&1);
// Subtract 2 from this coefficient
let amount_to_sub = <_>::conditional_select(&0, &2, should_act);
decomposition[i] -= amount_to_sub;
// Add 1 to the next coefficient
let amount_to_add = <_>::conditional_select(&0, &1, should_act);
decomposition[i + 1] += amount_to_add;
// Also update the sum of coefficients
sum_of_coefficients -= <_>::conditional_select(&0, &1, should_act);
// If we updated the coefficients this loop iter, we're done for this loop iter
done |= should_act;
}
}
}
for _ in 0 .. num_bits {
// If the sum of coefficients is the amount of bits, we're done
let mut done = sum_of_coefficients.ct_eq(&num_bits);
// Find the highest coefficient currently non-zero
for i in (1 .. decomposition.len()).rev() {
// If this is non-zero, we should decrement this coefficient if we haven't already
// decremented a coefficient this round
let is_non_zero = !(0.ct_eq(&decomposition[i]));
let should_act = (!done) & is_non_zero;
// Update this coefficient and the prior coefficient
let amount_to_sub = <_>::conditional_select(&0, &1, should_act);
decomposition[i] -= amount_to_sub;
let amount_to_add = <_>::conditional_select(&0, &2, should_act);
// i must be at least 1, so i - 1 will be at least 0 (meaning it's safe to index with)
decomposition[i - 1] += amount_to_add;
// Also update the sum of coefficients
sum_of_coefficients += <_>::conditional_select(&0, &1, should_act);
// If we updated the coefficients this loop iter, we're done for this loop iter
done |= should_act;
}
}
debug_assert!(bool::from(decomposition.iter().sum::<u64>().ct_eq(&num_bits)));
ScalarDecomposition { scalar, decomposition }
}
/// The decomposition of the scalar.
pub fn decomposition(&self) -> &[u64] {
&self.decomposition
}
/// A divisor to prove a scalar multiplication.
///
/// The divisor will interpolate $d_i$ instances of $2^i \cdot G$ with $-(s \cdot G)$.
///
/// This function executes in constant time with regards to the scalar.
///
/// This function MAY panic if this scalar is zero.
pub fn scalar_mul_divisor<C: Zeroize + DivisorCurve<Scalar = F>>(
&self,
mut generator: C,
) -> Poly<C::FieldElement> {
// The following for loop is constant time to the sum of `dlog`'s elements
let mut divisor_points =
Vec::with_capacity(usize::try_from(<C::Scalar as PrimeField>::NUM_BITS).unwrap());
divisor_points.push(-generator * self.scalar);
for coefficient in &self.decomposition {
let mut coefficient = *coefficient;
while coefficient != 0 {
coefficient -= 1;
divisor_points.push(generator);
}
generator = generator.double();
}
let res = new_divisor(&divisor_points).unwrap();
divisor_points.zeroize();
res
}
}
#[cfg(any(test, feature = "pasta"))]

View file

@ -1,25 +1,112 @@
use core::ops::{Add, Neg, Sub, Mul, Rem};
use zeroize::Zeroize;
use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConditionallySelectable};
use zeroize::{Zeroize, ZeroizeOnDrop};
use group::ff::PrimeField;
/// A structure representing a Polynomial with x**i, y**i, and y**i * x**j terms.
#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
pub struct Poly<F: PrimeField + From<u64>> {
/// c[i] * y ** (i + 1)
#[derive(Clone, Copy, PartialEq, Debug)]
struct CoefficientIndex {
y_pow: u64,
x_pow: u64,
}
impl ConditionallySelectable for CoefficientIndex {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
Self {
y_pow: <_>::conditional_select(&a.y_pow, &b.y_pow, choice),
x_pow: <_>::conditional_select(&a.x_pow, &b.x_pow, choice),
}
}
}
impl ConstantTimeEq for CoefficientIndex {
fn ct_eq(&self, other: &Self) -> Choice {
self.y_pow.ct_eq(&other.y_pow) & self.x_pow.ct_eq(&other.x_pow)
}
}
impl ConstantTimeGreater for CoefficientIndex {
fn ct_gt(&self, other: &Self) -> Choice {
self.y_pow.ct_gt(&other.y_pow) |
(self.y_pow.ct_eq(&other.y_pow) & self.x_pow.ct_gt(&other.x_pow))
}
}
/// A structure representing a Polynomial with x^i, y^i, and y^i * x^j terms.
#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
pub struct Poly<F: From<u64> + Zeroize + PrimeField> {
/// c\[i] * y^(i + 1)
pub y_coefficients: Vec<F>,
/// c[i][j] * y ** (i + 1) x ** (j + 1)
/// c\[i]\[j] * y^(i + 1) x^(j + 1)
pub yx_coefficients: Vec<Vec<F>>,
/// c[i] * x ** (i + 1)
/// c\[i] * x^(i + 1)
pub x_coefficients: Vec<F>,
/// Coefficient for x ** 0, y ** 0, and x ** 0 y ** 0 (the coefficient for 1)
/// Coefficient for x^0, y^0, and x^0 y^0 (the coefficient for 1)
pub zero_coefficient: F,
}
impl<F: PrimeField + From<u64>> Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> PartialEq for Poly<F> {
// This is not constant time and is not meant to be
fn eq(&self, b: &Poly<F>) -> bool {
{
let mutual_y_coefficients = self.y_coefficients.len().min(b.y_coefficients.len());
if self.y_coefficients[.. mutual_y_coefficients] != b.y_coefficients[.. mutual_y_coefficients]
{
return false;
}
for coeff in &self.y_coefficients[mutual_y_coefficients ..] {
if *coeff != F::ZERO {
return false;
}
}
for coeff in &b.y_coefficients[mutual_y_coefficients ..] {
if *coeff != F::ZERO {
return false;
}
}
}
{
for (i, yx_coeffs) in self.yx_coefficients.iter().enumerate() {
for (j, coeff) in yx_coeffs.iter().enumerate() {
if coeff != b.yx_coefficients.get(i).unwrap_or(&vec![]).get(j).unwrap_or(&F::ZERO) {
return false;
}
}
}
// Run from the other perspective in case other is longer than self
for (i, yx_coeffs) in b.yx_coefficients.iter().enumerate() {
for (j, coeff) in yx_coeffs.iter().enumerate() {
if coeff != self.yx_coefficients.get(i).unwrap_or(&vec![]).get(j).unwrap_or(&F::ZERO) {
return false;
}
}
}
}
{
let mutual_x_coefficients = self.x_coefficients.len().min(b.x_coefficients.len());
if self.x_coefficients[.. mutual_x_coefficients] != b.x_coefficients[.. mutual_x_coefficients]
{
return false;
}
for coeff in &self.x_coefficients[mutual_x_coefficients ..] {
if *coeff != F::ZERO {
return false;
}
}
for coeff in &b.x_coefficients[mutual_x_coefficients ..] {
if *coeff != F::ZERO {
return false;
}
}
}
self.zero_coefficient == b.zero_coefficient
}
}
impl<F: From<u64> + Zeroize + PrimeField> Poly<F> {
/// A polynomial for zero.
pub fn zero() -> Self {
pub(crate) fn zero() -> Self {
Poly {
y_coefficients: vec![],
yx_coefficients: vec![],
@ -27,37 +114,9 @@ impl<F: PrimeField + From<u64>> Poly<F> {
zero_coefficient: F::ZERO,
}
}
/// The amount of terms in the polynomial.
#[allow(clippy::len_without_is_empty)]
#[must_use]
pub fn len(&self) -> usize {
self.y_coefficients.len() +
self.yx_coefficients.iter().map(Vec::len).sum::<usize>() +
self.x_coefficients.len() +
usize::from(u8::from(self.zero_coefficient != F::ZERO))
}
// Remove high-order zero terms, allowing the length of the vectors to equal the amount of terms.
pub(crate) fn tidy(&mut self) {
let tidy = |vec: &mut Vec<F>| {
while vec.last() == Some(&F::ZERO) {
vec.pop();
}
};
tidy(&mut self.y_coefficients);
for vec in self.yx_coefficients.iter_mut() {
tidy(vec);
}
while self.yx_coefficients.last() == Some(&vec![]) {
self.yx_coefficients.pop();
}
tidy(&mut self.x_coefficients);
}
}
impl<F: PrimeField + From<u64>> Add<&Self> for Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Add<&Self> for Poly<F> {
type Output = Self;
fn add(mut self, other: &Self) -> Self {
@ -91,12 +150,11 @@ impl<F: PrimeField + From<u64>> Add<&Self> for Poly<F> {
}
self.zero_coefficient += other.zero_coefficient;
self.tidy();
self
}
}
impl<F: PrimeField + From<u64>> Neg for Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Neg for Poly<F> {
type Output = Self;
fn neg(mut self) -> Self {
@ -117,7 +175,7 @@ impl<F: PrimeField + From<u64>> Neg for Poly<F> {
}
}
impl<F: PrimeField + From<u64>> Sub for Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Sub for Poly<F> {
type Output = Self;
fn sub(self, other: Self) -> Self {
@ -125,14 +183,10 @@ impl<F: PrimeField + From<u64>> Sub for Poly<F> {
}
}
impl<F: PrimeField + From<u64>> Mul<F> for Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Mul<F> for Poly<F> {
type Output = Self;
fn mul(mut self, scalar: F) -> Self {
if scalar == F::ZERO {
return Poly::zero();
}
for y_coeff in self.y_coefficients.iter_mut() {
*y_coeff *= scalar;
}
@ -149,7 +203,7 @@ impl<F: PrimeField + From<u64>> Mul<F> for Poly<F> {
}
}
impl<F: PrimeField + From<u64>> Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Poly<F> {
#[must_use]
fn shift_by_x(mut self, power_of_x: usize) -> Self {
if power_of_x == 0 {
@ -203,17 +257,17 @@ impl<F: PrimeField + From<u64>> Poly<F> {
self.zero_coefficient = F::ZERO;
// Move the x coefficients
self.yx_coefficients[power_of_y - 1] = self.x_coefficients;
std::mem::swap(&mut self.yx_coefficients[power_of_y - 1], &mut self.x_coefficients);
self.x_coefficients = vec![];
self
}
}
impl<F: PrimeField + From<u64>> Mul for Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Mul<&Poly<F>> for Poly<F> {
type Output = Self;
fn mul(self, other: Self) -> Self {
fn mul(self, other: &Self) -> Self {
let mut res = self.clone() * other.zero_coefficient;
for (i, y_coeff) in other.y_coefficients.iter().enumerate() {
@ -233,94 +287,320 @@ impl<F: PrimeField + From<u64>> Mul for Poly<F> {
res = res + &scaled.shift_by_x(i + 1);
}
res.tidy();
res
}
}
impl<F: PrimeField + From<u64>> Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Poly<F> {
// The leading y coefficient and associated x coefficient.
fn leading_coefficient(&self) -> (usize, usize) {
if self.y_coefficients.len() > self.yx_coefficients.len() {
(self.y_coefficients.len(), 0)
} else if !self.yx_coefficients.is_empty() {
(self.yx_coefficients.len(), self.yx_coefficients.last().unwrap().len())
} else {
(0, self.x_coefficients.len())
}
}
/// Returns the highest non-zero coefficient greater than the specified coefficient.
///
/// If no non-zero coefficient is greater than the specified coefficient, this will return
/// (0, 0).
fn greater_than_or_equal_coefficient(
&self,
greater_than_or_equal: &CoefficientIndex,
) -> CoefficientIndex {
let mut leading_coefficient = CoefficientIndex { y_pow: 0, x_pow: 0 };
for (y_pow_sub_one, coeff) in self.y_coefficients.iter().enumerate() {
let y_pow = u64::try_from(y_pow_sub_one + 1).unwrap();
let coeff_is_non_zero = !coeff.is_zero();
let potential = CoefficientIndex { y_pow, x_pow: 0 };
leading_coefficient = <_>::conditional_select(
&leading_coefficient,
&potential,
coeff_is_non_zero &
potential.ct_gt(&leading_coefficient) &
(potential.ct_gt(greater_than_or_equal) | potential.ct_eq(greater_than_or_equal)),
);
}
for (y_pow_sub_one, yx_coefficients) in self.yx_coefficients.iter().enumerate() {
let y_pow = u64::try_from(y_pow_sub_one + 1).unwrap();
for (x_pow_sub_one, coeff) in yx_coefficients.iter().enumerate() {
let x_pow = u64::try_from(x_pow_sub_one + 1).unwrap();
let coeff_is_non_zero = !coeff.is_zero();
let potential = CoefficientIndex { y_pow, x_pow };
leading_coefficient = <_>::conditional_select(
&leading_coefficient,
&potential,
coeff_is_non_zero &
potential.ct_gt(&leading_coefficient) &
(potential.ct_gt(greater_than_or_equal) | potential.ct_eq(greater_than_or_equal)),
);
}
}
for (x_pow_sub_one, coeff) in self.x_coefficients.iter().enumerate() {
let x_pow = u64::try_from(x_pow_sub_one + 1).unwrap();
let coeff_is_non_zero = !coeff.is_zero();
let potential = CoefficientIndex { y_pow: 0, x_pow };
leading_coefficient = <_>::conditional_select(
&leading_coefficient,
&potential,
coeff_is_non_zero &
potential.ct_gt(&leading_coefficient) &
(potential.ct_gt(greater_than_or_equal) | potential.ct_eq(greater_than_or_equal)),
);
}
leading_coefficient
}
/// Perform multiplication mod `modulus`.
#[must_use]
pub fn mul_mod(self, other: Self, modulus: &Self) -> Self {
((self % modulus) * (other % modulus)) % modulus
pub(crate) fn mul_mod(self, other: &Self, modulus: &Self) -> Self {
(self * other) % modulus
}
/// Perform division, returning the result and remainder.
///
/// Panics upon division by zero, with undefined behavior if a non-tidy divisor is used.
/// This function is constant time to the structure of the numerator and denominator. The actual
/// value of the coefficients will not introduce timing differences.
///
/// Panics upon division by a polynomial where all coefficients are zero.
#[must_use]
pub fn div_rem(self, divisor: &Self) -> (Self, Self) {
// The leading y coefficient and associated x coefficient.
let leading_y = |poly: &Self| -> (_, _) {
if poly.y_coefficients.len() > poly.yx_coefficients.len() {
(poly.y_coefficients.len(), 0)
} else if !poly.yx_coefficients.is_empty() {
(poly.yx_coefficients.len(), poly.yx_coefficients.last().unwrap().len())
} else {
(0, poly.x_coefficients.len())
pub(crate) fn div_rem(self, denominator: &Self) -> (Self, Self) {
// These functions have undefined behavior if this isn't a valid index for this poly
fn ct_get<F: From<u64> + Zeroize + PrimeField>(
poly: &Poly<F>,
index: CoefficientIndex,
) -> F {
let mut res = poly.zero_coefficient;
for (y_pow_sub_one, coeff) in poly.y_coefficients.iter().enumerate() {
res = <_>::conditional_select(&res, coeff, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: 0 }));
}
};
let (div_y, div_x) = leading_y(divisor);
// If this divisor is actually a scalar, don't perform long division
if (div_y == 0) && (div_x == 0) {
return (self * divisor.zero_coefficient.invert().unwrap(), Poly::zero());
for (y_pow_sub_one, coeffs) in poly.yx_coefficients.iter().enumerate() {
for (x_pow_sub_one, coeff) in coeffs.iter().enumerate() {
res = <_>::conditional_select(&res, coeff, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: (x_pow_sub_one + 1).try_into().unwrap() }));
}
}
for (x_pow_sub_one, coeff) in poly.x_coefficients.iter().enumerate() {
res = <_>::conditional_select(&res, coeff, index.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: (x_pow_sub_one + 1).try_into().unwrap() }));
}
res
}
// Remove leading terms until the value is less than the divisor
let mut quotient: Poly<F> = Poly::zero();
let mut remainder = self.clone();
loop {
// If there's nothing left to divide, return
if remainder == Poly::zero() {
break;
fn ct_set<F: From<u64> + Zeroize + PrimeField>(
poly: &mut Poly<F>,
index: CoefficientIndex,
value: F,
) {
for (y_pow_sub_one, coeff) in poly.y_coefficients.iter_mut().enumerate() {
*coeff = <_>::conditional_select(coeff, &value, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: 0 }));
}
let (rem_y, rem_x) = leading_y(&remainder);
if (rem_y < div_y) || (rem_x < div_x) {
break;
for (y_pow_sub_one, coeffs) in poly.yx_coefficients.iter_mut().enumerate() {
for (x_pow_sub_one, coeff) in coeffs.iter_mut().enumerate() {
*coeff = <_>::conditional_select(coeff, &value, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: (x_pow_sub_one + 1).try_into().unwrap() }));
}
}
for (x_pow_sub_one, coeff) in poly.x_coefficients.iter_mut().enumerate() {
*coeff = <_>::conditional_select(coeff, &value, index.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: (x_pow_sub_one + 1).try_into().unwrap() }));
}
poly.zero_coefficient = <_>::conditional_select(&poly.zero_coefficient, &value, index.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: 0 }));
}
let get = |poly: &Poly<F>, y_pow: usize, x_pow: usize| -> F {
if (y_pow == 0) && (x_pow == 0) {
poly.zero_coefficient
} else if x_pow == 0 {
poly.y_coefficients[y_pow - 1]
} else if y_pow == 0 {
poly.x_coefficients[x_pow - 1]
} else {
poly.yx_coefficients[y_pow - 1][x_pow - 1]
fn conditional_select_poly<F: From<u64> + Zeroize + PrimeField>(
mut a: Poly<F>,
mut b: Poly<F>,
choice: Choice,
) -> Poly<F> {
let pad_to = |a: &mut Poly<F>, b: &Poly<F>| {
while a.x_coefficients.len() < b.x_coefficients.len() {
a.x_coefficients.push(F::ZERO);
}
while a.yx_coefficients.len() < b.yx_coefficients.len() {
a.yx_coefficients.push(vec![]);
}
for (a, b) in a.yx_coefficients.iter_mut().zip(&b.yx_coefficients) {
while a.len() < b.len() {
a.push(F::ZERO);
}
}
while a.y_coefficients.len() < b.y_coefficients.len() {
a.y_coefficients.push(F::ZERO);
}
};
let coeff_numerator = get(&remainder, rem_y, rem_x);
let coeff_denominator = get(divisor, div_y, div_x);
// Pad these to be the same size/layout as each other
pad_to(&mut a, &b);
pad_to(&mut b, &a);
// We want coeff_denominator scaled by x to equal coeff_numerator
// x * d = n
// n / d = x
let mut quotient_term = Poly::zero();
// Because this is the coefficient for the leading term of a tidied polynomial, it must be
// non-zero
quotient_term.zero_coefficient = coeff_numerator * coeff_denominator.invert().unwrap();
let mut res = Poly::zero();
for (a, b) in a.y_coefficients.iter().zip(&b.y_coefficients) {
res.y_coefficients.push(<_>::conditional_select(a, b, choice));
}
for (a, b) in a.yx_coefficients.iter().zip(&b.yx_coefficients) {
let mut yx_coefficients = Vec::with_capacity(a.len());
for (a, b) in a.iter().zip(b) {
yx_coefficients.push(<_>::conditional_select(a, b, choice))
}
res.yx_coefficients.push(yx_coefficients);
}
for (a, b) in a.x_coefficients.iter().zip(&b.x_coefficients) {
res.x_coefficients.push(<_>::conditional_select(a, b, choice));
}
res.zero_coefficient = <_>::conditional_select(&a.zero_coefficient, &b.zero_coefficient, choice);
// Add the necessary yx powers
let delta_y = rem_y - div_y;
let delta_x = rem_x - div_x;
let quotient_term = quotient_term.shift_by_y(delta_y).shift_by_x(delta_x);
let to_remove = quotient_term.clone() * divisor.clone();
debug_assert_eq!(get(&to_remove, rem_y, rem_x), coeff_numerator);
remainder = remainder - to_remove;
quotient = quotient + &quotient_term;
res
}
// The following long division algorithm only works if the denominator actually has a variable
// If the denominator isn't variable to anything, short-circuit to scalar 'division'
// This is safe as `leading_coefficient` is based on the structure, not the values, of the poly
let denominator_leading_coefficient = denominator.leading_coefficient();
if denominator_leading_coefficient == (0, 0) {
return (self * denominator.zero_coefficient.invert().unwrap(), Poly::zero());
}
// The structure of the quotient, which is the the numerator with all coefficients set to 0
let mut quotient_structure = Poly {
y_coefficients: vec![F::ZERO; self.y_coefficients.len()],
yx_coefficients: self.yx_coefficients.clone(),
x_coefficients: vec![F::ZERO; self.x_coefficients.len()],
zero_coefficient: F::ZERO,
};
for coeff in quotient_structure
.yx_coefficients
.iter_mut()
.flat_map(|yx_coefficients| yx_coefficients.iter_mut())
{
*coeff = F::ZERO;
}
// Calculate the amount of iterations we need to perform
let iterations = self.y_coefficients.len() +
self.yx_coefficients.iter().map(|yx_coefficients| yx_coefficients.len()).sum::<usize>() +
self.x_coefficients.len();
// Find the highest non-zero coefficient in the denominator
// This is the coefficient which we actually perform division with
let denominator_dividing_coefficient =
denominator.greater_than_or_equal_coefficient(&CoefficientIndex { y_pow: 0, x_pow: 0 });
let denominator_dividing_coefficient_inv =
ct_get(denominator, denominator_dividing_coefficient).invert().unwrap();
let mut quotient = quotient_structure.clone();
let mut remainder = self.clone();
for _ in 0 .. iterations {
// Find the numerator coefficient we're clearing
// This will be (0, 0) if we aren't clearing a coefficient
let numerator_coefficient =
remainder.greater_than_or_equal_coefficient(&denominator_dividing_coefficient);
// We only apply the effects of this iteration if the numerator's coefficient is actually >=
let meaningful_iteration = numerator_coefficient.ct_gt(&denominator_dividing_coefficient) |
numerator_coefficient.ct_eq(&denominator_dividing_coefficient);
// 1) Find the scalar `q` such that the leading coefficient of `q * denominator` is equal to
// the leading coefficient of self.
let numerator_coefficient_value = ct_get(&remainder, numerator_coefficient);
let q = numerator_coefficient_value * denominator_dividing_coefficient_inv;
// 2) Calculate the full term of the quotient by scaling with the necessary powers of y/x
let proper_powers_of_yx = CoefficientIndex {
y_pow: numerator_coefficient.y_pow.wrapping_sub(denominator_dividing_coefficient.y_pow),
x_pow: numerator_coefficient.x_pow.wrapping_sub(denominator_dividing_coefficient.x_pow),
};
let fallabck_powers_of_yx = CoefficientIndex { y_pow: 0, x_pow: 0 };
let mut quotient_term = quotient_structure.clone();
ct_set(
&mut quotient_term,
// If the numerator coefficient isn't >=, proper_powers_of_yx will have garbage in them
<_>::conditional_select(&fallabck_powers_of_yx, &proper_powers_of_yx, meaningful_iteration),
q,
);
let quotient_if_meaningful = quotient.clone() + &quotient_term;
quotient = conditional_select_poly(quotient, quotient_if_meaningful, meaningful_iteration);
// 3) Remove what we've divided out from self
let remainder_if_meaningful = remainder.clone() - (quotient_term * denominator);
remainder =
conditional_select_poly(remainder, remainder_if_meaningful, meaningful_iteration);
}
quotient = conditional_select_poly(
quotient,
// If the dividing coefficient was for y**0 x**0, we return the poly scaled by its inverse
self.clone() * denominator_dividing_coefficient_inv,
denominator_dividing_coefficient.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: 0 }),
);
remainder = conditional_select_poly(
remainder,
// If the dividing coefficient was for y**0 x**0, we're able to perfectly divide and there's
// no remainder
Poly::zero(),
denominator_dividing_coefficient.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: 0 }),
);
// Clear any junk terms out of the remainder which are less than the denominator
let denominator_leading_coefficient = CoefficientIndex {
y_pow: denominator_leading_coefficient.0.try_into().unwrap(),
x_pow: denominator_leading_coefficient.1.try_into().unwrap(),
};
if denominator_leading_coefficient != (CoefficientIndex { y_pow: 0, x_pow: 0 }) {
while {
let index =
CoefficientIndex { y_pow: remainder.y_coefficients.len().try_into().unwrap(), x_pow: 0 };
bool::from(
index.ct_gt(&denominator_leading_coefficient) |
index.ct_eq(&denominator_leading_coefficient),
)
} {
let popped = remainder.y_coefficients.pop();
debug_assert_eq!(popped, Some(F::ZERO));
}
while {
let index = CoefficientIndex {
y_pow: remainder.yx_coefficients.len().try_into().unwrap(),
x_pow: remainder
.yx_coefficients
.last()
.map(|yx_coefficients| yx_coefficients.len())
.unwrap_or(0)
.try_into()
.unwrap(),
};
bool::from(
index.ct_gt(&denominator_leading_coefficient) |
index.ct_eq(&denominator_leading_coefficient),
)
} {
let popped = remainder.yx_coefficients.last_mut().unwrap().pop();
// This may have been `vec![]`
if let Some(popped) = popped {
debug_assert_eq!(popped, F::ZERO);
}
if remainder.yx_coefficients.last().unwrap().is_empty() {
let popped = remainder.yx_coefficients.pop();
debug_assert_eq!(popped, Some(vec![]));
}
}
while {
let index =
CoefficientIndex { y_pow: 0, x_pow: remainder.x_coefficients.len().try_into().unwrap() };
bool::from(
index.ct_gt(&denominator_leading_coefficient) |
index.ct_eq(&denominator_leading_coefficient),
)
} {
let popped = remainder.x_coefficients.pop();
debug_assert_eq!(popped, Some(F::ZERO));
}
}
debug_assert_eq!((quotient.clone() * divisor.clone()) + &remainder, self);
(quotient, remainder)
}
}
impl<F: PrimeField + From<u64>> Rem<&Self> for Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Rem<&Self> for Poly<F> {
type Output = Self;
fn rem(self, modulus: &Self) -> Self {
@ -328,10 +608,10 @@ impl<F: PrimeField + From<u64>> Rem<&Self> for Poly<F> {
}
}
impl<F: PrimeField + From<u64>> Poly<F> {
impl<F: From<u64> + Zeroize + PrimeField> Poly<F> {
/// Evaluate this polynomial with the specified x/y values.
///
/// Panics on polynomials with terms whose powers exceed 2**64.
/// Panics on polynomials with terms whose powers exceed 2^64.
#[must_use]
pub fn eval(&self, x: F, y: F) -> F {
let mut res = self.zero_coefficient;
@ -358,14 +638,11 @@ impl<F: PrimeField + From<u64>> Poly<F> {
res
}
/// Differentiate a polynomial, reduced by a modulus with a leading y term y**2 x**0, by x and y.
/// Differentiate a polynomial, reduced by a modulus with a leading y term y^2 x^0, by x and y.
///
/// This function panics if a y**2 term is present within the polynomial.
/// This function has undefined behavior if unreduced.
#[must_use]
pub fn differentiate(&self) -> (Poly<F>, Poly<F>) {
assert!(self.y_coefficients.len() <= 1);
assert!(self.yx_coefficients.len() <= 1);
// Differentation by x practically involves:
// - Dropping everything without an x component
// - Shifting everything down a power of x
@ -391,17 +668,18 @@ impl<F: PrimeField + From<u64>> Poly<F> {
if !self.yx_coefficients.is_empty() {
let mut yx_coeffs = self.yx_coefficients[0].clone();
diff_x.y_coefficients = vec![yx_coeffs.remove(0)];
diff_x.yx_coefficients = vec![yx_coeffs];
if !yx_coeffs.is_empty() {
diff_x.y_coefficients = vec![yx_coeffs.remove(0)];
diff_x.yx_coefficients = vec![yx_coeffs];
let mut prior_x_power = F::from(2);
for yx_coeff in &mut diff_x.yx_coefficients[0] {
*yx_coeff *= prior_x_power;
prior_x_power += F::ONE;
let mut prior_x_power = F::from(2);
for yx_coeff in &mut diff_x.yx_coefficients[0] {
*yx_coeff *= prior_x_power;
prior_x_power += F::ONE;
}
}
}
diff_x.tidy();
diff_x
};

View file

@ -6,6 +6,8 @@ use pasta_curves::{Ep, Eq};
use crate::{DivisorCurve, Poly, new_divisor};
mod poly;
// Equation 4 in the security proofs
fn check_divisor<C: DivisorCurve>(points: Vec<C>) {
// Create the divisor
@ -184,16 +186,16 @@ fn test_subset_sum_to_infinity<C: DivisorCurve>() {
#[test]
fn test_divisor_pallas() {
test_divisor::<Ep>();
test_same_point::<Ep>();
test_subset_sum_to_infinity::<Ep>();
test_divisor::<Ep>();
}
#[test]
fn test_divisor_vesta() {
test_divisor::<Eq>();
test_same_point::<Eq>();
test_subset_sum_to_infinity::<Eq>();
test_divisor::<Eq>();
}
#[test]
@ -229,7 +231,7 @@ fn test_divisor_ed25519() {
}
}
test_divisor::<EdwardsPoint>();
test_same_point::<EdwardsPoint>();
test_subset_sum_to_infinity::<EdwardsPoint>();
test_divisor::<EdwardsPoint>();
}

View file

@ -1,3 +1,5 @@
use rand_core::OsRng;
use group::ff::Field;
use pasta_curves::Ep;
@ -16,7 +18,24 @@ fn test_poly() {
let mut modulus = Poly::zero();
modulus.y_coefficients = vec![one];
assert_eq!(poly % &modulus, Poly::zero());
assert_eq!(
poly.clone().div_rem(&modulus).0,
Poly {
y_coefficients: vec![one],
yx_coefficients: vec![],
x_coefficients: vec![],
zero_coefficient: zero
}
);
assert_eq!(
poly % &modulus,
Poly {
y_coefficients: vec![],
yx_coefficients: vec![],
x_coefficients: vec![],
zero_coefficient: zero
}
);
}
{
@ -25,7 +44,7 @@ fn test_poly() {
let mut squared = Poly::zero();
squared.y_coefficients = vec![zero, zero, zero, one];
assert_eq!(poly.clone() * poly.clone(), squared);
assert_eq!(poly.clone() * &poly, squared);
}
{
@ -37,18 +56,18 @@ fn test_poly() {
let mut res = Poly::zero();
res.zero_coefficient = F::from(6u64);
assert_eq!(a.clone() * b.clone(), res);
assert_eq!(a.clone() * &b, res);
b.y_coefficients = vec![F::from(4u64)];
res.y_coefficients = vec![F::from(8u64)];
assert_eq!(a.clone() * b.clone(), res);
assert_eq!(b.clone() * a.clone(), res);
assert_eq!(a.clone() * &b, res);
assert_eq!(b.clone() * &a, res);
a.x_coefficients = vec![F::from(5u64)];
res.x_coefficients = vec![F::from(15u64)];
res.yx_coefficients = vec![vec![F::from(20u64)]];
assert_eq!(a.clone() * b.clone(), res);
assert_eq!(b * a.clone(), res);
assert_eq!(a.clone() * &b, res);
assert_eq!(b * &a, res);
// res is now 20xy + 8*y + 15*x + 6
// res ** 2 =
@ -60,7 +79,7 @@ fn test_poly() {
vec![vec![F::from(480u64), F::from(600u64)], vec![F::from(320u64), F::from(400u64)]];
squared.x_coefficients = vec![F::from(180u64), F::from(225u64)];
squared.zero_coefficient = F::from(36u64);
assert_eq!(res.clone() * res, squared);
assert_eq!(res.clone() * &res, squared);
}
}