mirror of
https://github.com/serai-dex/serai.git
synced 2025-01-31 15:05:58 +00:00
381 lines
8.8 KiB
Rust
381 lines
8.8 KiB
Rust
|
use core::{
|
||
|
ops::{Add, AddAssign, Neg, Sub, SubAssign, Mul, MulAssign},
|
||
|
iter::Sum,
|
||
|
};
|
||
|
|
||
|
use lazy_static::lazy_static;
|
||
|
|
||
|
use rand_core::RngCore;
|
||
|
|
||
|
use zeroize::Zeroize;
|
||
|
use subtle::{Choice, CtOption, ConstantTimeEq, ConditionallySelectable, ConditionallyNegatable};
|
||
|
|
||
|
use ff::{Field, PrimeField, PrimeFieldBits};
|
||
|
use group::{Group, GroupEncoding, prime::PrimeGroup};
|
||
|
|
||
|
use crate::{
|
||
|
scalar::{Scalar, MODULUS as SCALAR_MODULUS},
|
||
|
field::{FieldElement, Q_4},
|
||
|
};
|
||
|
|
||
|
lazy_static! {
|
||
|
static ref D: FieldElement = -FieldElement::from(39081u16);
|
||
|
}
|
||
|
|
||
|
fn recover_x(y: FieldElement) -> CtOption<FieldElement> {
|
||
|
let ysq = y.square();
|
||
|
#[allow(non_snake_case)]
|
||
|
let D_ysq = *D * ysq;
|
||
|
(D_ysq - FieldElement::one()).invert().and_then(|inverted| {
|
||
|
let temp = (ysq - FieldElement::one()) * inverted;
|
||
|
let mut x = temp.pow(*Q_4);
|
||
|
x.conditional_negate(x.is_odd());
|
||
|
|
||
|
let xsq = x.square();
|
||
|
CtOption::new(x, (xsq + ysq).ct_eq(&(FieldElement::one() + (xsq * D_ysq))))
|
||
|
})
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Copy, Debug, Zeroize)]
|
||
|
pub struct Point {
|
||
|
x: FieldElement,
|
||
|
y: FieldElement,
|
||
|
z: FieldElement,
|
||
|
}
|
||
|
|
||
|
#[rustfmt::skip]
|
||
|
lazy_static! {
|
||
|
static ref G_Y: FieldElement = FieldElement::from_repr(
|
||
|
hex_literal::hex!(
|
||
|
"14fa30f25b790898adc8d74e2c13bdfdc4397ce61cffd33ad7c2a0051e9c78874098a36c7373ea4b62c7c9563720768824bcb66e71463f6900"
|
||
|
)
|
||
|
.into()
|
||
|
)
|
||
|
.unwrap();
|
||
|
|
||
|
static ref G: Point = Point { x: recover_x(*G_Y).unwrap(), y: *G_Y, z: FieldElement::one() };
|
||
|
}
|
||
|
|
||
|
impl ConstantTimeEq for Point {
|
||
|
fn ct_eq(&self, other: &Self) -> Choice {
|
||
|
let x1 = self.x * other.z;
|
||
|
let x2 = other.x * self.z;
|
||
|
|
||
|
let y1 = self.y * other.z;
|
||
|
let y2 = other.y * self.z;
|
||
|
|
||
|
x1.ct_eq(&x2) & y1.ct_eq(&y2)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl PartialEq for Point {
|
||
|
fn eq(&self, other: &Point) -> bool {
|
||
|
self.ct_eq(other).into()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Eq for Point {}
|
||
|
|
||
|
impl ConditionallySelectable for Point {
|
||
|
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
|
||
|
Point {
|
||
|
x: FieldElement::conditional_select(&a.x, &b.x, choice),
|
||
|
y: FieldElement::conditional_select(&a.y, &b.y, choice),
|
||
|
z: FieldElement::conditional_select(&a.z, &b.z, choice),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Add for Point {
|
||
|
type Output = Point;
|
||
|
fn add(self, other: Self) -> Self {
|
||
|
// 12 muls, 7 additions, 4 negations
|
||
|
let xcp = self.x * other.x;
|
||
|
let ycp = self.y * other.y;
|
||
|
let zcp = self.z * other.z;
|
||
|
#[allow(non_snake_case)]
|
||
|
let B = zcp.square();
|
||
|
#[allow(non_snake_case)]
|
||
|
let E = *D * xcp * ycp;
|
||
|
#[allow(non_snake_case)]
|
||
|
let F = B - E;
|
||
|
#[allow(non_snake_case)]
|
||
|
let G_ = B + E;
|
||
|
|
||
|
Point {
|
||
|
x: zcp * F * ((self.x + self.y) * (other.x + other.y) - xcp - ycp),
|
||
|
y: zcp * G_ * (ycp - xcp),
|
||
|
z: F * G_,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl AddAssign for Point {
|
||
|
fn add_assign(&mut self, other: Point) {
|
||
|
*self = *self + other;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Add<&Point> for Point {
|
||
|
type Output = Point;
|
||
|
fn add(self, other: &Point) -> Point {
|
||
|
self + *other
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl AddAssign<&Point> for Point {
|
||
|
fn add_assign(&mut self, other: &Point) {
|
||
|
*self += *other;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Neg for Point {
|
||
|
type Output = Point;
|
||
|
fn neg(self) -> Self {
|
||
|
Point { x: -self.x, y: self.y, z: self.z }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Sub for Point {
|
||
|
type Output = Point;
|
||
|
#[allow(clippy::suspicious_arithmetic_impl)]
|
||
|
fn sub(self, other: Self) -> Self {
|
||
|
self + other.neg()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl SubAssign for Point {
|
||
|
fn sub_assign(&mut self, other: Point) {
|
||
|
*self = *self - other;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Sub<&Point> for Point {
|
||
|
type Output = Point;
|
||
|
fn sub(self, other: &Point) -> Point {
|
||
|
self - *other
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl SubAssign<&Point> for Point {
|
||
|
fn sub_assign(&mut self, other: &Point) {
|
||
|
*self -= *other;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Group for Point {
|
||
|
type Scalar = Scalar;
|
||
|
// Ideally, this would be cryptographically secure, yet that's not a bound on the trait
|
||
|
// k256 also does this
|
||
|
fn random(rng: impl RngCore) -> Self {
|
||
|
Self::generator() * Scalar::random(rng)
|
||
|
}
|
||
|
fn identity() -> Self {
|
||
|
Point { x: FieldElement::zero(), y: FieldElement::one(), z: FieldElement::one() }
|
||
|
}
|
||
|
fn generator() -> Self {
|
||
|
*G
|
||
|
}
|
||
|
fn is_identity(&self) -> Choice {
|
||
|
self.ct_eq(&Self::identity())
|
||
|
}
|
||
|
fn double(&self) -> Self {
|
||
|
// 7 muls, 7 additions, 4 negations
|
||
|
let xsq = self.x.square();
|
||
|
let ysq = self.y.square();
|
||
|
let zsq = self.z.square();
|
||
|
let xy = self.x + self.y;
|
||
|
#[allow(non_snake_case)]
|
||
|
let F = xsq + ysq;
|
||
|
#[allow(non_snake_case)]
|
||
|
let J = F - zsq.double();
|
||
|
Point { x: J * (xy.square() - xsq - ysq), y: F * (xsq - ysq), z: F * J }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Sum<Point> for Point {
|
||
|
fn sum<I: Iterator<Item = Point>>(iter: I) -> Point {
|
||
|
let mut res = Self::identity();
|
||
|
for i in iter {
|
||
|
res += i;
|
||
|
}
|
||
|
res
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<'a> Sum<&'a Point> for Point {
|
||
|
fn sum<I: Iterator<Item = &'a Point>>(iter: I) -> Point {
|
||
|
Point::sum(iter.cloned())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Mul<Scalar> for Point {
|
||
|
type Output = Point;
|
||
|
fn mul(self, other: Scalar) -> Point {
|
||
|
// Precompute the optimal amount that's a multiple of 2
|
||
|
let mut table = [Point::identity(); 16];
|
||
|
table[1] = self;
|
||
|
for i in 2 .. 16 {
|
||
|
table[i] = table[i - 1] + self;
|
||
|
}
|
||
|
|
||
|
let mut res = Self::identity();
|
||
|
let mut bits = 0;
|
||
|
for (i, bit) in other.to_le_bits().iter().rev().enumerate() {
|
||
|
bits <<= 1;
|
||
|
let bit = *bit as u8;
|
||
|
assert_eq!(bit | 1, 1);
|
||
|
bits |= bit;
|
||
|
|
||
|
if ((i + 1) % 4) == 0 {
|
||
|
if i != 3 {
|
||
|
for _ in 0 .. 4 {
|
||
|
res = res.double();
|
||
|
}
|
||
|
}
|
||
|
res += table[usize::from(bits)];
|
||
|
bits = 0;
|
||
|
}
|
||
|
}
|
||
|
res
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl MulAssign<Scalar> for Point {
|
||
|
fn mul_assign(&mut self, other: Scalar) {
|
||
|
*self = *self * other;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Mul<&Scalar> for Point {
|
||
|
type Output = Point;
|
||
|
fn mul(self, other: &Scalar) -> Point {
|
||
|
self * *other
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl MulAssign<&Scalar> for Point {
|
||
|
fn mul_assign(&mut self, other: &Scalar) {
|
||
|
*self *= *other;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Point {
|
||
|
pub fn is_torsion_free(&self) -> Choice {
|
||
|
(*self * *SCALAR_MODULUS).is_identity()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl GroupEncoding for Point {
|
||
|
type Repr = <FieldElement as PrimeField>::Repr;
|
||
|
|
||
|
fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
|
||
|
// Extract and clear the sign bit
|
||
|
let sign = Choice::from(bytes[56] >> 7);
|
||
|
let mut bytes = *bytes;
|
||
|
let mut_ref: &mut [u8] = bytes.as_mut();
|
||
|
mut_ref[56] &= !(1 << 7);
|
||
|
|
||
|
// Parse y, recover x
|
||
|
FieldElement::from_repr(bytes).and_then(|y| {
|
||
|
recover_x(y).and_then(|mut x| {
|
||
|
x.conditional_negate(x.is_odd().ct_eq(&!sign));
|
||
|
let not_negative_zero = !(x.is_zero() & sign);
|
||
|
let point = Point { x, y, z: FieldElement::one() };
|
||
|
CtOption::new(point, not_negative_zero & point.is_torsion_free())
|
||
|
})
|
||
|
})
|
||
|
}
|
||
|
|
||
|
fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
|
||
|
Point::from_bytes(bytes)
|
||
|
}
|
||
|
|
||
|
fn to_bytes(&self) -> Self::Repr {
|
||
|
let z = self.z.invert().unwrap();
|
||
|
let x = self.x * z;
|
||
|
let y = self.y * z;
|
||
|
|
||
|
let mut bytes = y.to_repr();
|
||
|
let mut_ref: &mut [u8] = bytes.as_mut();
|
||
|
mut_ref[56] |= x.is_odd().unwrap_u8() << 7;
|
||
|
bytes
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl PrimeGroup for Point {}
|
||
|
|
||
|
#[test]
|
||
|
fn identity() {
|
||
|
assert_eq!(Point::from_bytes(&Point::identity().to_bytes()).unwrap(), Point::identity());
|
||
|
assert_eq!(Point::identity() + Point::identity(), Point::identity());
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn addition_multiplication_serialization() {
|
||
|
let mut accum = Point::identity();
|
||
|
for x in 1 .. 10 {
|
||
|
accum += Point::generator();
|
||
|
let mul = Point::generator() * Scalar::from(x as u8);
|
||
|
assert_eq!(accum, mul);
|
||
|
assert_eq!(Point::from_bytes(&mul.to_bytes()).unwrap(), mul);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[rustfmt::skip]
|
||
|
#[test]
|
||
|
fn torsion() {
|
||
|
// Uses the originally suggested generator which had torsion
|
||
|
let old_y = FieldElement::from_repr(
|
||
|
hex_literal::hex!(
|
||
|
"12796c1532041525945f322e414d434467cfd5c57c9a9af2473b27758c921c4828b277ca5f2891fc4f3d79afdf29a64c72fb28b59c16fa5100"
|
||
|
).into(),
|
||
|
)
|
||
|
.unwrap();
|
||
|
let old = Point { x: -recover_x(old_y).unwrap(), y: old_y, z: FieldElement::one() };
|
||
|
assert!(bool::from(!old.is_torsion_free()));
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn vector() {
|
||
|
use generic_array::GenericArray;
|
||
|
|
||
|
assert_eq!(
|
||
|
Point::generator().double(),
|
||
|
Point::from_bytes(GenericArray::from_slice(
|
||
|
&hex::decode(
|
||
|
"\
|
||
|
ed8693eacdfbeada6ba0cdd1beb2bcbb98302a3a8365650db8c4d88a\
|
||
|
726de3b7d74d8835a0d76e03b0c2865020d659b38d04d74a63e905ae\
|
||
|
80"
|
||
|
)
|
||
|
.unwrap()
|
||
|
))
|
||
|
.unwrap()
|
||
|
);
|
||
|
|
||
|
assert_eq!(
|
||
|
Point::generator() *
|
||
|
Scalar::from_repr(*GenericArray::from_slice(
|
||
|
&hex::decode(
|
||
|
"\
|
||
|
6298e1eef3c379392caaed061ed8a31033c9e9e3420726f23b404158\
|
||
|
a401cd9df24632adfe6b418dc942d8a091817dd8bd70e1c72ba52f3c\
|
||
|
00"
|
||
|
)
|
||
|
.unwrap()
|
||
|
))
|
||
|
.unwrap(),
|
||
|
Point::from_bytes(GenericArray::from_slice(
|
||
|
&hex::decode(
|
||
|
"\
|
||
|
3832f82fda00ff5365b0376df705675b63d2a93c24c6e81d40801ba2\
|
||
|
65632be10f443f95968fadb70d10786827f30dc001c8d0f9b7c1d1b0\
|
||
|
00"
|
||
|
)
|
||
|
.unwrap()
|
||
|
))
|
||
|
.unwrap()
|
||
|
);
|
||
|
}
|