use crate::api::Endpoint;
use heapless::consts::U500;
use hid_io_protocol::commands::CommandError;
use hid_io_protocol::{HidIoCommandId, HidIoPacketType};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::stream::StreamExt;
use tokio::sync::broadcast;
pub type HidIoPacketBufferDataSize = U500;
pub type HidIoPacketBuffer = hid_io_protocol::HidIoPacketBuffer<HidIoPacketBufferDataSize>;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Address {
All,
ApiCapnp {
uid: u64,
},
CancelAllSubscriptions,
CancelSubscription {
uid: u64,
sid: u64,
},
DeviceHidio {
uid: u64,
},
DeviceHid {
uid: u64,
},
DropSubscription,
Module,
}
const CHANNEL_SLOTS: usize = 100;
#[derive(Clone, Debug)]
pub struct Mailbox {
pub nodes: Arc<RwLock<Vec<Endpoint>>>,
pub last_uid: Arc<RwLock<u64>>,
pub lookup: Arc<RwLock<HashMap<String, Vec<u64>>>>,
pub sender: broadcast::Sender<Message>,
pub ack_timeout: Arc<RwLock<std::time::Duration>>,
pub rt: Arc<tokio::runtime::Runtime>,
}
impl Mailbox {
pub fn new(rt: Arc<tokio::runtime::Runtime>) -> Mailbox {
let (sender, _) = broadcast::channel::<Message>(CHANNEL_SLOTS);
let nodes = Arc::new(RwLock::new(vec![]));
let lookup = Arc::new(RwLock::new(HashMap::new()));
let last_uid: Arc<RwLock<u64>> = Arc::new(RwLock::new(0));
let ack_timeout: Arc<RwLock<std::time::Duration>> =
Arc::new(RwLock::new(std::time::Duration::from_millis(2000)));
Mailbox {
nodes,
last_uid,
lookup,
sender,
ack_timeout,
rt,
}
}
pub fn get_uid(&mut self, key: String, path: String) -> Option<u64> {
let mut lookup = self.lookup.write().unwrap();
let lookup_entry = lookup.entry(key).or_default();
'outer: for uid in lookup_entry.iter() {
for mut node in (*self.nodes.read().unwrap()).clone() {
if node.uid() == *uid {
if node.path() == path {
return Some(0);
}
continue 'outer;
}
}
return Some(*uid);
}
None
}
pub fn add_uid(&mut self, key: String, uid: u64) {
let mut lookup = self.lookup.write().unwrap();
let lookup_entry = lookup.entry(key).or_default();
lookup_entry.push(uid);
}
pub fn assign_uid(&mut self, key: String, path: String) -> Result<u64, std::io::Error> {
match self.get_uid(key.clone(), path) {
Some(0) => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"uid has already been registered!",
)),
Some(uid) => Ok(uid),
None => {
(*self.last_uid.write().unwrap()) += 1;
let uid = *self.last_uid.read().unwrap();
self.add_uid(key, uid);
Ok(uid)
}
}
}
pub fn register_node(&mut self, mut endpoint: Endpoint) {
info!("Registering endpoint: {}", endpoint.uid());
let mut nodes = self.nodes.write().unwrap();
(*nodes).push(endpoint);
}
pub fn unregister_node(&mut self, uid: u64) {
info!("Unregistering endpoint: {}", uid);
let mut nodes = self.nodes.write().unwrap();
*nodes = nodes
.drain_filter(|dev| dev.uid() != uid)
.collect::<Vec<_>>();
}
pub async fn send_command(
&self,
src: Address,
dst: Address,
id: HidIoCommandId,
data: Vec<u8>,
ack: bool,
) -> Result<Option<Message>, AckWaitError> {
let ptype = HidIoPacketType::Data;
let data = HidIoPacketBuffer {
ptype,
id,
max_len: 64,
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
if self.sender.receiver_count() == 0 {
error!("send_command (no active receivers)");
return Err(AckWaitError::NoActiveReceivers);
}
let receiver = self.sender.subscribe();
let result = self.sender.send(Message {
src,
dst,
data: data.clone(),
});
if let Err(e) = result {
error!(
"send_command failed, something is odd, should not get here... {:?}",
e
);
return Err(AckWaitError::NoActiveReceivers);
}
if !ack {
return Ok(None);
}
tokio::pin! {
let stream = receiver.into_stream()
.filter(Result::is_ok)
.map(Result::unwrap)
.filter(|msg| msg.src == src && msg.dst == Address::All && msg.data.id == id);
}
let ack_timeout = *self.ack_timeout.read().unwrap();
loop {
match tokio::time::timeout(ack_timeout, stream.next()).await {
Ok(msg) => {
if let Some(msg) = msg {
match msg.data.ptype {
HidIoPacketType::Ack => {
return Ok(Some(msg));
}
HidIoPacketType::Nak => {
let msg = Box::new(msg);
return Err(AckWaitError::NakReceived { msg });
}
_ => {}
}
} else {
return Err(AckWaitError::Invalid);
}
}
Err(_) => {
warn!("Timeout ({:?}) receiving Ack for: {}", ack_timeout, data);
return Err(AckWaitError::Timeout);
}
}
}
}
pub fn try_send_message(&self, msg: Message) -> Result<Option<Message>, CommandError> {
if self.sender.receiver_count() == 0 {
error!("send_command (no active receivers)");
return Err(CommandError::TxNoActiveReceivers);
}
let mut receiver = self.sender.subscribe();
let result = self.sender.send(msg.clone());
if let Err(e) = result {
error!(
"send_command failed, something is odd, should not get here... {:?}",
e
);
return Err(CommandError::TxNoActiveReceivers);
}
if msg.data.ptype != HidIoPacketType::Data {
return Ok(None);
}
let start_time = std::time::Instant::now();
loop {
if start_time.elapsed() >= *self.ack_timeout.read().unwrap() {
warn!(
"Timeout ({:?}) receiving Ack for command: src:{:?} dst:{:?}",
*self.ack_timeout.read().unwrap(),
msg.src,
msg.dst
);
return Err(CommandError::RxTimeout);
}
match receiver.try_recv() {
Ok(rcvmsg) => {
if rcvmsg.dst == Address::All
&& rcvmsg.src == msg.dst
&& rcvmsg.data.id == msg.data.id
{
match rcvmsg.data.ptype {
HidIoPacketType::Ack | HidIoPacketType::Nak => {
return Ok(Some(rcvmsg));
}
_ => {}
}
}
}
Err(broadcast::error::TryRecvError::Empty) => {
std::thread::yield_now();
std::thread::sleep(std::time::Duration::from_millis(1));
}
Err(broadcast::error::TryRecvError::Lagged(_skipped)) => {}
Err(broadcast::error::TryRecvError::Closed) => {
return Err(CommandError::TxBufferSendFailed);
}
}
}
}
pub fn try_send_command(
&self,
src: Address,
dst: Address,
id: HidIoCommandId,
data: Vec<u8>,
ack: bool,
) -> Result<Option<Message>, AckWaitError> {
let ptype = HidIoPacketType::Data;
let data = HidIoPacketBuffer {
ptype,
id,
max_len: 64,
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
if self.sender.receiver_count() == 0 {
error!("send_command (no active receivers)");
return Err(AckWaitError::NoActiveReceivers);
}
let mut receiver = self.sender.subscribe();
let result = self.sender.send(Message { src, dst, data });
if let Err(e) = result {
error!(
"send_command failed, something is odd, should not get here... {:?}",
e
);
return Err(AckWaitError::NoActiveReceivers);
}
if !ack {
return Ok(None);
}
let start_time = std::time::Instant::now();
loop {
if start_time.elapsed() >= *self.ack_timeout.read().unwrap() {
warn!(
"Timeout ({:?}) receiving Ack for command: src:{:?} dst:{:?}",
*self.ack_timeout.read().unwrap(),
src,
dst
);
return Err(AckWaitError::Timeout);
}
match receiver.try_recv() {
Ok(msg) => {
if msg.dst == Address::All && msg.src == dst && msg.data.id == id {
match msg.data.ptype {
HidIoPacketType::Ack => {
return Ok(Some(msg));
}
HidIoPacketType::Nak => {
let msg = Box::new(msg);
return Err(AckWaitError::NakReceived { msg });
}
_ => {}
}
}
}
Err(broadcast::error::TryRecvError::Empty) => {
std::thread::yield_now();
std::thread::sleep(std::time::Duration::from_millis(1));
}
Err(broadcast::error::TryRecvError::Lagged(_skipped)) => {}
Err(broadcast::error::TryRecvError::Closed) => {
return Err(AckWaitError::ChannelClosed);
}
}
}
}
pub fn drop_subscriber(&self, uid: u64, sid: u64) {
let data = HidIoPacketBuffer::default();
let result = self.sender.send(Message {
src: Address::DropSubscription,
dst: Address::CancelSubscription { uid, sid },
data,
});
if let Err(e) = result {
error!("drop_subscriber {:?}", e);
}
}
pub fn drop_all_subscribers(&self) {
let data = HidIoPacketBuffer::default();
let result = self.sender.send(Message {
src: Address::DropSubscription,
dst: Address::CancelAllSubscriptions,
data,
});
if let Err(e) = result {
error!("drop_all_subscribers {:?}", e);
}
}
}
impl Default for Mailbox {
fn default() -> Self {
let rt: Arc<tokio::runtime::Runtime> = Arc::new(
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap(),
);
Self::new(rt)
}
}
#[derive(PartialEq, Clone, Debug)]
pub struct Message {
pub src: Address,
pub dst: Address,
pub data: HidIoPacketBuffer,
}
impl Message {
pub fn new(src: Address, dst: Address, data: HidIoPacketBuffer) -> Message {
Message { src, dst, data }
}
pub fn send_ack(&self, sender: broadcast::Sender<Message>, data: Vec<u8>) {
let src = self.dst;
let dst = self.src;
let data = HidIoPacketBuffer {
ptype: HidIoPacketType::Ack,
id: self.data.id,
max_len: 64,
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
let result = sender.send(Message { src, dst, data });
if let Err(e) = result {
error!("send_ack {:?}", e);
}
}
pub fn send_nak(&self, sender: broadcast::Sender<Message>, data: Vec<u8>) {
let src = self.dst;
let dst = self.src;
let data = HidIoPacketBuffer {
ptype: HidIoPacketType::Nak,
id: self.data.id,
max_len: 64,
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
let result = sender.send(Message { src, dst, data });
if let Err(e) = result {
error!("send_ack {:?}", e);
}
}
}
#[derive(Debug)]
pub enum AckWaitError {
TooManySyncs,
NakReceived { msg: Box<Message> },
Invalid,
NoActiveReceivers,
Timeout,
ChannelClosed,
}