npc.rs 25 KB


  1. use crate::nockapp::driver::{make_driver, IODriverFn, PokeResult, TaskJoinSet};
  2. use crate::nockapp::wire::{Wire, WireRepr};
  3. use crate::nockapp::NockAppError;
  4. use crate::noun::slab::NounSlab;
  5. use crate::Bytes;
  6. use bytes::buf::BufMut;
  7. use std::sync::Arc;
  8. use sword::noun::{D, T};
  9. use sword_macros::tas;
  10. use tokio::io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
  11. use tokio::net::{UnixListener, UnixStream};
  12. use tokio::select;
  13. use tokio::sync::Mutex;
  14. use tokio::task::JoinSet;
  15. use tokio::time::{sleep, Duration};
  16. use tracing::{debug, error};
  17. pub enum NpcWire {
  18. Poke(u64),
  19. Pack(u64),
  20. Nack(u64),
  21. Bind(u64),
  22. }
  23. impl Wire for NpcWire {
  24. const VERSION: u64 = 1;
  25. const SOURCE: &'static str = "npc";
  26. fn to_wire(&self) -> WireRepr {
  27. let tags = match self {
  28. NpcWire::Poke(pid) => vec!["poke".into(), pid.into()],
  29. NpcWire::Pack(pid) => vec!["pack".into(), pid.into()],
  30. NpcWire::Nack(pid) => vec!["nack".into(), pid.into()],
  31. NpcWire::Bind(pid) => vec!["bind".into(), pid.into()],
  32. };
  33. WireRepr::new(Self::SOURCE, Self::VERSION, tags)
  34. }
  35. }
  36. /// NPC Listener IO driver
  37. pub fn npc_listener(listener: UnixListener) -> IODriverFn {
  38. make_driver(move |mut handle| async move {
  39. let mut client_join_set = TaskJoinSet::new();
  40. loop {
  41. select! {
  42. stream_res = listener.accept() => {
  43. debug!("Accepted new connection");
  44. match stream_res {
  45. Ok((stream, _)) => {
  46. let (my_handle, their_handle) = handle.dup();
  47. handle = my_handle;
  48. let _ = client_join_set.spawn(npc_client(stream)(their_handle));
  49. },
  50. Err(e) => {
  51. error!("Error accepting connection: {:?}", e);
  52. }
  53. }
  54. },
  55. Some(result) = client_join_set.join_next() => {
  56. match result {
  57. Ok(Ok(())) => debug!("npc: client task completed successfully"),
  58. Ok(Err(e)) => error!("npc: client task error: {:?}", e),
  59. Err(e) => error!("npc: client task join error: {:?}", e),
  60. }
  61. },
  62. // TODO: don't do this, revive robin hood
  63. _ = sleep(Duration::from_millis(100)) => {
  64. // avoid tight-looping
  65. }
  66. }
  67. }
  68. })
  69. }
  70. /// NPC Client IO driver
  71. pub fn npc_client(stream: UnixStream) -> IODriverFn {
  72. make_driver(move |handle| async move {
  73. let (stream_read, mut stream_write) = split(stream);
  74. let stream_read_arc = Arc::new(Mutex::new(stream_read));
  75. let mut read_message_join_set = JoinSet::new();
  76. read_message_join_set.spawn(read_message(stream_read_arc.clone()));
  77. 'driver: loop {
  78. select! {
  79. message = read_message_join_set.join_next() => {
  80. match message {
  81. Some(Ok(Ok(Some(mut slab)))) => {
  82. debug!("npc_client: read message");
  83. let Ok(message_cell) = unsafe { slab.root() }.as_cell() else {
  84. continue;
  85. };
  86. let (pid, directive_cell) = match (message_cell.head().as_direct(), message_cell.tail().as_cell()) {
  87. (Ok(direct), Ok(cell)) => (direct.data(), cell),
  88. _ => continue,
  89. };
  90. let Ok(directive_tag) = directive_cell.head().as_direct() else {
  91. continue;
  92. };
  93. let directive_tag = directive_tag.data();
  94. match directive_tag {
  95. tas!(b"poke") => {
  96. debug!("npc_client: poke");
  97. let mut poke_slab = NounSlab::new();
  98. let poke = directive_cell.tail();
  99. poke_slab.copy_into(poke);
  100. let wire = NpcWire::Poke(pid).to_wire();
  101. let result = handle.poke(wire, poke_slab).await?;
  102. let (tag, noun) = match result {
  103. PokeResult::Ack => (tas!(b"pack"), D(0)),
  104. PokeResult::Nack => (tas!(b"nack"), D(0)),
  105. };
  106. let mut response_slab = NounSlab::new();
  107. let response_noun = T(&mut response_slab, &[D(pid), D(tag), noun]);
  108. response_slab.set_root(response_noun);
  109. if !write_message(&mut stream_write, response_slab).await? {
  110. break 'driver;
  111. }
  112. },
  113. tas!(b"peek") => {
  114. debug!("npc_client: peek");
  115. let path = directive_cell.tail();
  116. slab.set_root(path);
  117. let peek_res = handle.peek(slab).await?;
  118. match peek_res {
  119. Some(mut bind_slab) => {
  120. bind_slab.modify(|root| {
  121. vec![D(pid), D(tas!(b"bind")), root]
  122. });
  123. if !write_message(&mut stream_write, bind_slab).await? {
  124. break 'driver;
  125. }
  126. },
  127. None => {
  128. error!("npc: peek failed!");
  129. }
  130. }
  131. },
  132. tas!(b"pack") | tas!(b"nack") | tas!(b"bind") => {
  133. debug!("npc_client: pack, nack, or bind");
  134. let tag = match directive_tag {
  135. tas!(b"pack") => tas!(b"npc-pack"),
  136. tas!(b"nack") => tas!(b"npc-nack"),
  137. tas!(b"bind") => tas!(b"npc-bind"),
  138. _ => unreachable!(),
  139. };
  140. let wire = match directive_tag {
  141. tas!(b"pack") => NpcWire::Pack(pid),
  142. tas!(b"nack") => NpcWire::Nack(pid),
  143. tas!(b"bind") => NpcWire::Bind(pid),
  144. _ => unreachable!(),
  145. };
  146. let poke = if tag == tas!(b"npc-bind") {
  147. T(&mut slab, &[D(tag), D(pid), directive_cell.tail()])
  148. } else {
  149. T(&mut slab, &[D(tag), D(pid)])
  150. };
  151. slab.set_root(poke);
  152. handle.poke(wire.to_wire(), slab).await?;
  153. },
  154. _ => {
  155. debug!("npc_client: unexpected message: {:?}", directive_tag);
  156. },
  157. }
  158. },
  159. Some(Ok(Ok(None))) => {
  160. break 'driver;
  161. },
  162. Some(Err(e)) => {
  163. error!("{e:?}");
  164. },
  165. Some(Ok(Err(e))) => {
  166. error!("{e:?}");
  167. },
  168. None => {
  169. read_message_join_set.spawn(read_message(stream_read_arc.clone()));
  170. }
  171. }
  172. },
  173. effect_res = handle.next_effect() => {
  174. let mut slab = effect_res?; // Closed error should error driver
  175. let Ok(effect_cell) = unsafe { slab.root() }.as_cell() else {
  176. continue;
  177. };
  178. // TODO: distinguish connections
  179. if unsafe { effect_cell.head().raw_equals(&D(tas!(b"npc"))) } {
  180. slab.set_root(effect_cell.tail());
  181. if !write_message(&mut stream_write, slab).await? {
  182. break 'driver;
  183. }
  184. }
  185. }
  186. }
  187. }
  188. Ok(())
  189. })
  190. }
  191. async fn read_message(
  192. stream_arc: Arc<Mutex<ReadHalf<UnixStream>>>,
  193. ) -> Result<Option<NounSlab>, NockAppError> {
  194. let mut stream = stream_arc.lock_owned().await;
  195. let mut size_bytes = [0u8; 8];
  196. debug!("Attempting to read message size...");
  197. match stream.read_exact(&mut size_bytes).await {
  198. Ok(0) => {
  199. debug!("Connection closed");
  200. return Ok(None);
  201. }
  202. Ok(size) => {
  203. debug!("Read size: {:?}", size);
  204. }
  205. Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
  206. debug!("Connection closed unexpectedly");
  207. return Ok(None);
  208. }
  209. Err(e) => {
  210. debug!("Error reading size: {:?}", e);
  211. return Err(NockAppError::IoError(e));
  212. }
  213. }
  214. let size = usize::from_le_bytes(size_bytes);
  215. debug!("Message size: {} bytes", size);
  216. let mut buf = Vec::with_capacity(size).limit(size);
  217. while buf.remaining_mut() > 0 {
  218. debug!(
  219. "Reading message content, {} bytes remaining",
  220. buf.remaining_mut()
  221. );
  222. match stream.read_buf(&mut buf).await {
  223. Ok(0) => {
  224. debug!("Connection closed while reading message content");
  225. return Ok(None);
  226. }
  227. Ok(_) => {}
  228. Err(e) => return Err(NockAppError::IoError(e)),
  229. }
  230. }
  231. debug!("Successfully read entire message");
  232. let mut slab = NounSlab::new();
  233. let noun = slab.cue_into(Bytes::from(buf.into_inner()))?;
  234. slab.set_root(noun);
  235. Ok(Some(slab))
  236. }
  237. async fn write_message(
  238. stream: &mut WriteHalf<UnixStream>,
  239. msg_slab: NounSlab,
  240. ) -> Result<bool, NockAppError> {
  241. let msg_bytes = msg_slab.jam();
  242. let msg_len = msg_bytes.len();
  243. debug!("Attempting to write message of {} bytes", msg_len);
  244. let mut msg_len_bytes = &msg_len.to_le_bytes()[..];
  245. let mut msg_buf = &msg_bytes[..];
  246. while !msg_len_bytes.is_empty() {
  247. debug!(
  248. "Writing message length, {} bytes remaining",
  249. msg_len_bytes.len()
  250. );
  251. let bytes = stream
  252. .write_buf(&mut msg_len_bytes)
  253. .await
  254. .map_err(NockAppError::IoError)?;
  255. if bytes == 0 {
  256. debug!("Wrote 0 bytes for message length, returning false");
  257. return Ok(false);
  258. }
  259. }
  260. while !msg_buf.is_empty() {
  261. debug!("Writing message content, {} bytes remaining", msg_buf.len());
  262. let bytes = stream
  263. .write_buf(&mut msg_buf)
  264. .await
  265. .map_err(NockAppError::IoError)?;
  266. if bytes == 0 {
  267. debug!("Wrote 0 bytes for message content, returning false");
  268. return Ok(false);
  269. }
  270. }
  271. debug!("Successfully wrote entire message");
  272. Ok(true)
  273. }
  274. #[cfg(test)]
  275. mod tests {
  276. use crate::nockapp::driver::{IOAction, NockAppHandle};
  277. use super::*;
  278. use std::io::{Read, Write};
  279. use std::os::unix::net::UnixStream as StdUnixStream;
  280. use std::time::Duration;
  281. use tempfile::tempdir;
  282. use tokio::net::UnixStream;
  283. use tokio::sync::{broadcast, mpsc};
  284. use tokio::time::timeout;
  285. use tracing_test::traced_test;
  286. async fn setup_socket_pair() -> (UnixStream, StdUnixStream) {
  287. let dir = tempdir().unwrap_or_else(|err| {
  288. panic!(
  289. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  290. file!(),
  291. line!(),
  292. option_env!("GIT_SHA")
  293. )
  294. });
  295. let socket_path = dir.path().join("test.sock");
  296. let listener = UnixListener::bind(&socket_path).unwrap_or_else(|err| {
  297. panic!(
  298. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  299. file!(),
  300. line!(),
  301. option_env!("GIT_SHA")
  302. )
  303. });
  304. let client = StdUnixStream::connect(&socket_path).unwrap_or_else(|err| {
  305. panic!(
  306. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  307. file!(),
  308. line!(),
  309. option_env!("GIT_SHA")
  310. )
  311. });
  312. let (server, _) = listener.accept().await.unwrap_or_else(|err| {
  313. panic!(
  314. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  315. file!(),
  316. line!(),
  317. option_env!("GIT_SHA")
  318. )
  319. });
  320. (server, client)
  321. }
  322. #[tokio::test]
  323. #[cfg_attr(miri, ignore)]
  324. async fn test_write_message_format() {
  325. let (server, mut client) = setup_socket_pair().await;
  326. let (_, mut writer) = split(server);
  327. let mut test_slab = NounSlab::new();
  328. let test_noun = T(&mut test_slab, &[D(123), D(456)]);
  329. test_slab.set_root(test_noun);
  330. write_message(&mut writer, test_slab)
  331. .await
  332. .unwrap_or_else(|err| {
  333. panic!(
  334. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  335. file!(),
  336. line!(),
  337. option_env!("GIT_SHA")
  338. )
  339. });
  340. let mut size_buf = [0u8; 8];
  341. client.read_exact(&mut size_buf).unwrap_or_else(|err| {
  342. panic!(
  343. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  344. file!(),
  345. line!(),
  346. option_env!("GIT_SHA")
  347. )
  348. });
  349. let size = usize::from_le_bytes(size_buf);
  350. let mut msg_buf = vec![0u8; size];
  351. client.read_exact(&mut msg_buf).unwrap_or_else(|err| {
  352. panic!(
  353. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  354. file!(),
  355. line!(),
  356. option_env!("GIT_SHA")
  357. )
  358. });
  359. let mut received_slab = NounSlab::new();
  360. let received_noun = received_slab
  361. .cue_into(Bytes::from(msg_buf))
  362. .unwrap_or_else(|err| {
  363. panic!(
  364. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  365. file!(),
  366. line!(),
  367. option_env!("GIT_SHA")
  368. )
  369. });
  370. received_slab.set_root(received_noun);
  371. let root = unsafe { received_slab.root() };
  372. let cell = root.as_cell().unwrap_or_else(|err| {
  373. panic!(
  374. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  375. file!(),
  376. line!(),
  377. option_env!("GIT_SHA")
  378. )
  379. });
  380. assert_eq!(
  381. cell.head()
  382. .as_direct()
  383. .unwrap_or_else(|err| {
  384. panic!(
  385. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  386. file!(),
  387. line!(),
  388. option_env!("GIT_SHA")
  389. )
  390. })
  391. .data(),
  392. 123
  393. );
  394. assert_eq!(
  395. cell.tail()
  396. .as_direct()
  397. .unwrap_or_else(|err| {
  398. panic!(
  399. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  400. file!(),
  401. line!(),
  402. option_env!("GIT_SHA")
  403. )
  404. })
  405. .data(),
  406. 456
  407. );
  408. }
  409. #[tokio::test]
  410. #[cfg_attr(miri, ignore)]
  411. async fn test_write_message_empty() {
  412. let (server, mut client) = setup_socket_pair().await;
  413. let (_, mut writer) = split(server);
  414. let mut test_slab = NounSlab::new();
  415. let test_noun = T(&mut test_slab, &[D(0), D(0)]);
  416. test_slab.set_root(test_noun);
  417. assert!(write_message(&mut writer, test_slab)
  418. .await
  419. .unwrap_or_else(|err| {
  420. panic!(
  421. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  422. file!(),
  423. line!(),
  424. option_env!("GIT_SHA")
  425. )
  426. }));
  427. let mut size_buf = [0u8; 8];
  428. client.read_exact(&mut size_buf).unwrap_or_else(|err| {
  429. panic!(
  430. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  431. file!(),
  432. line!(),
  433. option_env!("GIT_SHA")
  434. )
  435. });
  436. assert!(usize::from_le_bytes(size_buf) > 0);
  437. }
  438. #[tokio::test]
  439. #[cfg_attr(miri, ignore)]
  440. async fn test_read_message_eof() {
  441. let (server, client) = setup_socket_pair().await;
  442. drop(client);
  443. let stream_arc = Arc::new(Mutex::new(split(server).0));
  444. let result = read_message(stream_arc).await;
  445. assert!(result
  446. .unwrap_or_else(|err| {
  447. panic!(
  448. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  449. file!(),
  450. line!(),
  451. option_env!("GIT_SHA")
  452. )
  453. })
  454. .is_none());
  455. }
  456. #[tokio::test]
  457. #[traced_test]
  458. #[cfg_attr(miri, ignore)]
  459. async fn test_npc_driver() {
  460. // Setup
  461. let dir = tempdir().unwrap_or_else(|err| {
  462. panic!(
  463. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  464. file!(),
  465. line!(),
  466. option_env!("GIT_SHA")
  467. )
  468. });
  469. let socket_path = dir.path().join("test.sock");
  470. let listener = UnixListener::bind(&socket_path).unwrap_or_else(|err| {
  471. panic!(
  472. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  473. file!(),
  474. line!(),
  475. option_env!("GIT_SHA")
  476. )
  477. });
  478. // Create channels for driver communication
  479. let (tx_io, mut rx_io) = mpsc::channel(32);
  480. let (tx_effect, rx_effect) = broadcast::channel(32);
  481. let (tx_exit, _) = mpsc::channel(1);
  482. let handle = NockAppHandle {
  483. io_sender: tx_io,
  484. effect_sender: tx_effect.clone(),
  485. effect_receiver: Mutex::new(rx_effect),
  486. exit: tx_exit,
  487. };
  488. // Spawn the listener driver
  489. let _driver_task = tokio::spawn(npc_listener(listener)(handle));
  490. // Connect client
  491. let mut client = StdUnixStream::connect(&socket_path).unwrap_or_else(|err| {
  492. panic!(
  493. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  494. file!(),
  495. line!(),
  496. option_env!("GIT_SHA")
  497. )
  498. });
  499. // Create test noun slab
  500. let mut test_slab = NounSlab::new();
  501. let msg_noun = T(&mut test_slab, &[D(tas!(b"poke")), D(123), D(456)]);
  502. let test_noun = T(&mut test_slab, &[D(1), msg_noun]);
  503. test_slab.set_root(test_noun);
  504. // Jam the noun to bytes
  505. let msg_bytes = test_slab.jam();
  506. let msg_len = msg_bytes.len();
  507. // Write length prefix and jammed noun
  508. client
  509. .write_all(&(msg_len as u64).to_le_bytes())
  510. .unwrap_or_else(|err| {
  511. panic!(
  512. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  513. file!(),
  514. line!(),
  515. option_env!("GIT_SHA")
  516. )
  517. });
  518. client.write_all(&msg_bytes).unwrap_or_else(|err| {
  519. panic!(
  520. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  521. file!(),
  522. line!(),
  523. option_env!("GIT_SHA")
  524. )
  525. });
  526. debug!("client: wrote {} bytes", msg_len);
  527. // Verify driver received poke
  528. if let Some(IOAction::Poke {
  529. wire: _wire,
  530. poke: noun_slab,
  531. ack_channel: _,
  532. }) = timeout(Duration::from_secs(1), rx_io.recv())
  533. .await
  534. .unwrap_or_else(|err| {
  535. panic!(
  536. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  537. file!(),
  538. line!(),
  539. option_env!("GIT_SHA")
  540. )
  541. })
  542. {
  543. debug!("test_npc_driver: poke data: {:?}", unsafe {
  544. noun_slab.root()
  545. });
  546. // Verify noun content
  547. let noun = unsafe { noun_slab.root() };
  548. let noun_cell = noun.as_cell().unwrap_or_else(|err| {
  549. panic!(
  550. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  551. file!(),
  552. line!(),
  553. option_env!("GIT_SHA")
  554. )
  555. });
  556. assert_eq!(
  557. noun_cell
  558. .head()
  559. .as_direct()
  560. .unwrap_or_else(|err| {
  561. panic!(
  562. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  563. file!(),
  564. line!(),
  565. option_env!("GIT_SHA")
  566. )
  567. })
  568. .data(),
  569. 123
  570. );
  571. assert_eq!(
  572. noun_cell
  573. .tail()
  574. .as_direct()
  575. .unwrap_or_else(|err| {
  576. panic!(
  577. "Panicked with {err:?} at {}:{} (git sha: {:?})",
  578. file!(),
  579. line!(),
  580. option_env!("GIT_SHA")
  581. )
  582. })
  583. .data(),
  584. 456
  585. );
  586. // TODO: make this work
  587. /* ack_channel.send(PokeResult::Ack).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  588. // Send effect through broadcast channel
  589. let mut ack_slab = NounSlab::new();
  590. let ack = T(&mut ack_slab.clone(), &[
  591. D(tas!(b"npc")),
  592. T(&mut ack_slab.clone(), &[D(123), D(tas!(b"pack")), D(0)])
  593. ]);
  594. ack_slab.set_root(ack);
  595. tx_effect.send(ack_slab).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  596. // Verify client receives ack
  597. let mut size_buf = [0u8; 8];
  598. client.read_exact(&mut size_buf).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  599. let size = usize::from_le_bytes(size_buf);
  600. let mut msg_buf = vec![0u8; size];
  601. client.read_exact(&mut msg_buf).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  602. let mut received_slab = NounSlab::new();
  603. let received_noun = received_slab.cue_into(Bytes::from(msg_buf)).unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  604. received_slab.set_root(received_noun);
  605. let root = unsafe { received_slab.root() };
  606. let cell = root.as_cell().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  607. assert_eq!(cell.head().as_direct().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA"))).data(), 123);
  608. let rest = cell.tail().as_cell().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA")));
  609. assert_eq!(rest.head().as_direct().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA"))).data(), tas!(b"pack"));
  610. assert_eq!(rest.tail().as_direct().unwrap_or_else(|| panic!("Panicked at {}:{} (git sha: {:?})", file!(), line!(), option_env!("GIT_SHA"))).data(), 0); */
  611. } else {
  612. panic!("Did not receive poke message");
  613. }
  614. // Cleanup
  615. drop(client);
  616. }
  617. }