cryprot_net/
lib.rs

1//! A networking library providing abstractions on top [`s2n_quic`].
2use std::{
3    collections::{HashMap, hash_map::Entry},
4    future::Future,
5    io::{Error, IoSlice},
6    mem,
7    pin::{Pin, pin},
8    sync::{
9        Arc,
10        atomic::{AtomicU32, Ordering},
11    },
12    task::{Context, Poll},
13};
14
15use bincode::Options;
16use s2n_quic::{
17    connection::{Handle, StreamAcceptor as QuicStreamAcceptor},
18    stream::{ReceiveStream as QuicRecvStream, SendStream as QuicSendStream},
19};
20use serde::{Deserialize, Serialize, de::DeserializeOwned};
21use tokio::{
22    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
23    select,
24    sync::{mpsc, oneshot},
25};
26use tokio_serde::{
27    SymmetricallyFramed,
28    formats::{Bincode, SymmetricalBincode},
29};
30use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec, length_delimited};
31use tracing::{Level, debug, error, event};
32
33#[cfg(feature = "metrics")]
34pub mod metrics;
35
36#[doc(hidden)]
37#[cfg(any(test, feature = "__testing"))]
38pub mod testing;
39
40/// Explicit Id provided by the user for a stream for a specific [`Connection`].
41#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
42pub struct Id(pub(crate) u64);
43
44#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
45enum StreamId {
46    Implicit(u64),
47    Explicit(u64),
48}
49
50/// Id of a [`Connection`]. Does not include parent Ids of this connection.
51/// It is only unique with respect to its sibling connections created by
52/// [`Connection::sub_connection`].
53#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
54struct ConnectionId(pub(crate) u32);
55
56/// Unique id of a stream and all its parent [`ConnectionId`]s.
57#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
58struct UniqueId {
59    cids: Vec<ConnectionId>,
60    id: StreamId,
61}
62
63type StreamSend = oneshot::Sender<(QuicRecvStream, usize)>;
64type StreamRecv = oneshot::Receiver<(QuicRecvStream, usize)>;
65
66/// Manages accepting of new streams.
67pub struct StreamManager {
68    acceptor: QuicStreamAcceptor,
69    cmd_send: mpsc::UnboundedSender<Cmd>,
70    cmd_recv: mpsc::UnboundedReceiver<Cmd>,
71    pending: HashMap<UniqueId, StreamSend>,
72    accepted: HashMap<UniqueId, (QuicRecvStream, usize)>,
73}
74
75/// Used to create grouped sub-streams.
76///
77/// Connections can have sub-connections. Streams created via
78/// [`Connection::byte_stream`] and [`Connection::stream`] are tied to their
79/// connection. Streams created with the same [`Id`] but for different
80/// connections will not conflict with each other.
81#[derive(Debug)]
82pub struct Connection {
83    cids: Vec<ConnectionId>,
84    next_cid: Arc<AtomicU32>,
85    handle: Handle,
86    cmd: mpsc::UnboundedSender<Cmd>,
87    next_implicit_id: u64,
88}
89
90/// Send part of the bytes stream.
91pub struct SendStreamBytes {
92    inner: QuicSendStream,
93}
94
95/// Receive part of the bytes stream.
96pub struct ReceiveStreamBytes {
97    inner: ReceiveStreamWrapper,
98}
99
100/// Send part of the serialized stream.
101pub type SendStream<T> = SymmetricallyFramed<
102    FramedWrite<SendStreamBytes, LengthDelimitedCodec>,
103    T,
104    SymmetricalBincode<T>,
105>;
106
107/// A temporary typed send stream which borrows a [`SendStreamBytes`].
108pub type SendStreamTemp<'a, T> = SymmetricallyFramed<
109    FramedWrite<&'a mut SendStreamBytes, LengthDelimitedCodec>,
110    T,
111    SymmetricalBincode<T>,
112>;
113
114/// Receive part of the serialized stream.
115pub type ReceiveStream<T> = SymmetricallyFramed<
116    FramedRead<ReceiveStreamBytes, LengthDelimitedCodec>,
117    T,
118    SymmetricalBincode<T>,
119>;
120
121/// A temporary typed receive stream which borrows a [`ReceiveStreamBytes`].
122pub type ReceiveStreamTemp<'a, T> = SymmetricallyFramed<
123    FramedRead<&'a mut ReceiveStreamBytes, LengthDelimitedCodec>,
124    T,
125    SymmetricalBincode<T>,
126>;
127
128enum ReceiveStreamWrapper {
129    Channel { stream_recv: StreamRecv },
130    Stream { recv_stream: QuicRecvStream },
131}
132
133#[derive(Debug)]
134enum Cmd {
135    NewStream {
136        uid: UniqueId,
137        stream_return: StreamSend,
138    },
139    AcceptedStream {
140        uid: UniqueId,
141        stream: QuicRecvStream,
142        bytes_read: usize,
143    },
144}
145
146impl StreamManager {
147    pub fn new(acceptor: QuicStreamAcceptor) -> Self {
148        let (cmd_send, cmd_recv) = mpsc::unbounded_channel();
149        Self {
150            acceptor,
151            cmd_send,
152            cmd_recv,
153            pending: Default::default(),
154            accepted: Default::default(),
155        }
156    }
157
158    /// Start the StreamManager to accept streams.
159    ///
160    /// This method needs to be continually polled to establish new streams.
161    #[tracing::instrument(skip_all)]
162    pub async fn start(mut self) {
163        loop {
164            // Guard against possible cancellation unsafety of `accept_receive_stream`
165            let mut receive_stream = pin!(self.acceptor.accept_receive_stream());
166            select! {
167                res = &mut receive_stream => {
168                    match res {
169                        Ok(Some(stream)) => {
170                            debug!("accepted stream");
171                            Self::accepted(stream, self.cmd_send.clone());
172                        }
173                        Ok(None) => {
174                            debug!("remote closed");
175                            return;
176                        }
177                        Err(err) => {
178                            error!(%err, "unable to accept stream");
179                            return;
180                        }
181                    }
182                }
183                Some(cmd) = self.cmd_recv.recv() => {   // recv() is cancel safe
184                    debug!(?cmd, "received cmd");
185                    match cmd {
186                        Cmd::NewStream {uid, stream_return} => {
187                            if let Some(accepted) = self.accepted.remove(&uid) {
188                                if stream_return.send(accepted).is_err() {
189                                    debug!("accepted remote stream but local receiver is closed");
190                                }
191                                debug!("sending new stream to receiver");
192                                continue;
193                            }
194                            match self.pending.entry(uid) {
195                                Entry::Occupied(occupied_entry) => {
196                                    panic!("Duplicate unique id: {:?}", occupied_entry.key())
197                                },
198                                Entry::Vacant(vacant_entry) => {vacant_entry.insert(stream_return);},
199                            }
200                        }
201                        Cmd::AcceptedStream {uid, stream, bytes_read} => {
202                            if let Some(stream_ret) = self.pending.remove(&uid) {
203                               if stream_ret.send((stream, bytes_read)).is_err() {
204                                debug!("accepted remote stream but local receiver is closed");
205                               }
206                            } else {
207                                debug!("accepted stream but no pending");
208                                self.accepted.insert(uid, (stream, bytes_read));
209                            }
210                        }
211                    }
212                }
213            }
214        }
215    }
216
217    // not taking &self to work around borrow issue
218    fn accepted(mut stream: QuicRecvStream, cmd_send: mpsc::UnboundedSender<Cmd>) {
219        tokio::spawn(async move {
220            let (uid, bytes_read) = match UniqueId::read_from(&mut stream).await {
221                Ok(ret) => ret,
222                Err(err) => {
223                    error!(?err, "unable to read stream unique id");
224                    return;
225                }
226            };
227            cmd_send
228                .send(Cmd::AcceptedStream {
229                    uid,
230                    stream,
231                    bytes_read,
232                })
233                .expect("cmd_rcv is owned by StreamManager")
234        });
235    }
236}
237
238/// Possible connection errors.
239#[derive(thiserror::Error, Debug)]
240pub enum ConnectionError {
241    #[error("Unable to open stream")]
242    OpenStream(#[source] s2n_quic::connection::Error),
243    #[error("io error during stream establishment")]
244    IoError(#[source] io::Error),
245    #[error("StreamManager is dropped and not accepting connections")]
246    StreamManagerDropped,
247    #[error("Stream unique id deserialization failed")]
248    UniqueIdDeserialization(#[source] bincode::Error),
249    #[error("Stream unique id serialization failed")]
250    UniqueIdSerialization(#[source] bincode::Error),
251    #[error("Reached maximum number of sub connections")]
252    SubConnectionLimitReached,
253}
254
255impl Connection {
256    pub fn new(quic_conn: s2n_quic::Connection) -> (Self, StreamManager) {
257        let (handle, acceptor) = quic_conn.split();
258        let stream_manager = StreamManager::new(acceptor);
259        let conn = Self {
260            cids: vec![],
261            next_cid: Arc::new(AtomicU32::new(0)),
262            handle,
263            cmd: stream_manager.cmd_send.clone(),
264            next_implicit_id: 0,
265        };
266        (conn, stream_manager)
267    }
268
269    /// Create a sub-connection. The n'th call to sub_connection
270    /// is paired with the n'th call to `sub_connection` on the corresponding
271    /// [`Connection`] of the other party. Creating a sub-connection results
272    /// in no immediate communication and is a fast synchronous operation.
273    #[tracing::instrument(level = Level::DEBUG, skip(self), ret)]
274    pub fn sub_connection(&mut self) -> Self {
275        let cid = self.next_cid.fetch_add(1, Ordering::Relaxed);
276        let mut cids = self.cids.clone();
277        cids.push(ConnectionId(cid));
278        Self {
279            cids,
280            next_cid: Arc::new(AtomicU32::new(0)),
281            handle: self.handle.clone(),
282            cmd: self.cmd.clone(),
283            next_implicit_id: 0,
284        }
285    }
286
287    async fn internal_byte_stream(
288        &self,
289        stream_id: StreamId,
290    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
291        let uid = UniqueId::new(self.cids.clone(), stream_id);
292        let mut snd = self
293            .handle
294            .clone()
295            .open_send_stream()
296            .await
297            .map_err(ConnectionError::OpenStream)?;
298        let bytes_written = uid.write_into(&mut snd).await?;
299        event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes_written);
300        let (stream_return, stream_recv) = oneshot::channel();
301        self.cmd
302            .send(Cmd::NewStream { uid, stream_return })
303            .map_err(|_| ConnectionError::StreamManagerDropped)?;
304        let snd = SendStreamBytes { inner: snd };
305        let recv = ReceiveStreamBytes {
306            inner: ReceiveStreamWrapper::Channel { stream_recv },
307        };
308        Ok((snd, recv))
309    }
310
311    /// Establish a byte stream over this connection with the provided Id.
312    pub async fn byte_stream(
313        &mut self,
314    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
315        self.next_implicit_id += 1;
316        self.internal_byte_stream(StreamId::Implicit(self.next_implicit_id - 1))
317            .await
318    }
319
320    /// Establish a byte stream over this connection with the provided Id.
321    pub async fn byte_stream_with_id(
322        &self,
323        id: Id,
324    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
325        self.internal_byte_stream(StreamId::Explicit(id.0)).await
326    }
327
328    /// Establish a typed stream over this connection.
329    async fn internal_stream<T: Serialize + DeserializeOwned>(
330        &self,
331        id: StreamId,
332    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
333        let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
334        let mut ld_codec = LengthDelimitedCodec::builder();
335        // TODO what is a sensible max length?
336        const MB: usize = 1024 * 1024;
337        ld_codec.max_frame_length(256 * MB);
338        let framed_send = ld_codec.new_write(send_bytes);
339        let framed_read = ld_codec.new_read(recv_bytes);
340        let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
341        let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
342        Ok((serde_send, serde_read))
343    }
344
345    /// Establish a typed stream over this connection.
346    pub async fn stream<T: Serialize + DeserializeOwned>(
347        &mut self,
348    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
349        self.next_implicit_id += 1;
350        self.internal_stream(StreamId::Implicit(self.next_implicit_id - 1))
351            .await
352    }
353
354    /// Establish a typed stream over this connection with the provided explicit
355    /// Id.
356    pub async fn stream_with_id<T: Serialize + DeserializeOwned>(
357        &self,
358        id: Id,
359    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
360        self.internal_stream(StreamId::Explicit(id.0)).await
361    }
362
363    async fn internal_request_response_stream<T: Serialize, S: DeserializeOwned>(
364        &self,
365        id: StreamId,
366    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
367        let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
368        let framed_send = default_codec().new_write(send_bytes);
369        let framed_read = default_codec().new_read(recv_bytes);
370        let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
371        let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
372        Ok((serde_send, serde_read))
373    }
374
375    /// Establish a typed request-response stream over this connection with
376    /// differing types for the request and response.
377    pub async fn request_response_stream<T: Serialize, S: DeserializeOwned>(
378        &mut self,
379    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
380        self.next_implicit_id += 1;
381        self.internal_request_response_stream(StreamId::Implicit(self.next_implicit_id - 1))
382            .await
383    }
384
385    /// Establish a typed request-response stream over this connection with
386    /// differing types for the request and response.
387    pub async fn request_response_stream_with_id<T: Serialize, S: DeserializeOwned>(
388        &self,
389        id: Id,
390    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
391        self.internal_request_response_stream(StreamId::Explicit(id.0))
392            .await
393    }
394}
395
396impl Id {
397    pub fn new(id: u64) -> Self {
398        Self(id)
399    }
400}
401
402fn bincode_opts() -> impl bincode::Options {
403    bincode::options().with_big_endian().with_varint_encoding()
404}
405
406impl UniqueId {
407    fn new(cids: Vec<ConnectionId>, id: StreamId) -> Self {
408        Self { cids, id }
409    }
410
411    async fn write_into<W: AsyncWrite>(&self, write: W) -> Result<usize, ConnectionError> {
412        let mut write = pin!(write);
413        let mut options = bincode_opts();
414        let serialized = (&mut options)
415            .serialize(self)
416            .map_err(ConnectionError::UniqueIdSerialization)?;
417        write
418            .write_u32(
419                serialized
420                    .len()
421                    .try_into()
422                    .map_err(|_| ConnectionError::SubConnectionLimitReached)?,
423            )
424            .await
425            .map_err(ConnectionError::IoError)?;
426        write
427            .write_all(&serialized)
428            .await
429            .map_err(ConnectionError::IoError)?;
430        Ok(mem::size_of::<u32>() + serialized.len())
431    }
432
433    async fn read_from<R: AsyncRead>(reader: R) -> Result<(Self, usize), ConnectionError> {
434        let mut reader = pin!(reader);
435        let len = reader.read_u32().await.map_err(ConnectionError::IoError)?;
436        let mut buf = vec![0; len as usize];
437        reader
438            .read_exact(&mut buf)
439            .await
440            .map_err(ConnectionError::IoError)?;
441        let uid = bincode_opts()
442            .deserialize(&buf)
443            .map_err(ConnectionError::UniqueIdDeserialization)?;
444        Ok((uid, mem::size_of::<u32>() + len as usize))
445    }
446}
447
448/// Possible byte stream errors.
449#[derive(thiserror::Error, Debug)]
450pub enum StreamError {
451    #[error("unable to flush stream")]
452    Flush(#[source] s2n_quic::stream::Error),
453    #[error("unable to close stream")]
454    Close(#[source] s2n_quic::stream::Error),
455    #[error("unable to finish stream")]
456    Finish(#[source] s2n_quic::stream::Error),
457}
458
459impl SendStreamBytes {
460    pub async fn flush(&mut self) -> Result<(), StreamError> {
461        self.inner.flush().await.map_err(StreamError::Flush)
462    }
463
464    pub fn finish(&mut self) -> Result<(), StreamError> {
465        self.inner.finish().map_err(StreamError::Finish)
466    }
467
468    pub async fn close(&mut self) -> Result<(), StreamError> {
469        self.inner.close().await.map_err(StreamError::Close)
470    }
471
472    pub fn as_stream<T: Serialize>(&mut self) -> SendStreamTemp<T> {
473        let framed_send = default_codec().new_write(self);
474        SymmetricallyFramed::new(framed_send, Bincode::default())
475    }
476}
477
478impl AsyncWrite for SendStreamBytes {
479    fn poll_write(
480        mut self: Pin<&mut Self>,
481        cx: &mut Context<'_>,
482        buf: &[u8],
483    ) -> Poll<Result<usize, Error>> {
484        let inner = Pin::new(&mut self.inner);
485        trace_poll(inner.poll_write(cx, buf))
486    }
487
488    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
489        let inner = Pin::new(&mut self.inner);
490        AsyncWrite::poll_flush(inner, cx)
491    }
492
493    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
494        let inner = Pin::new(&mut self.inner);
495        inner.poll_shutdown(cx)
496    }
497
498    fn poll_write_vectored(
499        mut self: Pin<&mut Self>,
500        cx: &mut Context<'_>,
501        bufs: &[IoSlice<'_>],
502    ) -> Poll<Result<usize, Error>> {
503        let inner = Pin::new(&mut self.inner);
504        trace_poll(inner.poll_write_vectored(cx, bufs))
505    }
506
507    fn is_write_vectored(&self) -> bool {
508        self.inner.is_write_vectored()
509    }
510}
511
512fn trace_poll(p: Poll<io::Result<usize>>) -> Poll<io::Result<usize>> {
513    if let Poll::Ready(Ok(bytes)) = p {
514        event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes);
515    }
516    p
517}
518
519impl ReceiveStreamBytes {
520    pub fn as_stream<T: DeserializeOwned>(&mut self) -> ReceiveStreamTemp<T> {
521        let framed_read = default_codec().new_read(self);
522        SymmetricallyFramed::new(framed_read, Bincode::default())
523    }
524}
525
526// Implement AsyncRead for ReceiveStream to poll the oneshot Receiver first if
527// there is not already a channel.
528impl AsyncRead for ReceiveStreamBytes {
529    fn poll_read(
530        mut self: Pin<&mut Self>,
531        cx: &mut Context<'_>,
532        buf: &mut ReadBuf<'_>,
533    ) -> Poll<std::io::Result<()>> {
534        match &mut self.inner {
535            ReceiveStreamWrapper::Channel { stream_recv } => match Pin::new(stream_recv).poll(cx) {
536                Poll::Pending => Poll::Pending,
537                Poll::Ready(Ok((recv_stream, bytes_read))) => {
538                    // We know we read those bytes in the StreamManager, so we emit
539                    // the corresponding event here in the correct span.
540                    event!(target: "cryprot_metrics", Level::TRACE, bytes_read);
541                    self.inner = ReceiveStreamWrapper::Stream { recv_stream };
542                    self.poll_read(cx, buf)
543                }
544                Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::other(Box::new(err)))),
545            },
546            ReceiveStreamWrapper::Stream { recv_stream } => {
547                let len = buf.filled().len();
548                let poll = Pin::new(recv_stream).poll_read(cx, buf);
549                if let Poll::Ready(Ok(())) = poll {
550                    let bytes = buf.filled().len() - len;
551                    if bytes > 0 {
552                        event!(target: "cryprot_metrics", Level::TRACE, bytes_read = bytes);
553                    }
554                }
555                poll
556            }
557        }
558    }
559}
560
561fn default_codec() -> length_delimited::Builder {
562    let mut ld_codec = LengthDelimitedCodec::builder();
563    const MB: usize = 1024 * 1024;
564    ld_codec.max_frame_length(20 * MB);
565    ld_codec
566}
567
568#[cfg(test)]
569mod tests {
570    use std::u8;
571
572    use anyhow::{Context, Result};
573    use futures::{SinkExt, StreamExt};
574    use tokio::{
575        io::{AsyncReadExt, AsyncWriteExt},
576        task::JoinSet,
577    };
578    use tracing::debug;
579
580    use crate::{
581        Id,
582        testing::{init_tracing, local_conn},
583    };
584
585    #[tokio::test]
586    async fn create_local_conn() -> Result<()> {
587        let _g = init_tracing();
588        let _ = local_conn().await?;
589        Ok(())
590    }
591
592    #[tokio::test]
593    async fn byte_stream() -> Result<()> {
594        let _g = init_tracing();
595        let (mut s, mut c) = local_conn().await?;
596        let (mut s_send, _) = s.byte_stream().await?;
597        let (_, mut c_recv) = c.byte_stream().await?;
598        let send_buf = b"hello there";
599        s_send.write_all(send_buf).await?;
600        let mut buf = [0; 11];
601        c_recv.read_exact(&mut buf).await?;
602        assert_eq!(send_buf, &buf);
603        Ok(())
604    }
605
606    #[tokio::test]
607    async fn byte_stream_explicit_implicit_id() -> Result<()> {
608        let _g = init_tracing();
609        let (mut s, mut c) = local_conn().await?;
610        let (mut s_send1, _) = s.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
611        let (mut s_send2, _) = s.byte_stream().await?;
612        let (_, mut c_recv1) = c.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
613        let (_, mut c_recv2) = c.byte_stream().await?;
614        let send_buf1 = b"hello there";
615        s_send1.write_all(send_buf1).await?;
616        let mut buf = [0; 11];
617        c_recv1.read_exact(&mut buf).await?;
618        assert_eq!(send_buf1, &buf);
619
620        let send_buf2 = b"general kenobi";
621        s_send2.write_all(send_buf2).await?;
622        let mut buf = [0; 14];
623        c_recv2.read_exact(&mut buf).await?;
624        assert_eq!(send_buf2, &buf);
625        Ok(())
626    }
627
628    #[tokio::test]
629    async fn byte_stream_different_order() -> Result<()> {
630        let _g = init_tracing();
631        let (mut s, mut c) = local_conn().await?;
632        let (mut s_send, mut s_recv) = s.byte_stream().await?;
633        let s_send_buf = b"hello there";
634        s_send.write_all(s_send_buf).await?;
635        let mut s_recv_buf = [0; 2];
636        // By already spawning the read task before the client calls c._new_byte_stream
637        // we check that the switch from channel to s2n stream works
638        let jh = tokio::spawn(async move {
639            s_recv.read_exact(&mut s_recv_buf).await.unwrap();
640            s_recv_buf
641        });
642        let (mut c_send, mut c_recv) = c.byte_stream().await?;
643        let mut c_recv_buf = [0; 11];
644        c_recv.read_exact(&mut c_recv_buf).await?;
645        assert_eq!(s_send_buf, &c_recv_buf);
646        let c_send_buf = b"42";
647        c_send.write_all(c_send_buf).await?;
648        let s_recv_buf = jh.await?;
649        assert_eq!(c_send_buf, &s_recv_buf);
650        Ok(())
651    }
652
653    #[tokio::test]
654    async fn many_parallel_byte_streams() -> Result<()> {
655        let _g = init_tracing();
656        let (mut c1, mut c2) = local_conn().await?;
657        let mut jhs = JoinSet::new();
658        for i in 0..10 {
659            let ((mut s, _), (_, mut r)) =
660                tokio::try_join!(c1.byte_stream(), c2.byte_stream()).unwrap();
661
662            let jh = tokio::spawn(async move {
663                let buf = vec![0; 10 * 1024 * 1024];
664                s.write_all(&buf).await.unwrap();
665                debug!("wrote buf {i}");
666            });
667            jhs.spawn(jh);
668            let jh = tokio::spawn(async move {
669                let mut buf = vec![0; 10 * 1024 * 1024];
670                r.read_exact(&mut buf).await.unwrap();
671                debug!("received buf {i}");
672            });
673            jhs.spawn(jh);
674        }
675        let res = jhs.join_all().await;
676        for res in res {
677            res.unwrap();
678        }
679        Ok(())
680    }
681
682    #[tokio::test]
683    async fn serde_stream() -> Result<()> {
684        let _g = init_tracing();
685        let (mut s, mut c) = local_conn().await?;
686        let (mut snd, _) = s.stream::<Vec<i32>>().await?;
687        let (_, mut recv) = c.stream::<Vec<i32>>().await?;
688        snd.send(vec![1, 2, 3]).await?;
689        let ret = recv.next().await.context("recv")??;
690        assert_eq!(vec![1, 2, 3], ret);
691        drop(snd);
692        let ret = recv.next().await.map(|res| res.map_err(|_| ()));
693        assert_eq!(None, ret);
694        Ok(())
695    }
696
697    #[tokio::test]
698    async fn serde_stream_block() -> Result<()> {
699        let _g = init_tracing();
700        let (mut s, mut c) = local_conn().await?;
701        let (mut snd, _) = s.stream().await?;
702        let (_, mut recv) = c.stream().await?;
703        snd.send(vec![u8::MAX; 16]).await?;
704        let ret: Vec<_> = recv.next().await.context("recv")??;
705        assert_eq!(vec![u8::MAX; 16], ret);
706        Ok(())
707    }
708
709    #[tokio::test]
710    async fn serde_byte_stream_as_stream() -> Result<()> {
711        let _g = init_tracing();
712        let (mut s, mut c) = local_conn().await?;
713        let (mut s_send, _) = s.byte_stream().await?;
714        let (_, mut c_recv) = c.byte_stream().await?;
715        {
716            let mut send_ser1 = s_send.as_stream::<i32>();
717            let mut recv_ser1 = c_recv.as_stream::<i32>();
718            send_ser1.send(42).await?;
719            let ret = recv_ser1.next().await.context("recv")??;
720            assert_eq!(42, ret);
721        }
722        {
723            let mut send_ser2 = s_send.as_stream::<Vec<i32>>();
724            let mut recv_ser2 = c_recv.as_stream::<Vec<i32>>();
725            send_ser2.send(vec![1, 2, 3]).await?;
726            let ret = recv_ser2.next().await.context("recv")??;
727            assert_eq!(vec![1, 2, 3], ret);
728        }
729        Ok(())
730    }
731
732    #[tokio::test]
733    async fn serde_request_response_stream() -> Result<()> {
734        let _g = init_tracing();
735        let (mut s, mut c) = local_conn().await?;
736        let (mut snd1, mut recv1) = s.request_response_stream::<Vec<i32>, String>().await?;
737        let (mut snd2, mut recv2) = c.request_response_stream::<String, Vec<i32>>().await?;
738        snd1.send(vec![1, 2, 3]).await?;
739        let ret = recv2.next().await.context("recv")??;
740        assert_eq!(vec![1, 2, 3], ret);
741        snd2.send("hello there".to_string()).await?;
742        let ret = recv1.next().await.context("recv2")??;
743        assert_eq!("hello there", &ret);
744        Ok(())
745    }
746
747    #[tokio::test]
748    async fn sub_connection() -> Result<()> {
749        let _g = init_tracing();
750        let (mut s1, mut c1) = local_conn().await?;
751        let mut s2 = s1.sub_connection();
752        let mut c2 = c1.sub_connection();
753        let _ = s1.byte_stream();
754        let _ = c1.byte_stream();
755        let (mut snd, _) = s2.stream::<Vec<i32>>().await?;
756        let (_, mut recv) = c2.stream::<Vec<i32>>().await?;
757
758        snd.send(vec![1, 2, 3]).await?;
759        let ret = recv.next().await.context("recv")??;
760        assert_eq!(vec![1, 2, 3], ret);
761        Ok(())
762    }
763
764    #[tokio::test]
765    async fn sub_sub_connection() -> Result<()> {
766        let _g = init_tracing();
767        let (mut s1, mut c1) = local_conn().await?;
768        let mut s2 = s1.sub_connection();
769        let mut c2 = c1.sub_connection();
770        let mut s3 = s2.sub_connection();
771        let mut c3 = c2.sub_connection();
772        let _ = s1.byte_stream();
773        let _ = c1.byte_stream();
774        let _ = s2.byte_stream();
775        let _ = c2.byte_stream();
776        let (mut snd, _) = s3.stream::<Vec<i32>>().await?;
777        let (_, mut recv) = c3.stream::<Vec<i32>>().await?;
778
779        snd.send(vec![1, 2, 3]).await?;
780        let ret = recv.next().await.context("recv")??;
781        assert_eq!(vec![1, 2, 3], ret);
782        Ok(())
783    }
784}