Move message-queue to a fully binary representation (#454)

* Move message-queue to a fully binary representation

Additionally adds a timeout to the message queue test.

* coordinator clippy

* Remove contention for the message-queue socket by using per-request sockets

* clippy
This commit is contained in:
Luke Parker 2023-11-26 11:22:18 -05:00 committed by GitHub
parent c6c74684c9
commit b79cf8abde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 241 additions and 228 deletions

4
Cargo.lock generated
View file

@ -7719,7 +7719,6 @@ dependencies = [
"env_logger", "env_logger",
"flexible-transcript", "flexible-transcript",
"hex", "hex",
"jsonrpsee",
"log", "log",
"once_cell", "once_cell",
"rand_core", "rand_core",
@ -7727,9 +7726,6 @@ dependencies = [
"serai-db", "serai-db",
"serai-env", "serai-env",
"serai-primitives", "serai-primitives",
"serde",
"serde_json",
"simple-request",
"tokio", "tokio",
"zeroize", "zeroize",
] ]

View file

@ -769,7 +769,7 @@ async fn handle_processor_messages<D: Db, Pro: Processors, P: P2p>(
mut db: D, mut db: D,
key: Zeroizing<<Ristretto as Ciphersuite>::F>, key: Zeroizing<<Ristretto as Ciphersuite>::F>,
serai: Arc<Serai>, serai: Arc<Serai>,
mut processors: Pro, processors: Pro,
p2p: P, p2p: P,
cosign_channel: mpsc::UnboundedSender<CosignedBlock>, cosign_channel: mpsc::UnboundedSender<CosignedBlock>,
network: NetworkId, network: NetworkId,

View file

@ -15,8 +15,8 @@ pub struct Message {
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait Processors: 'static + Send + Sync + Clone { pub trait Processors: 'static + Send + Sync + Clone {
async fn send(&self, network: NetworkId, msg: impl Send + Into<CoordinatorMessage>); async fn send(&self, network: NetworkId, msg: impl Send + Into<CoordinatorMessage>);
async fn recv(&mut self, network: NetworkId) -> Message; async fn recv(&self, network: NetworkId) -> Message;
async fn ack(&mut self, msg: Message); async fn ack(&self, msg: Message);
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -28,7 +28,7 @@ impl Processors for Arc<MessageQueue> {
let msg = borsh::to_vec(&msg).unwrap(); let msg = borsh::to_vec(&msg).unwrap();
self.queue(metadata, msg).await; self.queue(metadata, msg).await;
} }
async fn recv(&mut self, network: NetworkId) -> Message { async fn recv(&self, network: NetworkId) -> Message {
let msg = self.next(Service::Processor(network)).await; let msg = self.next(Service::Processor(network)).await;
assert_eq!(msg.from, Service::Processor(network)); assert_eq!(msg.from, Service::Processor(network));
@ -40,7 +40,7 @@ impl Processors for Arc<MessageQueue> {
return Message { id, network, msg }; return Message { id, network, msg };
} }
async fn ack(&mut self, msg: Message) { async fn ack(&self, msg: Message) {
MessageQueue::ack(self, Service::Processor(msg.network), msg.id).await MessageQueue::ack(self, Service::Processor(msg.network), msg.id).await
} }
} }

View file

@ -35,10 +35,10 @@ impl Processors for MemProcessors {
let processor = processors.entry(network).or_insert_with(VecDeque::new); let processor = processors.entry(network).or_insert_with(VecDeque::new);
processor.push_back(msg.into()); processor.push_back(msg.into());
} }
async fn recv(&mut self, _: NetworkId) -> Message { async fn recv(&self, _: NetworkId) -> Message {
todo!() todo!()
} }
async fn ack(&mut self, _: Message) { async fn ack(&self, _: Message) {
todo!() todo!()
} }
} }

View file

