Expand task management

These extensions are necessary for the signers task management.
This commit is contained in:
Luke Parker 2024-09-08 00:30:55 -04:00
parent 59ff944152
commit 7484eadbbb
4 changed files with 105 additions and 27 deletions

View file

@ -1,28 +1,54 @@
use core::time::Duration; use core::time::Duration;
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::{mpsc, oneshot, Mutex};
/// A handle to immediately run an iteration of a task. enum Closed {
NotClosed(Option<oneshot::Receiver<()>>),
Closed,
}
/// A handle for a task.
#[derive(Clone)] #[derive(Clone)]
pub struct RunNowHandle(mpsc::Sender<()>); pub struct TaskHandle {
/// An instruction recipient to immediately run an iteration of a task. run_now: mpsc::Sender<()>,
pub struct RunNowRecipient(mpsc::Receiver<()>); close: mpsc::Sender<()>,
closed: Arc<Mutex<Closed>>,
}
/// A task's internal structures.
pub struct Task {
run_now: mpsc::Receiver<()>,
close: mpsc::Receiver<()>,
closed: oneshot::Sender<()>,
}
impl RunNowHandle { impl Task {
/// Create a new run-now handle to be assigned to a task. /// Create a new task definition.
pub fn new() -> (Self, RunNowRecipient) { pub fn new() -> (Self, TaskHandle) {
// Uses a capacity of 1 as any call to run as soon as possible satisfies all calls to run as // Uses a capacity of 1 as any call to run as soon as possible satisfies all calls to run as
// soon as possible // soon as possible
let (send, recv) = mpsc::channel(1); let (run_now_send, run_now_recv) = mpsc::channel(1);
(Self(send), RunNowRecipient(recv)) // And any call to close satisfies all calls to close
let (close_send, close_recv) = mpsc::channel(1);
let (closed_send, closed_recv) = oneshot::channel();
(
Self { run_now: run_now_recv, close: close_recv, closed: closed_send },
TaskHandle {
run_now: run_now_send,
close: close_send,
closed: Arc::new(Mutex::new(Closed::NotClosed(Some(closed_recv)))),
},
)
} }
}
impl TaskHandle {
/// Tell the task to run now (and not whenever its next iteration on a timer is). /// Tell the task to run now (and not whenever its next iteration on a timer is).
/// ///
/// Panics if the task has been dropped. /// Panics if the task has been dropped.
pub fn run_now(&self) { pub fn run_now(&self) {
#[allow(clippy::match_same_arms)] #[allow(clippy::match_same_arms)]
match self.0.try_send(()) { match self.run_now.try_send(()) {
Ok(()) => {} Ok(()) => {}
// NOP on full, as this task will already be ran as soon as possible // NOP on full, as this task will already be ran as soon as possible
Err(mpsc::error::TrySendError::Full(())) => {} Err(mpsc::error::TrySendError::Full(())) => {}
@ -31,6 +57,24 @@ impl RunNowHandle {
} }
} }
} }
/// Close the task.
///
/// Returns once the task shuts down after it finishes its current iteration (which may be of
/// unbounded time).
pub async fn close(self) {
// If another instance of the handle called tfhis, don't error
let _ = self.close.send(()).await;
// Wait until we receive the closed message
let mut closed = self.closed.lock().await;
match &mut *closed {
Closed::NotClosed(ref mut recv) => {
assert_eq!(recv.take().unwrap().await, Ok(()), "continually ran task dropped itself?");
*closed = Closed::Closed;
}
Closed::Closed => {}
}
}
} }
/// A task to be continually ran. /// A task to be continually ran.
@ -50,10 +94,7 @@ pub trait ContinuallyRan: Sized {
async fn run_iteration(&mut self) -> Result<bool, String>; async fn run_iteration(&mut self) -> Result<bool, String>;
/// Continually run the task. /// Continually run the task.
/// async fn continually_run(mut self, mut task: Task, dependents: Vec<TaskHandle>) {
/// This returns a channel which can have a message set to immediately trigger a new run of an
/// iteration.
async fn continually_run(mut self, mut run_now: RunNowRecipient, dependents: Vec<RunNowHandle>) {
// The default number of seconds to sleep before running the task again // The default number of seconds to sleep before running the task again
let default_sleep_before_next_task = Self::DELAY_BETWEEN_ITERATIONS; let default_sleep_before_next_task = Self::DELAY_BETWEEN_ITERATIONS;
// The current number of seconds to sleep before running the task again // The current number of seconds to sleep before running the task again
@ -66,6 +107,15 @@ pub trait ContinuallyRan: Sized {
}; };
loop { loop {
// If we were told to close/all handles were dropped, drop it
{
let should_close = task.close.try_recv();
match should_close {
Ok(()) | Err(mpsc::error::TryRecvError::Disconnected) => break,
Err(mpsc::error::TryRecvError::Empty) => {}
}
}
match self.run_iteration().await { match self.run_iteration().await {
Ok(run_dependents) => { Ok(run_dependents) => {
// Upon a successful (error-free) loop iteration, reset the amount of time we sleep // Upon a successful (error-free) loop iteration, reset the amount of time we sleep
@ -86,8 +136,15 @@ pub trait ContinuallyRan: Sized {
// Don't run the task again for another few seconds UNLESS told to run now // Don't run the task again for another few seconds UNLESS told to run now
tokio::select! { tokio::select! {
() = tokio::time::sleep(Duration::from_secs(current_sleep_before_next_task)) => {}, () = tokio::time::sleep(Duration::from_secs(current_sleep_before_next_task)) => {},
msg = run_now.0.recv() => assert_eq!(msg, Some(()), "run now handle was dropped"), msg = task.run_now.recv() => {
// Check if this is firing because the handle was dropped
if msg.is_none() {
break;
}
},
} }
} }
task.closed.send(()).unwrap();
} }
} }

View file

@ -343,7 +343,7 @@ pub trait Scheduler<S: ScannerFeed>: 'static + Send {
/// A representation of a scanner. /// A representation of a scanner.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct Scanner<S: ScannerFeed> { pub struct Scanner<S: ScannerFeed> {
substrate_handle: RunNowHandle, substrate_handle: TaskHandle,
_S: PhantomData<S>, _S: PhantomData<S>,
} }
impl<S: ScannerFeed> Scanner<S> { impl<S: ScannerFeed> Scanner<S> {
@ -362,11 +362,11 @@ impl<S: ScannerFeed> Scanner<S> {
let substrate_task = substrate::SubstrateTask::<_, S>::new(db.clone()); let substrate_task = substrate::SubstrateTask::<_, S>::new(db.clone());
let eventuality_task = eventuality::EventualityTask::<_, _, Sch>::new(db, feed, start_block); let eventuality_task = eventuality::EventualityTask::<_, _, Sch>::new(db, feed, start_block);
let (_index_handle, index_run) = RunNowHandle::new(); let (index_run, _index_handle) = Task::new();
let (scan_handle, scan_run) = RunNowHandle::new(); let (scan_run, scan_handle) = Task::new();
let (report_handle, report_run) = RunNowHandle::new(); let (report_run, report_handle) = Task::new();
let (substrate_handle, substrate_run) = RunNowHandle::new(); let (substrate_run, substrate_handle) = Task::new();
let (eventuality_handle, eventuality_run) = RunNowHandle::new(); let (eventuality_run, eventuality_handle) = Task::new();
// Upon indexing a new block, scan it // Upon indexing a new block, scan it
tokio::spawn(index_task.continually_run(index_run, vec![scan_handle.clone()])); tokio::spawn(index_task.continually_run(index_run, vec![scan_handle.clone()]));

View file

@ -9,6 +9,7 @@ create_db! {
RegisteredKeys: () -> Vec<Session>, RegisteredKeys: () -> Vec<Session>,
SerializedKeys: (session: Session) -> Vec<u8>, SerializedKeys: (session: Session) -> Vec<u8>,
LatestRetiredSession: () -> Session, LatestRetiredSession: () -> Session,
ToCleanup: () -> Vec<(Session, Vec<u8>)>,
} }
} }

