diff --git a/substrate/tendermint/machine/src/ext.rs b/substrate/tendermint/machine/src/ext.rs index f3030027..9852f528 100644 --- a/substrate/tendermint/machine/src/ext.rs +++ b/substrate/tendermint/machine/src/ext.rs @@ -33,8 +33,35 @@ pub struct BlockNumber(pub u64); #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Encode, Decode)] pub struct Round(pub u32); -/// A signature scheme used by validators. +/// A signer for a validator. #[async_trait] +pub trait Signer: Send + Sync { + // Type used to identify validators. + type ValidatorId: ValidatorId; + /// Signature type. + type Signature: Signature; + + /// Returns the validator's current ID. + async fn validator_id(&self) -> Self::ValidatorId; + /// Sign a signature with the current validator's private key. + async fn sign(&self, msg: &[u8]) -> Self::Signature; +} + +#[async_trait] +impl<S: Signer> Signer for Arc<S> { + type ValidatorId = S::ValidatorId; + type Signature = S::Signature; + + async fn validator_id(&self) -> Self::ValidatorId { + self.as_ref().validator_id().await + } + + async fn sign(&self, msg: &[u8]) -> Self::Signature { + self.as_ref().sign(msg).await + } +} + +/// A signature scheme used by validators. pub trait SignatureScheme: Send + Sync { // Type used to identify validators. type ValidatorId: ValidatorId; @@ -46,8 +73,9 @@ pub trait SignatureScheme: Send + Sync { /// It could even be a threshold signature scheme, though that's currently unexpected. type AggregateSignature: Signature; - /// Sign a signature with the current validator's private key. - async fn sign(&self, msg: &[u8]) -> Self::Signature; + /// Type representing a signer of this scheme. + type Signer: Signer<ValidatorId = Self::ValidatorId, Signature = Self::Signature>; + /// Verify a signature from the validator in question. #[must_use] fn verify(&self, validator: Self::ValidatorId, msg: &[u8], sig: &Self::Signature) -> bool; @@ -64,6 +92,31 @@ pub trait SignatureScheme: Send + Sync { ) -> bool; } +impl<S: SignatureScheme> SignatureScheme for Arc<S> { + type ValidatorId = S::ValidatorId; + type Signature = S::Signature; + type AggregateSignature = S::AggregateSignature; + type Signer = S::Signer; + + fn verify(&self, validator: Self::ValidatorId, msg: &[u8], sig: &Self::Signature) -> bool { + self.as_ref().verify(validator, msg, sig) + } + + fn aggregate(sigs: &[Self::Signature]) -> Self::AggregateSignature { + S::aggregate(sigs) + } + + #[must_use] + fn verify_aggregate( + &self, + signers: &[Self::ValidatorId], + msg: &[u8], + sig: &Self::AggregateSignature, + ) -> bool { + self.as_ref().verify_aggregate(signers, msg, sig) + } +} + /// A commit for a specific block. The list of validators have weight exceeding the threshold for /// a valid commit. #[derive(Clone, PartialEq, Debug, Encode, Decode)] @@ -97,6 +150,22 @@ pub trait Weights: Send + Sync { fn proposer(&self, number: BlockNumber, round: Round) -> Self::ValidatorId; } +impl<W: Weights> Weights for Arc<W> { + type ValidatorId = W::ValidatorId; + + fn total_weight(&self) -> u64 { + self.as_ref().total_weight() + } + + fn weight(&self, validator: Self::ValidatorId) -> u64 { + self.as_ref().weight(validator) + } + + fn proposer(&self, number: BlockNumber, round: Round) -> Self::ValidatorId { + self.as_ref().proposer(number, round) + } +} + /// Simplified error enum representing a block's validity. #[derive(Clone, Copy, PartialEq, Eq, Debug, Error, Encode, Decode)] pub enum BlockError { @@ -141,11 +210,12 @@ pub trait Network: Send + Sync { // Block time in seconds const BLOCK_TIME: u32; - /// Return the signature scheme in use. The instance is expected to have the validators' public - /// keys, along with an instance of the private key of the current validator. - fn signature_scheme(&self) -> Arc<Self::SignatureScheme>; - /// Return a reference to the validators' weights. - fn weights(&self) -> Arc<Self::Weights>; + /// Return a handle on the signer in use, usable for the entire lifetime of the machine. + fn signer(&self) -> <Self::SignatureScheme as SignatureScheme>::Signer; + /// Return a handle on the signing scheme in use, usable for the entire lifetime of the machine. + fn signature_scheme(&self) -> Self::SignatureScheme; + /// Return a handle on the validators' weights, usable for the entire lifetime of the machine. + fn weights(&self) -> Self::Weights; /// Verify a commit for a given block. Intended for use when syncing or when not an active /// validator. diff --git a/substrate/tendermint/machine/src/lib.rs b/substrate/tendermint/machine/src/lib.rs index b0ba1a3d..2eb0fb28 100644 --- a/substrate/tendermint/machine/src/lib.rs +++ b/substrate/tendermint/machine/src/lib.rs @@ -10,10 +10,7 @@ use parity_scale_codec::{Encode, Decode}; use tokio::{ task::{JoinHandle, yield_now}, - sync::{ - RwLock, - mpsc::{self, error::TryRecvError}, - }, + sync::mpsc::{self, error::TryRecvError}, time::sleep, }; @@ -90,7 +87,7 @@ impl<V: ValidatorId, B: Block, S: Signature> SignedMessage<V, B, S> { #[must_use] pub fn verify_signature<Scheme: SignatureScheme<ValidatorId = V, Signature = S>>( &self, - signer: &Arc<Scheme>, + signer: &Scheme, ) -> bool { signer.verify(self.msg.sender, &self.msg.encode(), &self.sig) } @@ -104,10 +101,12 @@ enum TendermintError<V: ValidatorId> { /// A machine executing the Tendermint protocol. pub struct TendermintMachine<N: Network> { - network: Arc<RwLock<N>>, - signer: Arc<N::SignatureScheme>, + network: N, + signer: <N::SignatureScheme as SignatureScheme>::Signer, + validators: N::SignatureScheme, weights: Arc<N::Weights>, - proposer: N::ValidatorId, + + validator_id: N::ValidatorId, number: BlockNumber, canonical_start_time: u64, @@ -178,13 +177,13 @@ impl<N: Network + 'static> TendermintMachine<N> { self.step = step; self.queue.push(( true, - Message { sender: self.proposer, number: self.number, round: self.round, data }, + Message { sender: self.validator_id, number: self.number, round: self.round, data }, )); } // 14-21 fn round_propose(&mut self) -> bool { - if self.weights.proposer(self.number, self.round) == self.proposer { + if self.weights.proposer(self.number, self.round) == self.validator_id { let (round, block) = self .valid .clone() @@ -222,6 +221,8 @@ impl<N: Network + 'static> TendermintMachine<N> { let round_end = self.end_time[&end_round]; sleep(round_end.saturating_duration_since(Instant::now())).await; + self.validator_id = self.signer.validator_id().await; + self.number.0 += 1; self.canonical_start_time = self.canonical_end_time(end_round); self.start_time = round_end; @@ -229,7 +230,7 @@ impl<N: Network + 'static> TendermintMachine<N> { self.queue = self.queue.drain(..).filter(|msg| msg.1.number == self.number).collect(); - self.log = MessageLog::new(self.network.read().await.weights()); + self.log = MessageLog::new(self.weights.clone()); self.end_time = HashMap::new(); self.locked = None; @@ -241,12 +242,7 @@ impl<N: Network + 'static> TendermintMachine<N> { /// Create a new Tendermint machine, for the specified proposer, from the specified block, with /// the specified block as the one to propose next, returning a handle for the machine. #[allow(clippy::new_ret_no_self)] - pub fn new( - network: N, - proposer: N::ValidatorId, - last: (BlockNumber, u64), - proposal: N::Block, - ) -> TendermintHandle<N> { + pub fn new(network: N, last: (BlockNumber, u64), proposal: N::Block) -> TendermintHandle<N> { let (msg_send, mut msg_recv) = mpsc::channel(100); // Backlog to accept. Currently arbitrary TendermintHandle { messages: msg_send, @@ -270,15 +266,18 @@ impl<N: Network + 'static> TendermintMachine<N> { instant_now - sys_now.duration_since(last_end).unwrap_or(Duration::ZERO) }; - let signer = network.signature_scheme(); - let weights = network.weights(); - let network = Arc::new(RwLock::new(network)); + let signer = network.signer(); + let validators = network.signature_scheme(); + let weights = Arc::new(network.weights()); + let validator_id = signer.validator_id().await; // 01-10 let mut machine = TendermintMachine { network, signer, + validators, weights: weights.clone(), - proposer, + + validator_id, number: BlockNumber(last.0 .0 + 1), canonical_start_time: last.1, @@ -334,7 +333,7 @@ impl<N: Network + 'static> TendermintMachine<N> { loop { match msg_recv.try_recv() { Ok(msg) => { - if !msg.verify_signature(&machine.signer) { + if !msg.verify_signature(&machine.validators) { continue; } machine.queue.push((false, msg.msg)); @@ -372,20 +371,20 @@ impl<N: Network + 'static> TendermintMachine<N> { validators, signature: N::SignatureScheme::aggregate(&sigs), }; - debug_assert!(machine.network.read().await.verify_commit(block.id(), &commit)); + debug_assert!(machine.network.verify_commit(block.id(), &commit)); - let proposal = machine.network.write().await.add_block(block, commit).await; + let proposal = machine.network.add_block(block, commit).await; machine.reset(msg.round, proposal).await; } Err(TendermintError::Malicious(validator)) => { - machine.network.write().await.slash(validator).await; + machine.network.slash(validator).await; } Err(TendermintError::Temporal) => (), } if broadcast { let sig = machine.signer.sign(&msg.encode()).await; - machine.network.write().await.broadcast(SignedMessage { msg, sig }).await; + machine.network.broadcast(SignedMessage { msg, sig }).await; } } @@ -405,7 +404,7 @@ impl<N: Network + 'static> TendermintMachine<N> { // Verify the end time and signature if this is a precommit if let Data::Precommit(Some((id, sig))) = &msg.data { - if !self.signer.verify( + if !self.validators.verify( msg.sender, &commit_msg(self.canonical_end_time(msg.round), id.as_ref()), sig, @@ -496,7 +495,7 @@ impl<N: Network + 'static> TendermintMachine<N> { // 22-33 if self.step == Step::Propose { // Delay error handling (triggering a slash) until after we vote. - let (valid, err) = match self.network.write().await.validate(block).await { + let (valid, err) = match self.network.validate(block).await { Ok(_) => (true, Ok(None)), Err(BlockError::Temporal) => (false, Ok(None)), Err(BlockError::Fatal) => (false, Err(TendermintError::Malicious(proposer))), @@ -538,7 +537,7 @@ impl<N: Network + 'static> TendermintMachine<N> { // being set, or only being set historically, means this has yet to be run if self.log.has_consensus(self.round, Data::Prevote(Some(block.id()))) { - match self.network.write().await.validate(block).await { + match self.network.validate(block).await { Ok(_) => (), Err(BlockError::Temporal) => (), Err(BlockError::Fatal) => Err(TendermintError::Malicious(proposer))?, diff --git a/substrate/tendermint/machine/tests/ext.rs b/substrate/tendermint/machine/tests/ext.rs index 5e9e7d52..0dd9c331 100644 --- a/substrate/tendermint/machine/tests/ext.rs +++ b/substrate/tendermint/machine/tests/ext.rs @@ -14,12 +14,15 @@ use tendermint_machine::{ext::*, SignedMessage, TendermintMachine, TendermintHan type TestValidatorId = u16; type TestBlockId = [u8; 4]; -struct TestSignatureScheme(u16); +struct TestSigner(u16); #[async_trait] -impl SignatureScheme for TestSignatureScheme { +impl Signer for TestSigner { type ValidatorId = TestValidatorId; type Signature = [u8; 32]; - type AggregateSignature = Vec<[u8; 32]>; + + async fn validator_id(&self) -> TestValidatorId { + self.0 + } async fn sign(&self, msg: &[u8]) -> [u8; 32] { let mut sig = [0; 32]; @@ -27,6 +30,14 @@ impl SignatureScheme for TestSignatureScheme { sig[2 .. (2 + 30.min(msg.len()))].copy_from_slice(&msg[.. 30.min(msg.len())]); sig } +} + +struct TestSignatureScheme; +impl SignatureScheme for TestSignatureScheme { + type ValidatorId = TestValidatorId; + type Signature = [u8; 32]; + type AggregateSignature = Vec<[u8; 32]>; + type Signer = TestSigner; #[must_use] fn verify(&self, validator: u16, msg: &[u8], sig: &[u8; 32]) -> bool { @@ -93,12 +104,16 @@ impl Network for TestNetwork { const BLOCK_TIME: u32 = 1; - fn signature_scheme(&self) -> Arc<TestSignatureScheme> { - Arc::new(TestSignatureScheme(self.0)) + fn signer(&self) -> TestSigner { + TestSigner(self.0) } - fn weights(&self) -> Arc<TestWeights> { - Arc::new(TestWeights) + fn signature_scheme(&self) -> TestSignatureScheme { + TestSignatureScheme + } + + fn weights(&self) -> TestWeights { + TestWeights } async fn broadcast(&mut self, msg: SignedMessage<TestValidatorId, Self::Block, [u8; 32]>) { @@ -137,7 +152,6 @@ impl TestNetwork { let i = u16::try_from(i).unwrap(); write.push(TendermintMachine::new( TestNetwork(i, arc.clone()), - i, (BlockNumber(1), (SystemTime::now().duration_since(UNIX_EPOCH)).unwrap().as_secs()), TestBlock { id: 1u32.to_le_bytes(), valid: Ok(()) }, ));