@ -16,12 +16,10 @@ rustdoc-args = ["--cfg", "docsrs"]
[dependencies] [dependencies]
# Macros # Macros
once_cell = { version = "1", default-features = false } once_cell = { version = "1", default-features = false }
serde = { version = "1", default-features = false, features = ["std", "derive"] }
# Encoders # Encoders
hex = { version = "0.4", default-features = false, features = ["std"] } hex = { version = "0.4", default-features = false, features = ["std"] }
borsh = { version = "1", default-features = false, features = ["std", "derive", "de_strict_order"] } borsh = { version = "1", default-features = false, features = ["std", "derive", "de_strict_order"] }
serde_json = { version = "1", default-features = false, features = ["std"] }
# Libs # Libs
zeroize = { version = "1", default-features = false, features = ["std"] } zeroize = { version = "1", default-features = false, features = ["std"] }
@ -37,16 +35,13 @@ log = { version = "0.4", default-features = false, features = ["std"] }
env_logger = { version = "0.10", default-features = false, features = ["humantime"] } env_logger = { version = "0.10", default-features = false, features = ["humantime"] }
# Uses a single threaded runtime since this shouldn't ever be CPU-bound # Uses a single threaded runtime since this shouldn't ever be CPU-bound
tokio = { version = "1", default-features = false, features = ["rt", "time", "macros"] } tokio = { version = "1", default-features = false, features = ["rt", "time", "io-util", "net", "macros"] }
serai-db = { path = "../common/db", features = ["rocksdb"], optional = true } serai-db = { path = "../common/db", features = ["rocksdb"], optional = true }
serai-env = { path = "../common/env" } serai-env = { path = "../common/env" }
serai-primitives = { path = "../substrate/primitives", features = ["borsh", "serde"] } serai-primitives = { path = "../substrate/primitives", features = ["borsh"] }
jsonrpsee = { version = "0.16", default-features = false, features = ["server"], optional = true }
simple-request = { path = "../common/request", default-features = false }
[features] [features]
binaries = ["serai-db", "jsonrpsee"] binaries = ["serai-db"]

View file