View file

@ -3,6 +3,7 @@
#![deny(missing_docs)] #![deny(missing_docs)]
use core::{fmt::Debug, marker::PhantomData}; use core::{fmt::Debug, marker::PhantomData};
use std::collections::HashMap;
use zeroize::Zeroizing; use zeroize::Zeroizing;
@ -13,6 +14,7 @@ use frost::dkg::{ThresholdCore, ThresholdKeys};
use serai_db::{DbTxn, Db}; use serai_db::{DbTxn, Db};
use primitives::task::TaskHandle;
use scheduler::{Transaction, SignableTransaction, TransactionsToSign}; use scheduler::{Transaction, SignableTransaction, TransactionsToSign};
pub(crate) mod db; pub(crate) mod db;
@ -39,7 +41,10 @@ pub trait TransactionPublisher<T: Transaction>: 'static + Send + Sync {
} }
/// The signers used by a processor. /// The signers used by a processor.
pub struct Signers<ST: SignableTransaction>(PhantomData<ST>); pub struct Signers<ST: SignableTransaction> {
tasks: HashMap<Session, Vec<TaskHandle>>,
_ST: PhantomData<ST>,
}
/* /*
This is completely outside of consensus, so the worst that can happen is: This is completely outside of consensus, so the worst that can happen is:
@ -58,6 +63,8 @@ impl<ST: SignableTransaction> Signers<ST> {
/// ///
/// This will spawn tasks for any historically registered keys. /// This will spawn tasks for any historically registered keys.
pub fn new(db: impl Db) -> Self { pub fn new(db: impl Db) -> Self {
let mut tasks = HashMap::new();
for session in db::RegisteredKeys::get(&db).unwrap_or(vec![]) { for session in db::RegisteredKeys::get(&db).unwrap_or(vec![]) {
let buf = db::SerializedKeys::get(&db, session).unwrap(); let buf = db::SerializedKeys::get(&db, session).unwrap();
let mut buf = buf.as_slice(); let mut buf = buf.as_slice();
@ -74,7 +81,7 @@ impl<ST: SignableTransaction> Signers<ST> {
todo!("TODO") todo!("TODO")
} }
todo!("TODO") Self { tasks, _ST: PhantomData }
} }
/// Register a set of keys to sign with. /// Register a set of keys to sign with.
@ -87,6 +94,7 @@ impl<ST: SignableTransaction> Signers<ST> {
substrate_keys: Vec<ThresholdKeys<Ristretto>>, substrate_keys: Vec<ThresholdKeys<Ristretto>>,
network_keys: Vec<ThresholdKeys<ST::Ciphersuite>>, network_keys: Vec<ThresholdKeys<ST::Ciphersuite>>,
) { ) {
// Don't register already retired keys
if Some(session.0) <= db::LatestRetiredSession::get(txn).map(|session| session.0) { if Some(session.0) <= db::LatestRetiredSession::get(txn).map(|session| session.0) {
return; return;
} }
@ -125,9 +133,6 @@ impl<ST: SignableTransaction> Signers<ST> {
db::LatestRetiredSession::set(txn, &session); db::LatestRetiredSession::set(txn, &session);
} }
// Kill the tasks
todo!("TODO");
// Update RegisteredKeys/SerializedKeys // Update RegisteredKeys/SerializedKeys
if let Some(registered) = db::RegisteredKeys::get(txn) { if let Some(registered) = db::RegisteredKeys::get(txn) {
db::RegisteredKeys::set( db::RegisteredKeys::set(
@ -137,6 +142,20 @@ impl<ST: SignableTransaction> Signers<ST> {
} }
db::SerializedKeys::del(txn, session); db::SerializedKeys::del(txn, session);
// Queue the session for clean up
let mut to_cleanup = db::ToCleanup::get(txn).unwrap_or(vec![]);
to_cleanup.push((session, external_key.to_bytes().as_ref().to_vec()));
db::ToCleanup::set(txn, &to_cleanup);
// TODO: Handle all of the following cleanup on a task
/*
// Kill the tasks
if let Some(tasks) = self.tasks.remove(&session) {
for task in tasks {
task.close().await;
}
}
// Drain the transactions to sign // Drain the transactions to sign
// Presumably, TransactionsToSign will be fully populated before retiry occurs, making this // Presumably, TransactionsToSign will be fully populated before retiry occurs, making this
// perfect in not leaving any pending blobs behind // perfect in not leaving any pending blobs behind
@ -152,6 +171,7 @@ impl<ST: SignableTransaction> Signers<ST> {
while db::SlashReportSignerToCoordinatorMessages::try_recv(txn, session).is_some() {} while db::SlashReportSignerToCoordinatorMessages::try_recv(txn, session).is_some() {}
while db::CoordinatorToCosignerMessages::try_recv(txn, session).is_some() {} while db::CoordinatorToCosignerMessages::try_recv(txn, session).is_some() {}
while db::CosignerToCoordinatorMessages::try_recv(txn, session).is_some() {} while db::CosignerToCoordinatorMessages::try_recv(txn, session).is_some() {}
*/
} }
} }