| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668 |
- use crate::nockapp::driver::{make_driver, IODriverFn, PokeResult, TaskJoinSet};
- use crate::nockapp::wire::{Wire, WireRepr};
- use crate::nockapp::NockAppError;
- use crate::noun::slab::NounSlab;
- use crate::Bytes;
- use bytes::buf::BufMut;
- use std::sync::Arc;
- use sword::noun::{D, T};
- use sword_macros::tas;
- use tokio::io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
- use tokio::net::{UnixListener, UnixStream};
- use tokio::select;
- use tokio::sync::Mutex;
- use tokio::task::JoinSet;
- use tokio::time::{sleep, Duration};
- use tracing::{debug, error};
- pub enum NpcWire {
- Poke(u64),
- Pack(u64),
- Nack(u64),
- Bind(u64),
- }
- impl Wire for NpcWire {
- const VERSION: u64 = 1;
- const SOURCE: &'static str = "npc";
- fn to_wire(&self) -> WireRepr {
- let tags = match self {
- NpcWire::Poke(pid) => vec!["poke".into(), pid.into()],
- NpcWire::Pack(pid) => vec!["pack".into(), pid.into()],
- NpcWire::Nack(pid) => vec!["nack".into(), pid.into()],
- NpcWire::Bind(pid) => vec!["bind".into(), pid.into()],
- };
- WireRepr::new(Self::SOURCE, Self::VERSION, tags)
- }
- }
- /// NPC Listener IO driver
- pub fn npc_listener(listener: UnixListener) -> IODriverFn {
- make_driver(move |mut handle| async move {
- let mut client_join_set = TaskJoinSet::new();
- loop {
- select! {
- stream_res = listener.accept() => {
- debug!("Accepted new connection");
- match stream_res {
- Ok((stream, _)) => {
- let (my_handle, their_handle) = handle.dup();
- handle = my_handle;
- let _ = client_join_set.spawn(npc_client(stream)(their_handle));
- },
- Err(e) => {
- error!("Error accepting connection: {:?}", e);
- }
- }
- },
- Some(result) = client_join_set.join_next() => {
- match result {
- Ok(Ok(())) => debug!("npc: client task completed successfully"),
- Ok(Err(e)) => error!("npc: client task error: {:?}", e),
- Err(e) => error!("npc: client task join error: {:?}", e),
- }
- },
- // TODO: don't do this, revive robin hood
- _ = sleep(Duration::from_millis(100)) => {
- // avoid tight-looping
- }
- }
- }
- })
- }
- /// NPC Client IO driver
- pub fn npc_client(stream: UnixStream) -> IODriverFn {
- make_driver(move |handle| async move {
- let (stream_read, mut stream_write) = split(stream);
- let stream_read_arc = Arc::new(Mutex::new(stream_read));
- let mut read_message_join_set = JoinSet::new();
- read_message_join_set.spawn(read_message(stream_read_arc.clone()));
- 'driver: loop {
- select! {
- message = read_message_join_set.join_next() => {
- match message {
- Some(Ok(Ok(Some(mut slab)))) => {
- debug!("npc_client: read message");
- let Ok(message_cell) = unsafe { slab.root() }.as_cell() else {
- continue;
- };
- let (pid, directive_cell) = match (message_cell.head().as_direct(), message_cell.tail().as_cell()) {
- (Ok(direct), Ok(cell)) => (direct.data(), cell),
- _ => continue,
- };
- let Ok(directive_tag) = directive_cell.head().as_direct() else {
- continue;
- };
- let directive_tag = directive_tag.data();
- match directive_tag {
- tas!(b"poke") => {
- debug!("npc_client: poke");
- let mut poke_slab = NounSlab::new();
- let poke = directive_cell.tail();
- poke_slab.copy_into(poke);
- let wire = NpcWire::Poke(pid).to_wire();
- let result = handle.poke(wire, poke_slab).await?;
- let (tag, noun) = match result {
- PokeResult::Ack => (tas!(b"pack"), D(0)),
- PokeResult::Nack => (tas!(b"nack"), D(0)),
- };
- let mut response_slab = NounSlab::new();
- let response_noun = T(&mut response_slab, &[D(pid), D(tag), noun]);
- response_slab.set_root(response_noun);
- if !write_message(&mut stream_write, response_slab).await? {
- break 'driver;
- }
- },
- tas!(b"peek") => {
- debug!("npc_client: peek");
- let path = directive_cell.tail();
- slab.set_root(path);
- let peek_res = handle.peek(slab).await?;
- match peek_res {
- Some(mut bind_slab) => {
- bind_slab.modify(|root| {
- vec![D(pid), D(tas!(b"bind")), root]
- });
- if !write_message(&mut stream_write, bind_slab).await? {
- break 'driver;
- }
- },
- None => {
- error!("npc: peek failed!");
- }
- }
- },
- tas!(b"pack") | tas!(b"nack") | tas!(b"bind") => {
- debug!("npc_client: pack, nack, or bind");
- let tag = match directive_tag {
- tas!(b"pack") => tas!(b"npc-pack"),
- tas!(b"nack") => tas!(b"npc-nack"),
- tas!(b"bind") => tas!(b"npc-bind"),
- _ => unreachable!(),
- };
- let wire = match directive_tag {
- tas!(b"pack") => NpcWire::Pack(pid),
- tas!(b"nack") => NpcWire::Nack(pid),
- tas!(b"bind") => NpcWire::Bind(pid),
- _ => unreachable!(),
- };
- let poke = if tag == tas!(b"npc-bind") {
- T(&mut slab, &[D(tag), D(pid), directive_cell.tail()])
- } else {
- T(&mut slab, &[D(tag), D(pid)])
- };
- slab.set_root(poke);
- handle.poke(wire.to_wire(), slab).await?;
- },
- _ => {
- debug!("npc_client: unexpected message: {:?}", directive_tag);
- },
- }
- },
- Some(Ok(Ok(None))) => {
- break 'driver;
- },
- Some(Err(e)) => {
- error!("{e:?}");
- },
- Some(Ok(Err(e))) => {
- error!("{e:?}");
- },
- None => {
- read_message_join_set.spawn(read_message(stream_read_arc.clone()));
- }
- }
- },
- effect_res = handle.next_effect() => {
- let mut slab = effect_res?; // Closed error should error driver
- let Ok(effect_cell) = unsafe { slab.root() }.as_cell() else {
- continue;
- };
- // TODO: distinguish connections
- if unsafe { effect_cell.head().raw_equals(&D(tas!(b"npc"))) } {
- slab.set_root(effect_cell.tail());
- if !write_message(&mut stream_write, slab).await? {
- break 'driver;
- }
- }
- }
- }
- }
- Ok(())
- })
- }
- async fn read_message(
- stream_arc: Arc<Mutex<ReadHalf<UnixStream>>>,
- ) -> Result<Option<NounSlab>, NockAppError> {
- let mut stream = stream_arc.lock_owned().await;
- let mut size_bytes = [0u8; 8];
- debug!("Attempting to read message size...");
- match stream.read_exact(&mut size_bytes).await {
- Ok(0) => {
- debug!("Connection closed");
- return Ok(None);
- }
- Ok(size) => {
- debug!("Read size: {:?}", size);
- }
- Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
- debug!("Connection closed unexpectedly");
- return Ok(None);
- }
- Err(e) => {
- debug!("Error reading size: {:?}", e);
- return Err(NockAppError::IoError(e));
- }
- }
- let size = usize::from_le_bytes(size_bytes);
- debug!("Message size: {} bytes", size);
- let mut buf = Vec::with_capacity(size).limit(size);
- while buf.remaining_mut() > 0 {
- debug!(
- "Reading message content, {} bytes remaining",
- buf.remaining_mut()
- );
- match stream.read_buf(&mut buf).await {
- Ok(0) => {
- debug!("Connection closed while reading message content");
- return Ok(None);
- }
- Ok(_) => {}
- Err(e) => return Err(NockAppError::IoError(e)),
- }
- }
- debug!("Successfully read entire message");
- let mut slab = NounSlab::new();
- let noun = slab.cue_into(Bytes::from(buf.into_inner()))?;
- slab.set_root(noun);
- Ok(Some(slab))
- }
- async fn write_message(
- stream: &mut WriteHalf<UnixStream>,
- msg_slab: NounSlab,
- ) -> Result<bool, NockAppError> {
- let msg_bytes = msg_slab.jam();
- let msg_len = msg_bytes.len();
- debug!("Attempting to write message of {} bytes", msg_len);
- let mut msg_len_bytes = &msg_len.to_le_bytes()[..];
- let mut msg_buf = &msg_bytes[..];
- while !msg_len_bytes.is_empty() {
- debug!(
- "Writing message length, {} bytes remaining",
- msg_len_bytes.len()
- );
- let bytes = stream
- .write_buf(&mut msg_len_bytes)
- .await
- .map_err(NockAppError::IoError)?;
- if bytes == 0 {
- debug!("Wrote 0 bytes for message length, returning false");
- return Ok(false);
- }
- }
- while !msg_buf.is_empty() {
- debug!("Writing message content, {} bytes remaining", msg_buf.len());
- let bytes = stream
- .write_buf(&mut msg_buf)
- .await
- .map_err(NockAppError::IoError)?;
- if bytes == 0 {
- debug!("Wrote 0 bytes for message content, returning false");
- return Ok(false);
- }
- }
- debug!("Successfully wrote entire message");
- Ok(true)
- }
- #[cfg(test)]
- mod tests {
- use crate::nockapp::driver::{IOAction, NockAppHandle};
- use super::*;
- use std::io::{Read, Write};
- use std::os::unix::net::UnixStream as StdUnixStream;
- use std::time::Duration;
- use tempfile::tempdir;
- use tokio::net::UnixStream;
- use tokio::sync::{broadcast, mpsc};
- use tokio::time::timeout;
- use tracing_test::traced_test;
- async fn setup_socket_pair() -> (UnixStream, StdUnixStream) {
- let dir = tempdir().unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let socket_path = dir.path().join("test.sock");
- let listener = UnixListener::bind(&socket_path).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let client = StdUnixStream::connect(&socket_path).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let (server, _) = listener.accept().await.unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- (server, client)
- }
- #[tokio::test]
- #[cfg_attr(miri, ignore)]
- async fn test_write_message_format() {
- let (server, mut client) = setup_socket_pair().await;
- let (_, mut writer) = split(server);
- let mut test_slab = NounSlab::new();
- let test_noun = T(&mut test_slab, &[D(123), D(456)]);
- test_slab.set_root(test_noun);
- write_message(&mut writer, test_slab)
- .await
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let mut size_buf = [0u8; 8];
- client.read_exact(&mut size_buf).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let size = usize::from_le_bytes(size_buf);
- let mut msg_buf = vec![0u8; size];
- client.read_exact(&mut msg_buf).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let mut received_slab = NounSlab::new();
- let received_noun = received_slab
- .cue_into(Bytes::from(msg_buf))
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- received_slab.set_root(received_noun);
- let root = unsafe { received_slab.root() };
- let cell = root.as_cell().unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- assert_eq!(
- cell.head()
- .as_direct()
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- })
- .data(),
- 123
- );
- assert_eq!(
- cell.tail()
- .as_direct()
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- })
- .data(),
- 456
- );
- }
- #[tokio::test]
- #[cfg_attr(miri, ignore)]
- async fn test_write_message_empty() {
- let (server, mut client) = setup_socket_pair().await;
- let (_, mut writer) = split(server);
- let mut test_slab = NounSlab::new();
- let test_noun = T(&mut test_slab, &[D(0), D(0)]);
- test_slab.set_root(test_noun);
- assert!(write_message(&mut writer, test_slab)
- .await
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- }));
- let mut size_buf = [0u8; 8];
- client.read_exact(&mut size_buf).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- assert!(usize::from_le_bytes(size_buf) > 0);
- }
- #[tokio::test]
- #[cfg_attr(miri, ignore)]
- async fn test_read_message_eof() {
- let (server, client) = setup_socket_pair().await;
- drop(client);
- let stream_arc = Arc::new(Mutex::new(split(server).0));
- let result = read_message(stream_arc).await;
- assert!(result
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- })
- .is_none());
- }
- #[tokio::test]
- #[traced_test]
- #[cfg_attr(miri, ignore)]
- async fn test_npc_driver() {
- // Setup
- let dir = tempdir().unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- let socket_path = dir.path().join("test.sock");
- let listener = UnixListener::bind(&socket_path).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- // Create channels for driver communication
- let (tx_io, mut rx_io) = mpsc::channel(32);
- let (tx_effect, rx_effect) = broadcast::channel(32);
- let (tx_exit, _) = mpsc::channel(1);
- let handle = NockAppHandle {
- io_sender: tx_io,
- effect_sender: tx_effect.clone(),
- effect_receiver: Mutex::new(rx_effect),
- exit: tx_exit,
- };
- // Spawn the listener driver
- let _driver_task = tokio::spawn(npc_listener(listener)(handle));
- // Connect client
- let mut client = StdUnixStream::connect(&socket_path).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- // Create test noun slab
- let mut test_slab = NounSlab::new();
- let msg_noun = T(&mut test_slab, &[D(tas!(b"poke")), D(123), D(456)]);
- let test_noun = T(&mut test_slab, &[D(1), msg_noun]);
- test_slab.set_root(test_noun);
- // Jam the noun to bytes
- let msg_bytes = test_slab.jam();
- let msg_len = msg_bytes.len();
- // Write length prefix and jammed noun
- client
- .write_all(&(msg_len as u64).to_le_bytes())
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- client.write_all(&msg_bytes).unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- debug!("client: wrote {} bytes", msg_len);
- // Verify driver received poke
- if let Some(IOAction::Poke {
- wire: _wire,
- poke: noun_slab,
- ack_channel: _,
- }) = timeout(Duration::from_secs(1), rx_io.recv())
- .await
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- })
- {
- debug!("test_npc_driver: poke data: {:?}", unsafe {
- noun_slab.root()
- });
- // Verify noun content
- let noun = unsafe { noun_slab.root() };
- let noun_cell = noun.as_cell().unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- });
- assert_eq!(
- noun_cell
- .head()
- .as_direct()
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- })
- .data(),
- 123
- );
- assert_eq!(
- noun_cell
- .tail()
- .as_direct()
- .unwrap_or_else(|err| {
- panic!(
- "Panicked with {err:?} at {}:{} (git sha: {:?})",
- file!(),
- line!(),
- option_env!("GIT_SHA")
- )
- })
- .data(),
- 456
- );
- // TODO: make this work
- /* ack_channel.send(PokeResult::Ack).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- // Send effect through broadcast channel
- let mut ack_slab = NounSlab::new();
- let ack = T(&mut ack_slab.clone(), &[
- D(tas!(b"npc")),
- T(&mut ack_slab.clone(), &[D(123), D(tas!(b"pack")), D(0)])
- ]);
- ack_slab.set_root(ack);
- tx_effect.send(ack_slab).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- // Verify client receives ack
- let mut size_buf = [0u8; 8];
- client.read_exact(&mut size_buf).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- let size = usize::from_le_bytes(size_buf);
- let mut msg_buf = vec![0u8; size];
- client.read_exact(&mut msg_buf).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- let mut received_slab = NounSlab::new();
- let received_noun = received_slab.cue_into(Bytes::from(msg_buf)).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- received_slab.set_root(received_noun);
- let root = unsafe { received_slab.root() };
- let cell = root.as_cell().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- assert_eq!(cell.head().as_direct().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA"))).data(), 123);
- let rest = cell.tail().as_cell().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
- assert_eq!(rest.head().as_direct().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA"))).data(), tas!(b"pack"));
- assert_eq!(rest.tail().as_direct().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA"))).data(), 0); */
- } else {
- panic!("Did not receive poke message");
- }
- // Cleanup
- drop(client);
- }
- }
|