@ -9,19 +9,20 @@ use ciphersuite::{
}; };
use schnorr_signatures::SchnorrSignature; use schnorr_signatures::SchnorrSignature;
use serde::{Serialize, Deserialize}; use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
use simple_request::{hyper::Request, Client}; net::TcpStream,
};
use serai_env as env; use serai_env as env;
use crate::{Service, Metadata, QueuedMessage, message_challenge, ack_challenge}; #[rustfmt::skip]
use crate::{Service, Metadata, QueuedMessage, MessageQueueRequest, message_challenge, ack_challenge};
pub struct MessageQueue { pub struct MessageQueue {
pub service: Service, pub service: Service,
priv_key: Zeroizing<<Ristretto as Ciphersuite>::F>, priv_key: Zeroizing<<Ristretto as Ciphersuite>::F>,
pub_key: <Ristretto as Ciphersuite>::G, pub_key: <Ristretto as Ciphersuite>::G,
client: Client,
url: String, url: String,
} }
@ -37,17 +38,8 @@ impl MessageQueue {
if !url.contains(':') { if !url.contains(':') {
url += ":2287"; url += ":2287";
} }
if !url.starts_with("http://") {
url = "http://".to_string() + &url;
}
MessageQueue { MessageQueue { service, pub_key: Ristretto::generator() * priv_key.deref(), priv_key, url }
service,
pub_key: Ristretto::generator() * priv_key.deref(),
priv_key,
client: Client::with_connection_pool(),
url,
}
} }
pub fn from_env(service: Service) -> MessageQueue { pub fn from_env(service: Service) -> MessageQueue {
@ -72,60 +64,14 @@ impl MessageQueue {
Self::new(service, url, priv_key) Self::new(service, url, priv_key)
} }
async fn json_call(&self, method: &'static str, params: serde_json::Value) -> serde_json::Value { #[must_use]
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] async fn send(socket: &mut TcpStream, msg: MessageQueueRequest) -> bool {
struct JsonRpcRequest { let msg = borsh::to_vec(&msg).unwrap();
jsonrpc: &'static str, let Ok(_) = socket.write_all(&u32::try_from(msg.len()).unwrap().to_le_bytes()).await else {
method: &'static str, return false;
params: serde_json::Value,
id: u64,
}
let mut res = loop {
// Make the request
match self
.client
.request(
Request::post(&self.url)
.header("Content-Type", "application/json")
.body(
serde_json::to_vec(&JsonRpcRequest {
jsonrpc: "2.0",
method,
params: params.clone(),
id: 0,
})
.unwrap()
.into(),
)
.unwrap(),
)
.await
{
Ok(req) => {
// Get the response
match req.body().await {
Ok(res) => break res,
Err(e) => {
dbg!(e);
}
}
}
Err(e) => {
dbg!(e);
}
}
// Sleep for a second before trying again
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
}; };
let Ok(_) = socket.write_all(&msg).await else { return false };
let json: serde_json::Value = true
serde_json::from_reader(&mut res).expect("message-queue returned invalid JSON");
if json.get("result").is_none() {
panic!("call failed: {json}");
}
json
} }
pub async fn queue(&self, metadata: Metadata, msg: Vec<u8>) { pub async fn queue(&self, metadata: Metadata, msg: Vec<u8>) {
@ -146,30 +92,76 @@ impl MessageQueue {
) )
.serialize(); .serialize();
let json = self.json_call("queue", serde_json::json!([metadata, msg, sig])).await; let msg = MessageQueueRequest::Queue { meta: metadata, msg, sig };
if json.get("result") != Some(&serde_json::Value::Bool(true)) { let mut first = true;
panic!("failed to queue message: {json}"); loop {
// Sleep, so we don't hammer re-attempts
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
}
first = false;
let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
if !Self::send(&mut socket, msg.clone()).await {
continue;
}
if socket.read_u8().await.ok() != Some(1) {
continue;
}
break;
} }
} }
pub async fn next(&self, from: Service) -> QueuedMessage { pub async fn next(&self, from: Service) -> QueuedMessage {
let msg = MessageQueueRequest::Next { from, to: self.service };
let mut first = true;
'outer: loop {
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
continue;
}
first = false;
let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
loop { loop {
let json = self.json_call("next", serde_json::json!([from, self.service])).await; if !Self::send(&mut socket, msg.clone()).await {
continue 'outer;
// Convert from a Value to a type via reserialization }
let msg: Option<QueuedMessage> = serde_json::from_str( let Ok(status) = socket.read_u8().await else {
&serde_json::to_string( continue 'outer;
&json.get("result").expect("successful JSON RPC call didn't have result"), };
)
.unwrap(),
)
.expect("next didn't return an Option<QueuedMessage>");
// If there wasn't a message, check again in 1s // If there wasn't a message, check again in 1s
let Some(msg) = msg else { if status == 0 {
tokio::time::sleep(core::time::Duration::from_secs(1)).await; tokio::time::sleep(core::time::Duration::from_secs(1)).await;
continue; continue;
}
assert_eq!(status, 1);
break;
}
// Timeout after 5 seconds in case there's an issue with the length handling
let Ok(msg) = tokio::time::timeout(core::time::Duration::from_secs(5), async {
// Read the message length
let Ok(len) = socket.read_u32_le().await else {
return vec![];
}; };
let mut buf = vec![0; usize::try_from(len).unwrap()];
// Read the message
let Ok(_) = socket.read_exact(&mut buf).await else {
return vec![];
};
buf
})
.await
else {
continue;
};
if msg.is_empty() {
continue;
}
let msg: QueuedMessage = borsh::from_slice(msg.as_slice()).unwrap();
// Verify the message // Verify the message
// Verify the sender is sane // Verify the sender is sane
@ -202,9 +194,22 @@ impl MessageQueue {
) )
.serialize(); .serialize();
let json = self.json_call("ack", serde_json::json!([from, self.service, id, sig])).await; let msg = MessageQueueRequest::Ack { from, to: self.service, id, sig };
if json.get("result") != Some(&serde_json::Value::Bool(true)) { let mut first = true;
panic!("failed to ack message {id}: {json}"); loop {
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
}
first = false;
let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
if !Self::send(&mut socket, msg.clone()).await {
continue;
}
if socket.read_u8().await.ok() != Some(1) {
continue;
}
break;
} }
} }
} }

View file

