1use std::{
3 collections::{HashMap, hash_map::Entry},
4 future::Future,
5 io::{Error, IoSlice},
6 mem,
7 pin::{Pin, pin},
8 task::{Context, Poll},
9};
10
11use bincode::Options;
12use s2n_quic::{
13 connection::{Handle, StreamAcceptor as QuicStreamAcceptor},
14 stream::{ReceiveStream as QuicRecvStream, SendStream as QuicSendStream},
15};
16use serde::{Deserialize, Serialize, de::DeserializeOwned};
17use tokio::{
18 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
19 select,
20 sync::{mpsc, oneshot},
21};
22use tokio_serde::{
23 SymmetricallyFramed,
24 formats::{Bincode, SymmetricalBincode},
25};
26use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec, length_delimited};
27use tracing::{Level, debug, error, event};
28
29#[cfg(feature = "metrics")]
30pub mod metrics;
31
32#[doc(hidden)]
33#[cfg(any(test, feature = "__testing"))]
34pub mod testing;
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
38pub struct Id(pub(crate) u64);
39
40#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
41enum StreamId {
42 Implicit(u64),
43 Explicit(u64),
44}
45
46#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
50struct ConnectionId(pub(crate) u32);
51
52#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
54struct UniqueId {
55 cids: Vec<ConnectionId>,
56 id: StreamId,
57}
58
59type StreamSend = oneshot::Sender<(QuicRecvStream, usize)>;
60type StreamRecv = oneshot::Receiver<(QuicRecvStream, usize)>;
61
62pub struct StreamManager {
64 acceptor: QuicStreamAcceptor,
65 cmd_send: mpsc::UnboundedSender<Cmd>,
66 cmd_recv: mpsc::UnboundedReceiver<Cmd>,
67 pending: HashMap<UniqueId, StreamSend>,
68 accepted: HashMap<UniqueId, (QuicRecvStream, usize)>,
69}
70
71#[derive(Debug)]
78pub struct Connection {
79 cids: Vec<ConnectionId>,
80 next_cid: u32,
81 handle: Handle,
82 cmd: mpsc::UnboundedSender<Cmd>,
83 next_implicit_id: u64,
84}
85
86pub struct SendStreamBytes {
88 inner: QuicSendStream,
89}
90
91pub struct ReceiveStreamBytes {
93 inner: ReceiveStreamWrapper,
94}
95
96pub type SendStream<T> = SymmetricallyFramed<
98 FramedWrite<SendStreamBytes, LengthDelimitedCodec>,
99 T,
100 SymmetricalBincode<T>,
101>;
102
103pub type SendStreamTemp<'a, T> = SymmetricallyFramed<
105 FramedWrite<&'a mut SendStreamBytes, LengthDelimitedCodec>,
106 T,
107 SymmetricalBincode<T>,
108>;
109
110pub type ReceiveStream<T> = SymmetricallyFramed<
112 FramedRead<ReceiveStreamBytes, LengthDelimitedCodec>,
113 T,
114 SymmetricalBincode<T>,
115>;
116
117pub type ReceiveStreamTemp<'a, T> = SymmetricallyFramed<
119 FramedRead<&'a mut ReceiveStreamBytes, LengthDelimitedCodec>,
120 T,
121 SymmetricalBincode<T>,
122>;
123
124enum ReceiveStreamWrapper {
125 Channel { stream_recv: StreamRecv },
126 Stream { recv_stream: QuicRecvStream },
127}
128
129#[derive(Debug)]
130enum Cmd {
131 NewStream {
132 uid: UniqueId,
133 stream_return: StreamSend,
134 },
135 AcceptedStream {
136 uid: UniqueId,
137 stream: QuicRecvStream,
138 bytes_read: usize,
139 },
140}
141
142impl StreamManager {
143 pub fn new(acceptor: QuicStreamAcceptor) -> Self {
144 let (cmd_send, cmd_recv) = mpsc::unbounded_channel();
145 Self {
146 acceptor,
147 cmd_send,
148 cmd_recv,
149 pending: Default::default(),
150 accepted: Default::default(),
151 }
152 }
153
154 #[tracing::instrument(skip_all)]
158 pub async fn start(mut self) {
159 loop {
160 let mut receive_stream = pin!(self.acceptor.accept_receive_stream());
162 select! {
163 res = &mut receive_stream => {
164 match res {
165 Ok(Some(stream)) => {
166 debug!("accepted stream");
167 Self::accepted(stream, self.cmd_send.clone());
168 }
169 Ok(None) => {
170 debug!("remote closed");
171 return;
172 }
173 Err(err) => {
174 error!(%err, "unable to accept stream");
175 return;
176 }
177 }
178 }
179 Some(cmd) = self.cmd_recv.recv() => { debug!(?cmd, "received cmd");
181 match cmd {
182 Cmd::NewStream {uid, stream_return} => {
183 if let Some(accepted) = self.accepted.remove(&uid) {
184 if stream_return.send(accepted).is_err() {
185 debug!("accepted remote stream but local receiver is closed");
186 }
187 debug!("sending new stream to receiver");
188 continue;
189 }
190 match self.pending.entry(uid) {
191 Entry::Occupied(occupied_entry) => {
192 panic!("Duplicate unique id: {:?}", occupied_entry.key())
193 },
194 Entry::Vacant(vacant_entry) => {vacant_entry.insert(stream_return);},
195 }
196 }
197 Cmd::AcceptedStream {uid, stream, bytes_read} => {
198 if let Some(stream_ret) = self.pending.remove(&uid) {
199 if stream_ret.send((stream, bytes_read)).is_err() {
200 debug!("accepted remote stream but local receiver is closed");
201 }
202 } else {
203 debug!("accepted stream but no pending");
204 self.accepted.insert(uid, (stream, bytes_read));
205 }
206 }
207 }
208 }
209 }
210 }
211 }
212
213 fn accepted(mut stream: QuicRecvStream, cmd_send: mpsc::UnboundedSender<Cmd>) {
215 tokio::spawn(async move {
216 let (uid, bytes_read) = match UniqueId::read_from(&mut stream).await {
217 Ok(ret) => ret,
218 Err(err) => {
219 error!(?err, "unable to read stream unique id");
220 return;
221 }
222 };
223 cmd_send
224 .send(Cmd::AcceptedStream {
225 uid,
226 stream,
227 bytes_read,
228 })
229 .expect("cmd_rcv is owned by StreamManager")
230 });
231 }
232}
233
234#[derive(thiserror::Error, Debug)]
236pub enum ConnectionError {
237 #[error("Unable to open stream")]
238 OpenStream(#[source] s2n_quic::connection::Error),
239 #[error("io error during stream establishment")]
240 IoError(#[source] io::Error),
241 #[error("StreamManager is dropped and not accepting connections")]
242 StreamManagerDropped,
243 #[error("Stream unique id deserialization failed")]
244 UniqueIdDeserialization(#[source] bincode::Error),
245 #[error("Stream unique id serialization failed")]
246 UniqueIdSerialization(#[source] bincode::Error),
247 #[error("Reached maximum number of sub connections")]
248 SubConnectionLimitReached,
249}
250
251impl Connection {
252 pub fn new(quic_conn: s2n_quic::Connection) -> (Self, StreamManager) {
253 let (handle, acceptor) = quic_conn.split();
254 let stream_manager = StreamManager::new(acceptor);
255 let conn = Self {
256 cids: vec![],
257 next_cid: 0,
258 handle,
259 cmd: stream_manager.cmd_send.clone(),
260 next_implicit_id: 0,
261 };
262 (conn, stream_manager)
263 }
264
265 #[tracing::instrument(level = Level::DEBUG, skip(self), ret)]
270 pub fn sub_connection(&mut self) -> Self {
271 let cid = self.next_cid;
272 self.next_cid += 1;
273 let mut cids = self.cids.clone();
274 cids.push(ConnectionId(cid));
275 Self {
276 cids,
277 next_cid: 0,
278 handle: self.handle.clone(),
279 cmd: self.cmd.clone(),
280 next_implicit_id: 0,
281 }
282 }
283
284 async fn internal_byte_stream(
285 &self,
286 stream_id: StreamId,
287 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
288 let uid = UniqueId::new(self.cids.clone(), stream_id);
289 let mut snd = self
290 .handle
291 .clone()
292 .open_send_stream()
293 .await
294 .map_err(ConnectionError::OpenStream)?;
295 let bytes_written = uid.write_into(&mut snd).await?;
296 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes_written);
297 let (stream_return, stream_recv) = oneshot::channel();
298 self.cmd
299 .send(Cmd::NewStream { uid, stream_return })
300 .map_err(|_| ConnectionError::StreamManagerDropped)?;
301 let snd = SendStreamBytes { inner: snd };
302 let recv = ReceiveStreamBytes {
303 inner: ReceiveStreamWrapper::Channel { stream_recv },
304 };
305 Ok((snd, recv))
306 }
307
308 pub async fn byte_stream(
310 &mut self,
311 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
312 self.next_implicit_id += 1;
313 self.internal_byte_stream(StreamId::Implicit(self.next_implicit_id - 1))
314 .await
315 }
316
317 pub async fn byte_stream_with_id(
319 &self,
320 id: Id,
321 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
322 self.internal_byte_stream(StreamId::Explicit(id.0)).await
323 }
324
325 async fn internal_stream<T: Serialize + DeserializeOwned>(
327 &self,
328 id: StreamId,
329 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
330 let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
331 let mut ld_codec = LengthDelimitedCodec::builder();
332 const MB: usize = 1024 * 1024;
334 ld_codec.max_frame_length(256 * MB);
335 let framed_send = ld_codec.new_write(send_bytes);
336 let framed_read = ld_codec.new_read(recv_bytes);
337 let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
338 let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
339 Ok((serde_send, serde_read))
340 }
341
342 pub async fn stream<T: Serialize + DeserializeOwned>(
344 &mut self,
345 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
346 self.next_implicit_id += 1;
347 self.internal_stream(StreamId::Implicit(self.next_implicit_id - 1))
348 .await
349 }
350
351 pub async fn stream_with_id<T: Serialize + DeserializeOwned>(
354 &self,
355 id: Id,
356 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
357 self.internal_stream(StreamId::Explicit(id.0)).await
358 }
359
360 async fn internal_request_response_stream<T: Serialize, S: DeserializeOwned>(
361 &self,
362 id: StreamId,
363 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
364 let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
365 let framed_send = default_codec().new_write(send_bytes);
366 let framed_read = default_codec().new_read(recv_bytes);
367 let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
368 let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
369 Ok((serde_send, serde_read))
370 }
371
372 pub async fn request_response_stream<T: Serialize, S: DeserializeOwned>(
375 &mut self,
376 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
377 self.next_implicit_id += 1;
378 self.internal_request_response_stream(StreamId::Implicit(self.next_implicit_id - 1))
379 .await
380 }
381
382 pub async fn request_response_stream_with_id<T: Serialize, S: DeserializeOwned>(
385 &self,
386 id: Id,
387 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
388 self.internal_request_response_stream(StreamId::Explicit(id.0))
389 .await
390 }
391}
392
393impl Id {
394 pub fn new(id: u64) -> Self {
395 Self(id)
396 }
397}
398
399fn bincode_opts() -> impl bincode::Options {
400 bincode::options().with_big_endian().with_varint_encoding()
401}
402
403impl UniqueId {
404 fn new(cids: Vec<ConnectionId>, id: StreamId) -> Self {
405 Self { cids, id }
406 }
407
408 async fn write_into<W: AsyncWrite>(&self, write: W) -> Result<usize, ConnectionError> {
409 let mut write = pin!(write);
410 let mut options = bincode_opts();
411 let serialized = (&mut options)
412 .serialize(self)
413 .map_err(ConnectionError::UniqueIdSerialization)?;
414 write
415 .write_u32(
416 serialized
417 .len()
418 .try_into()
419 .map_err(|_| ConnectionError::SubConnectionLimitReached)?,
420 )
421 .await
422 .map_err(ConnectionError::IoError)?;
423 write
424 .write_all(&serialized)
425 .await
426 .map_err(ConnectionError::IoError)?;
427 Ok(mem::size_of::<u32>() + serialized.len())
428 }
429
430 async fn read_from<R: AsyncRead>(reader: R) -> Result<(Self, usize), ConnectionError> {
431 let mut reader = pin!(reader);
432 let len = reader.read_u32().await.map_err(ConnectionError::IoError)?;
433 let mut buf = vec![0; len as usize];
434 reader
435 .read_exact(&mut buf)
436 .await
437 .map_err(ConnectionError::IoError)?;
438 let uid = bincode_opts()
439 .deserialize(&buf)
440 .map_err(ConnectionError::UniqueIdDeserialization)?;
441 Ok((uid, mem::size_of::<u32>() + len as usize))
442 }
443}
444
445#[derive(thiserror::Error, Debug)]
447pub enum StreamError {
448 #[error("unable to flush stream")]
449 Flush(#[source] s2n_quic::stream::Error),
450 #[error("unable to close stream")]
451 Close(#[source] s2n_quic::stream::Error),
452 #[error("unable to finish stream")]
453 Finish(#[source] s2n_quic::stream::Error),
454}
455
456impl SendStreamBytes {
457 pub async fn flush(&mut self) -> Result<(), StreamError> {
458 self.inner.flush().await.map_err(StreamError::Flush)
459 }
460
461 pub fn finish(&mut self) -> Result<(), StreamError> {
462 self.inner.finish().map_err(StreamError::Finish)
463 }
464
465 pub async fn close(&mut self) -> Result<(), StreamError> {
466 self.inner.close().await.map_err(StreamError::Close)
467 }
468
469 pub fn as_stream<T: Serialize>(&mut self) -> SendStreamTemp<'_, T> {
470 let framed_send = default_codec().new_write(self);
471 SymmetricallyFramed::new(framed_send, Bincode::default())
472 }
473}
474
475impl AsyncWrite for SendStreamBytes {
476 fn poll_write(
477 mut self: Pin<&mut Self>,
478 cx: &mut Context<'_>,
479 buf: &[u8],
480 ) -> Poll<Result<usize, Error>> {
481 let inner = Pin::new(&mut self.inner);
482 trace_poll(inner.poll_write(cx, buf))
483 }
484
485 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
486 let inner = Pin::new(&mut self.inner);
487 AsyncWrite::poll_flush(inner, cx)
488 }
489
490 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
491 let inner = Pin::new(&mut self.inner);
492 inner.poll_shutdown(cx)
493 }
494
495 fn poll_write_vectored(
496 mut self: Pin<&mut Self>,
497 cx: &mut Context<'_>,
498 bufs: &[IoSlice<'_>],
499 ) -> Poll<Result<usize, Error>> {
500 let inner = Pin::new(&mut self.inner);
501 trace_poll(inner.poll_write_vectored(cx, bufs))
502 }
503
504 fn is_write_vectored(&self) -> bool {
505 self.inner.is_write_vectored()
506 }
507}
508
509fn trace_poll(p: Poll<io::Result<usize>>) -> Poll<io::Result<usize>> {
510 if let Poll::Ready(Ok(bytes)) = p {
511 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes);
512 }
513 p
514}
515
516impl ReceiveStreamBytes {
517 pub fn as_stream<T: DeserializeOwned>(&mut self) -> ReceiveStreamTemp<'_, T> {
518 let framed_read = default_codec().new_read(self);
519 SymmetricallyFramed::new(framed_read, Bincode::default())
520 }
521}
522
523impl AsyncRead for ReceiveStreamBytes {
526 fn poll_read(
527 mut self: Pin<&mut Self>,
528 cx: &mut Context<'_>,
529 buf: &mut ReadBuf<'_>,
530 ) -> Poll<std::io::Result<()>> {
531 match &mut self.inner {
532 ReceiveStreamWrapper::Channel { stream_recv } => match Pin::new(stream_recv).poll(cx) {
533 Poll::Pending => Poll::Pending,
534 Poll::Ready(Ok((recv_stream, bytes_read))) => {
535 event!(target: "cryprot_metrics", Level::TRACE, bytes_read);
538 self.inner = ReceiveStreamWrapper::Stream { recv_stream };
539 self.poll_read(cx, buf)
540 }
541 Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::other(Box::new(err)))),
542 },
543 ReceiveStreamWrapper::Stream { recv_stream } => {
544 let len = buf.filled().len();
545 let poll = Pin::new(recv_stream).poll_read(cx, buf);
546 if let Poll::Ready(Ok(())) = poll {
547 let bytes = buf.filled().len() - len;
548 if bytes > 0 {
549 event!(target: "cryprot_metrics", Level::TRACE, bytes_read = bytes);
550 }
551 }
552 poll
553 }
554 }
555 }
556}
557
558fn default_codec() -> length_delimited::Builder {
559 let mut ld_codec = LengthDelimitedCodec::builder();
560 const MB: usize = 1024 * 1024;
561 ld_codec.max_frame_length(20 * MB);
562 ld_codec
563}
564
565#[cfg(test)]
566mod tests {
567 use std::u8;
568
569 use anyhow::{Context, Result};
570 use futures::{SinkExt, StreamExt};
571 use tokio::{
572 io::{AsyncReadExt, AsyncWriteExt},
573 task::JoinSet,
574 };
575 use tracing::debug;
576
577 use crate::{
578 Id,
579 testing::{init_tracing, local_conn},
580 };
581
582 #[tokio::test]
583 async fn create_local_conn() -> Result<()> {
584 let _g = init_tracing();
585 let _ = local_conn().await?;
586 Ok(())
587 }
588
589 #[tokio::test]
590 async fn byte_stream() -> Result<()> {
591 let _g = init_tracing();
592 let (mut s, mut c) = local_conn().await?;
593 let (mut s_send, _) = s.byte_stream().await?;
594 let (_, mut c_recv) = c.byte_stream().await?;
595 let send_buf = b"hello there";
596 s_send.write_all(send_buf).await?;
597 let mut buf = [0; 11];
598 c_recv.read_exact(&mut buf).await?;
599 assert_eq!(send_buf, &buf);
600 Ok(())
601 }
602
603 #[tokio::test]
604 async fn byte_stream_explicit_implicit_id() -> Result<()> {
605 let _g = init_tracing();
606 let (mut s, mut c) = local_conn().await?;
607 let (mut s_send1, _) = s.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
608 let (mut s_send2, _) = s.byte_stream().await?;
609 let (_, mut c_recv1) = c.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
610 let (_, mut c_recv2) = c.byte_stream().await?;
611 let send_buf1 = b"hello there";
612 s_send1.write_all(send_buf1).await?;
613 let mut buf = [0; 11];
614 c_recv1.read_exact(&mut buf).await?;
615 assert_eq!(send_buf1, &buf);
616
617 let send_buf2 = b"general kenobi";
618 s_send2.write_all(send_buf2).await?;
619 let mut buf = [0; 14];
620 c_recv2.read_exact(&mut buf).await?;
621 assert_eq!(send_buf2, &buf);
622 Ok(())
623 }
624
625 #[tokio::test]
626 async fn byte_stream_different_order() -> Result<()> {
627 let _g = init_tracing();
628 let (mut s, mut c) = local_conn().await?;
629 let (mut s_send, mut s_recv) = s.byte_stream().await?;
630 let s_send_buf = b"hello there";
631 s_send.write_all(s_send_buf).await?;
632 let mut s_recv_buf = [0; 2];
633 let jh = tokio::spawn(async move {
636 s_recv.read_exact(&mut s_recv_buf).await.unwrap();
637 s_recv_buf
638 });
639 let (mut c_send, mut c_recv) = c.byte_stream().await?;
640 let mut c_recv_buf = [0; 11];
641 c_recv.read_exact(&mut c_recv_buf).await?;
642 assert_eq!(s_send_buf, &c_recv_buf);
643 let c_send_buf = b"42";
644 c_send.write_all(c_send_buf).await?;
645 let s_recv_buf = jh.await?;
646 assert_eq!(c_send_buf, &s_recv_buf);
647 Ok(())
648 }
649
650 #[tokio::test]
651 async fn many_parallel_byte_streams() -> Result<()> {
652 let _g = init_tracing();
653 let (mut c1, mut c2) = local_conn().await?;
654 let mut jhs = JoinSet::new();
655 for i in 0..10 {
656 let ((mut s, _), (_, mut r)) =
657 tokio::try_join!(c1.byte_stream(), c2.byte_stream()).unwrap();
658
659 let jh = tokio::spawn(async move {
660 let buf = vec![0; 10 * 1024 * 1024];
661 s.write_all(&buf).await.unwrap();
662 debug!("wrote buf {i}");
663 });
664 jhs.spawn(jh);
665 let jh = tokio::spawn(async move {
666 let mut buf = vec![0; 10 * 1024 * 1024];
667 r.read_exact(&mut buf).await.unwrap();
668 debug!("received buf {i}");
669 });
670 jhs.spawn(jh);
671 }
672 let res = jhs.join_all().await;
673 for res in res {
674 res.unwrap();
675 }
676 Ok(())
677 }
678
679 #[tokio::test]
680 async fn serde_stream() -> Result<()> {
681 let _g = init_tracing();
682 let (mut s, mut c) = local_conn().await?;
683 let (mut snd, _) = s.stream::<Vec<i32>>().await?;
684 let (_, mut recv) = c.stream::<Vec<i32>>().await?;
685 snd.send(vec![1, 2, 3]).await?;
686 let ret = recv.next().await.context("recv")??;
687 assert_eq!(vec![1, 2, 3], ret);
688 drop(snd);
689 let ret = recv.next().await.map(|res| res.map_err(|_| ()));
690 assert_eq!(None, ret);
691 Ok(())
692 }
693
694 #[tokio::test]
695 async fn serde_stream_block() -> Result<()> {
696 let _g = init_tracing();
697 let (mut s, mut c) = local_conn().await?;
698 let (mut snd, _) = s.stream().await?;
699 let (_, mut recv) = c.stream().await?;
700 snd.send(vec![u8::MAX; 16]).await?;
701 let ret: Vec<_> = recv.next().await.context("recv")??;
702 assert_eq!(vec![u8::MAX; 16], ret);
703 Ok(())
704 }
705
706 #[tokio::test]
707 async fn serde_byte_stream_as_stream() -> Result<()> {
708 let _g = init_tracing();
709 let (mut s, mut c) = local_conn().await?;
710 let (mut s_send, _) = s.byte_stream().await?;
711 let (_, mut c_recv) = c.byte_stream().await?;
712 {
713 let mut send_ser1 = s_send.as_stream::<i32>();
714 let mut recv_ser1 = c_recv.as_stream::<i32>();
715 send_ser1.send(42).await?;
716 let ret = recv_ser1.next().await.context("recv")??;
717 assert_eq!(42, ret);
718 }
719 {
720 let mut send_ser2 = s_send.as_stream::<Vec<i32>>();
721 let mut recv_ser2 = c_recv.as_stream::<Vec<i32>>();
722 send_ser2.send(vec![1, 2, 3]).await?;
723 let ret = recv_ser2.next().await.context("recv")??;
724 assert_eq!(vec![1, 2, 3], ret);
725 }
726 Ok(())
727 }
728
729 #[tokio::test]
730 async fn serde_request_response_stream() -> Result<()> {
731 let _g = init_tracing();
732 let (mut s, mut c) = local_conn().await?;
733 let (mut snd1, mut recv1) = s.request_response_stream::<Vec<i32>, String>().await?;
734 let (mut snd2, mut recv2) = c.request_response_stream::<String, Vec<i32>>().await?;
735 snd1.send(vec![1, 2, 3]).await?;
736 let ret = recv2.next().await.context("recv")??;
737 assert_eq!(vec![1, 2, 3], ret);
738 snd2.send("hello there".to_string()).await?;
739 let ret = recv1.next().await.context("recv2")??;
740 assert_eq!("hello there", &ret);
741 Ok(())
742 }
743
744 #[tokio::test]
745 async fn sub_connection() -> Result<()> {
746 let _g = init_tracing();
747 let (mut s1, mut c1) = local_conn().await?;
748 let mut s2 = s1.sub_connection();
749 let mut c2 = c1.sub_connection();
750 let _ = s1.byte_stream();
751 let _ = c1.byte_stream();
752 let (mut snd, _) = s2.stream::<Vec<i32>>().await?;
753 let (_, mut recv) = c2.stream::<Vec<i32>>().await?;
754
755 snd.send(vec![1, 2, 3]).await?;
756 let ret = recv.next().await.context("recv")??;
757 assert_eq!(vec![1, 2, 3], ret);
758 Ok(())
759 }
760
761 #[tokio::test]
762 async fn sub_sub_connection() -> Result<()> {
763 let _g = init_tracing();
764 let (mut s1, mut c1) = local_conn().await?;
765 let mut s2 = s1.sub_connection();
766 let mut c2 = c1.sub_connection();
767 let mut s3 = s2.sub_connection();
768 let mut c3 = c2.sub_connection();
769 let _ = s1.byte_stream();
770 let _ = c1.byte_stream();
771 let _ = s2.byte_stream();
772 let _ = c2.byte_stream();
773 let (mut snd, _) = s3.stream::<Vec<i32>>().await?;
774 let (_, mut recv) = c3.stream::<Vec<i32>>().await?;
775
776 snd.send(vec![1, 2, 3]).await?;
777 let ret = recv.next().await.context("recv")??;
778 assert_eq!(vec![1, 2, 3], ret);
779 Ok(())
780 }
781}