Add a proper database trait

This commit is contained in:
Luke Parker 2022-06-05 06:00:21 -04:00
parent 3617ed4eb7
commit a46432b829
No known key found for this signature in database
GPG key ID: F9F1386DB1E119B6
4 changed files with 88 additions and 28 deletions

View file

@ -24,10 +24,10 @@ impl OutputTrait for Output {
// While we could use (tx, o), using the key ensures we won't be susceptible to the burning bug. // While we could use (tx, o), using the key ensures we won't be susceptible to the burning bug.
// While the Monero library offers a variant which allows senders to ensure their TXs have unique // While the Monero library offers a variant which allows senders to ensure their TXs have unique
// output keys, Serai can still be targeted using the classic burning bug // output keys, Serai can still be targeted using the classic burning bug
type Id = CompressedEdwardsY; type Id = [u8; 32];
fn id(&self) -> Self::Id { fn id(&self) -> Self::Id {
self.0.key.compress() self.0.key.compress().to_bytes()
} }
fn amount(&self) -> u64 { fn amount(&self) -> u64 {

View file

@ -13,7 +13,7 @@ mod wallet;
mod tests; mod tests;
pub trait Output: Sized + Clone { pub trait Output: Sized + Clone {
type Id; type Id: AsRef<[u8]>;
fn id(&self) -> Self::Id; fn id(&self) -> Self::Id;
fn amount(&self) -> u64; fn amount(&self) -> u64;

View file

@ -2,14 +2,14 @@ use std::rc::Rc;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use crate::{Coin, coins::monero::Monero, wallet::{WalletKeys, Wallet}}; use crate::{Coin, coins::monero::Monero, wallet::{WalletKeys, MemCoinDb, Wallet}};
#[tokio::test] #[tokio::test]
async fn test() { async fn test() {
let monero = Monero::new("http://127.0.0.1:18081".to_string()); let monero = Monero::new("http://127.0.0.1:18081".to_string());
println!("{}", monero.get_height().await.unwrap()); println!("{}", monero.get_height().await.unwrap());
let mut keys = frost::tests::key_gen::<_, <Monero as Coin>::Curve>(&mut OsRng); let mut keys = frost::tests::key_gen::<_, <Monero as Coin>::Curve>(&mut OsRng);
let mut wallet = Wallet::new(monero); let mut wallet = Wallet::new(MemCoinDb::new(), monero);
wallet.acknowledge_height(0, 0); wallet.acknowledge_height(0, 0);
wallet.add_keys(&WalletKeys::new(Rc::try_unwrap(keys.remove(&1).take().unwrap()).unwrap(), 0)); wallet.add_keys(&WalletKeys::new(Rc::try_unwrap(keys.remove(&1).take().unwrap()).unwrap(), 0));
dbg!(0); dbg!(0);

View file

@ -33,28 +33,83 @@ impl<C: Curve> WalletKeys<C> {
} }
} }
pub struct CoinDb { pub trait CoinDb {
// Set a height as scanned to
fn scanned_to_height(&mut self, height: usize);
// Acknowledge a given coin height for a canonical height
fn acknowledge_height(&mut self, canonical: usize, height: usize);
// Adds an output to the DB. Returns false if the output was already added
fn add_output<O: Output>(&mut self, output: &O) -> bool;
// Height this coin has been scanned to
fn scanned_height(&self) -> usize;
// Acknowledged height for a given canonical height
fn acknowledged_height(&self, canonical: usize) -> usize;
}
pub struct MemCoinDb {
// Height this coin has been scanned to // Height this coin has been scanned to
scanned_height: usize, scanned_height: usize,
// Acknowledged height for a given canonical height // Acknowledged height for a given canonical height
acknowledged_heights: HashMap<usize, usize> acknowledged_heights: HashMap<usize, usize>,
outputs: HashMap<Vec<u8>, Vec<u8>>
} }
pub struct Wallet<C: Coin> {
db: CoinDb, impl MemCoinDb {
pub fn new() -> MemCoinDb {
MemCoinDb {
scanned_height: 0,
acknowledged_heights: HashMap::new(),
outputs: HashMap::new()
}
}
}
impl CoinDb for MemCoinDb {
fn scanned_to_height(&mut self, height: usize) {
self.scanned_height = height;
}
fn acknowledge_height(&mut self, canonical: usize, height: usize) {
debug_assert!(!self.acknowledged_heights.contains_key(&canonical));
self.acknowledged_heights.insert(canonical, height);
}
fn add_output<O: Output>(&mut self, output: &O) -> bool {
// This would be insecure as we're indexing by ID and this will replace the output as a whole
// Multiple outputs may have the same ID in edge cases such as Monero, where outputs are ID'd
// by key image, not by hash + index
// self.outputs.insert(output.id(), output).is_some()
let id = output.id().as_ref().to_vec();
if self.outputs.contains_key(&id) {
return false;
}
self.outputs.insert(id, output.serialize());
true
}
fn scanned_height(&self) -> usize {
self.scanned_height
}
fn acknowledged_height(&self, canonical: usize) -> usize {
self.acknowledged_heights[&canonical]
}
}
pub struct Wallet<D: CoinDb, C: Coin> {
db: D,
coin: C, coin: C,
keys: Vec<(Arc<MultisigKeys<C::Curve>>, Vec<C::Output>)>, keys: Vec<(Arc<MultisigKeys<C::Curve>>, Vec<C::Output>)>,
pending: Vec<(usize, MultisigKeys<C::Curve>)> pending: Vec<(usize, MultisigKeys<C::Curve>)>
} }
impl<C: Coin> Wallet<C> { impl<D: CoinDb, C: Coin> Wallet<D, C> {
pub fn new(coin: C) -> Wallet<C> { pub fn new(db: D, coin: C) -> Wallet<D, C> {
Wallet { Wallet {
db: CoinDb { db,
scanned_height: 0,
acknowledged_heights: HashMap::new(),
},
coin, coin,
keys: vec![], keys: vec![],
@ -62,13 +117,12 @@ impl<C: Coin> Wallet<C> {
} }
} }
pub fn scanned_height(&self) -> usize { self.db.scanned_height } pub fn scanned_height(&self) -> usize { self.db.scanned_height() }
pub fn acknowledge_height(&mut self, canonical: usize, height: usize) { pub fn acknowledge_height(&mut self, canonical: usize, height: usize) {
debug_assert!(!self.db.acknowledged_heights.contains_key(&canonical)); self.db.acknowledge_height(canonical, height);
self.db.acknowledged_heights.insert(canonical, height);
} }
pub fn acknowledged_height(&self, canonical: usize) -> usize { pub fn acknowledged_height(&self, canonical: usize) -> usize {
self.db.acknowledged_heights[&canonical] self.db.acknowledged_height(canonical)
} }
pub fn add_keys(&mut self, keys: &WalletKeys<C::Curve>) { pub fn add_keys(&mut self, keys: &WalletKeys<C::Curve>) {
@ -83,7 +137,10 @@ impl<C: Coin> Wallet<C> {
{ {
let mut k = 0; let mut k = 0;
while k < self.pending.len() { while k < self.pending.len() {
if height >= self.pending[k].0 { // TODO
//if height < self.pending[k].0 {
//} else if height == self.pending[k].0 {
if height <= self.pending[k].0 {
self.keys.push((Arc::new(self.pending.swap_remove(k).1), vec![])); self.keys.push((Arc::new(self.pending.swap_remove(k).1), vec![]));
} else { } else {
k += 1; k += 1;
@ -95,7 +152,7 @@ impl<C: Coin> Wallet<C> {
for (keys, outputs) in self.keys.iter_mut() { for (keys, outputs) in self.keys.iter_mut() {
outputs.extend( outputs.extend(
self.coin.get_outputs(&block, keys.group_key()).await.iter().cloned().filter( self.coin.get_outputs(&block, keys.group_key()).await.iter().cloned().filter(
|_output| true // !self.db.handled.contains_key(output.id()) // TODO |output| self.db.add_output(output)
) )
); );
} }
@ -103,18 +160,23 @@ impl<C: Coin> Wallet<C> {
Ok(()) Ok(())
} }
// This should be called whenever new outputs are received, meaning there was a new block
// If these outputs were received and sent to Substrate, it should be called after they're
// included in a block and we have results to act on
// If these outputs weren't sent to Substrate (change), it should be called immediately
// with all payments still queued from the last call
pub async fn prepare_sends( pub async fn prepare_sends(
&mut self, &mut self,
canonical: usize, canonical: usize,
payments: Vec<(C::Address, u64)> payments: Vec<(C::Address, u64)>
) -> Result<Vec<C::SignableTransaction>, CoinError> { ) -> Result<(Vec<(C::Address, u64)>, Vec<C::SignableTransaction>), CoinError> {
if payments.len() == 0 { if payments.len() == 0 {
return Ok(vec![]); return Ok((vec![], vec![]));
} }
let acknowledged_height = self.acknowledged_height(canonical); let acknowledged_height = self.acknowledged_height(canonical);
// TODO: Log schedule outputs when MAX_OUTPUTS is low // TODO: Log schedule outputs when MAX_OUTPUTS is lower than payments.len()
// Payments is the first set of TXs in the schedule // Payments is the first set of TXs in the schedule
// As each payment re-appears, let mut payments = schedule[payment] where the only input is // As each payment re-appears, let mut payments = schedule[payment] where the only input is
// the source payment // the source payment
@ -177,8 +239,6 @@ impl<C: Coin> Wallet<C> {
} }
} }
// TODO: Remaining payments? Ok((payments, txs))
Ok(txs)
} }
} }