@ -15,9 +15,12 @@ mod binaries {
pub(crate) use serai_primitives::NetworkId; pub(crate) use serai_primitives::NetworkId;
use serai_db::{Get, DbTxn, Db as DbTrait}; pub(crate) use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
};
pub(crate) use jsonrpsee::{RpcModule, server::ServerBuilder}; use serai_db::{Get, DbTxn, Db as DbTrait};
pub(crate) use crate::messages::*; pub(crate) use crate::messages::*;
@ -51,7 +54,7 @@ mod binaries {
successful ordering by the time this call returns. successful ordering by the time this call returns.
*/ */
pub(crate) fn queue_message( pub(crate) fn queue_message(
db: &RwLock<Db>, db: &mut Db,
meta: Metadata, meta: Metadata,
msg: Vec<u8>, msg: Vec<u8>,
sig: SchnorrSignature<Ristretto>, sig: SchnorrSignature<Ristretto>,
@ -78,7 +81,6 @@ mod binaries {
fn intent_key(from: Service, to: Service, intent: &[u8]) -> Vec<u8> { fn intent_key(from: Service, to: Service, intent: &[u8]) -> Vec<u8> {
key(b"intent_seen", borsh::to_vec(&(from, to, intent)).unwrap()) key(b"intent_seen", borsh::to_vec(&(from, to, intent)).unwrap())
} }
let mut db = db.write().unwrap();
let mut txn = db.txn(); let mut txn = db.txn();
let intent_key = intent_key(meta.from, meta.to, &meta.intent); let intent_key = intent_key(meta.from, meta.to, &meta.intent);
if Get::get(&txn, &intent_key).is_some() { if Get::get(&txn, &intent_key).is_some() {
@ -148,7 +150,7 @@ mod binaries {
} }
#[cfg(feature = "binaries")] #[cfg(feature = "binaries")]
#[tokio::main] #[tokio::main(flavor = "current_thread")]
async fn main() { async fn main() {
use binaries::*; use binaries::*;
@ -225,48 +227,55 @@ async fn main() {
register_service(Service::Coordinator, read_key("COORDINATOR_KEY").unwrap()); register_service(Service::Coordinator, read_key("COORDINATOR_KEY").unwrap());
// Start server // Start server
let builder = ServerBuilder::new();
// TODO: Add middleware to check some key is present in the header, making this an authed
// connection
// TODO: Set max request/response size
// 5132 ^ ((b'M' << 8) | b'Q') // 5132 ^ ((b'M' << 8) | b'Q')
let listen_on: &[std::net::SocketAddr] = &["0.0.0.0:2287".parse().unwrap()]; let server = TcpListener::bind("0.0.0.0:2287").await.unwrap();
let server = builder.build(listen_on).await.unwrap();
let mut module = RpcModule::new(RwLock::new(db)); loop {
module let (mut socket, _) = server.accept().await.unwrap();
.register_method("queue", |args, db| { // TODO: Add a magic value with a key at the start of the connection to make this authed
let args = args.parse::<(Metadata, Vec<u8>, Vec<u8>)>().unwrap(); let mut db = db.clone();
tokio::spawn(async move {
loop {
let Ok(msg_len) = socket.read_u32_le().await else { break };
let mut buf = vec![0; usize::try_from(msg_len).unwrap()];
let Ok(_) = socket.read_exact(&mut buf).await else { break };
let msg = borsh::from_slice(&buf).unwrap();
match msg {
MessageQueueRequest::Queue { meta, msg, sig } => {
queue_message( queue_message(
db, &mut db,
args.0, meta,
args.1, msg,
SchnorrSignature::<Ristretto>::read(&mut args.2.as_slice()).unwrap(), SchnorrSignature::<Ristretto>::read(&mut sig.as_slice()).unwrap(),
); );
Ok(true) let Ok(_) = socket.write_all(&[1]).await else { break };
}) }
.unwrap(); MessageQueueRequest::Next { from, to } => match get_next_message(from, to) {
module Some(msg) => {
.register_method("next", |args, _| { let Ok(_) = socket.write_all(&[1]).await else { break };
let (from, to) = args.parse::<(Service, Service)>().unwrap(); let msg = borsh::to_vec(&msg).unwrap();
Ok(get_next_message(from, to)) let len = u32::try_from(msg.len()).unwrap();
}) let Ok(_) = socket.write_all(&len.to_le_bytes()).await else { break };
.unwrap(); let Ok(_) = socket.write_all(&msg).await else { break };
module }
.register_method("ack", |args, _| { None => {
let args = args.parse::<(Service, Service, u64, Vec<u8>)>().unwrap(); let Ok(_) = socket.write_all(&[0]).await else { break };
}
},
MessageQueueRequest::Ack { from, to, id, sig } => {
ack_message( ack_message(
args.0, from,
args.1, to,
args.2, id,
SchnorrSignature::<Ristretto>::read(&mut args.3.as_slice()).unwrap(), SchnorrSignature::<Ristretto>::read(&mut sig.as_slice()).unwrap(),
); );
Ok(true) let Ok(_) = socket.write_all(&[1]).await else { break };
}) }
.unwrap(); }
}
// Run until stopped, which it never will });
server.start(module).unwrap().stopped().await; }
} }
#[cfg(not(feature = "binaries"))] #[cfg(not(feature = "binaries"))]

View file

@ -2,19 +2,16 @@ use transcript::{Transcript, RecommendedTranscript};
use ciphersuite::{group::GroupEncoding, Ciphersuite, Ristretto}; use ciphersuite::{group::GroupEncoding, Ciphersuite, Ristretto};
use borsh::{BorshSerialize, BorshDeserialize}; use borsh::{BorshSerialize, BorshDeserialize};
use serde::{Serialize, Deserialize};
use serai_primitives::NetworkId; use serai_primitives::NetworkId;
#[derive( #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, BorshSerialize, BorshDeserialize)]
Clone, Copy, PartialEq, Eq, Hash, Debug, BorshSerialize, BorshDeserialize, Serialize, Deserialize,
)]
pub enum Service { pub enum Service {
Processor(NetworkId), Processor(NetworkId),
Coordinator, Coordinator,
} }
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Debug, BorshSerialize, BorshDeserialize)]
pub struct QueuedMessage { pub struct QueuedMessage {
pub from: Service, pub from: Service,
pub id: u64, pub id: u64,
@ -22,13 +19,20 @@ pub struct QueuedMessage {
pub sig: Vec<u8>, pub sig: Vec<u8>,
} }
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Debug, BorshSerialize, BorshDeserialize)]
pub struct Metadata { pub struct Metadata {
pub from: Service, pub from: Service,
pub to: Service, pub to: Service,
pub intent: Vec<u8>, pub intent: Vec<u8>,
} }
#[derive(Clone, PartialEq, Eq, Debug, BorshSerialize, BorshDeserialize)]
pub enum MessageQueueRequest {
Queue { meta: Metadata, msg: Vec<u8>, sig: Vec<u8> },
Next { from: Service, to: Service },
Ack { from: Service, to: Service, id: u64, sig: Vec<u8> },
}
pub fn message_challenge( pub fn message_challenge(
from: Service, from: Service,
from_key: <Ristretto as Ciphersuite>::G, from_key: <Ristretto as Ciphersuite>::G,

View file

@ -45,7 +45,7 @@ impl<D: Db> Queue<D> {
let msg_key = self.message_key(id); let msg_key = self.message_key(id);
let msg_count_key = self.message_count_key(); let msg_count_key = self.message_count_key();
txn.put(msg_key, serde_json::to_vec(&msg).unwrap()); txn.put(msg_key, borsh::to_vec(&msg).unwrap());
txn.put(msg_count_key, (id + 1).to_le_bytes()); txn.put(msg_count_key, (id + 1).to_le_bytes());
id id
@ -53,7 +53,7 @@ impl<D: Db> Queue<D> {
pub(crate) fn get_message(&self, id: u64) -> Option<QueuedMessage> { pub(crate) fn get_message(&self, id: u64) -> Option<QueuedMessage> {
let msg: Option<QueuedMessage> = let msg: Option<QueuedMessage> =
self.0.get(self.message_key(id)).map(|bytes| serde_json::from_slice(&bytes).unwrap()); self.0.get(self.message_key(id)).map(|bytes| borsh::from_slice(&bytes).unwrap());
if let Some(msg) = msg.as_ref() { if let Some(msg) = msg.as_ref() {
assert_eq!(msg.id, id, "message stored at {id} has ID {}", msg.id); assert_eq!(msg.id, id, "message stored at {id} has ID {}", msg.id);
} }

View file

@ -70,6 +70,7 @@ fn basic_functionality() {
let (coord_key, priv_keys, composition) = instance(); let (coord_key, priv_keys, composition) = instance();
test.provide_container(composition); test.provide_container(composition);
test.run(|ops| async move { test.run(|ops| async move {
tokio::time::timeout(core::time::Duration::from_secs(60), async move {
// Sleep for a second for the message-queue to boot // Sleep for a second for the message-queue to boot
// It isn't an error to start immediately, it just silences an error // It isn't an error to start immediately, it just silences an error
tokio::time::sleep(core::time::Duration::from_secs(1)).await; tokio::time::sleep(core::time::Duration::from_secs(1)).await;
@ -157,5 +158,8 @@ fn basic_functionality() {
tokio::time::timeout(core::time::Duration::from_secs(10), monero.next(Service::Coordinator)) tokio::time::timeout(core::time::Duration::from_secs(10), monero.next(Service::Coordinator))
.await .await
.unwrap_err(); .unwrap_err();
})
.await
.unwrap();
}); });
} }