Skip to main content

cryprot_net/
lib.rs

1//! A networking library providing abstractions on top [`s2n_quic`].
2use std::{
3    collections::{HashMap, hash_map::Entry},
4    future::Future,
5    io::{Error, IoSlice},
6    mem,
7    pin::{Pin, pin},
8    task::{Context, Poll},
9};
10
11use bincode::Options;
12use s2n_quic::{
13    connection::{Handle, StreamAcceptor as QuicStreamAcceptor},
14    stream::{ReceiveStream as QuicRecvStream, SendStream as QuicSendStream},
15};
16use serde::{Deserialize, Serialize, de::DeserializeOwned};
17use tokio::{
18    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
19    select,
20    sync::{mpsc, oneshot},
21};
22use tokio_serde::{
23    SymmetricallyFramed,
24    formats::{Bincode, SymmetricalBincode},
25};
26use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec, length_delimited};
27use tracing::{Level, debug, error, event};
28
29#[cfg(feature = "metrics")]
30pub mod metrics;
31
32#[doc(hidden)]
33#[cfg(any(test, feature = "__testing"))]
34pub mod testing;
35
36/// Explicit Id provided by the user for a stream for a specific [`Connection`].
37#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
38pub struct Id(pub(crate) u64);
39
40#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
41enum StreamId {
42    Implicit(u64),
43    Explicit(u64),
44}
45
46/// Id of a [`Connection`]. Does not include parent Ids of this connection.
47/// It is only unique with respect to its sibling connections created by
48/// [`Connection::sub_connection`].
49#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
50struct ConnectionId(pub(crate) u32);
51
52/// Unique id of a stream and all its parent [`ConnectionId`]s.
53#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
54struct UniqueId {
55    cids: Vec<ConnectionId>,
56    id: StreamId,
57}
58
59type StreamSend = oneshot::Sender<(QuicRecvStream, usize)>;
60type StreamRecv = oneshot::Receiver<(QuicRecvStream, usize)>;
61
62/// Manages accepting of new streams.
63pub struct StreamManager {
64    acceptor: QuicStreamAcceptor,
65    cmd_send: mpsc::UnboundedSender<Cmd>,
66    cmd_recv: mpsc::UnboundedReceiver<Cmd>,
67    pending: HashMap<UniqueId, StreamSend>,
68    accepted: HashMap<UniqueId, (QuicRecvStream, usize)>,
69}
70
71/// Used to create grouped sub-streams.
72///
73/// Connections can have sub-connections. Streams created via
74/// [`Connection::byte_stream`] and [`Connection::stream`] are tied to their
75/// connection. Streams created with the same [`Id`] but for different
76/// connections will not conflict with each other.
77#[derive(Debug)]
78pub struct Connection {
79    cids: Vec<ConnectionId>,
80    next_cid: u32,
81    handle: Handle,
82    cmd: mpsc::UnboundedSender<Cmd>,
83    next_implicit_id: u64,
84}
85
86/// Send part of the bytes stream.
87pub struct SendStreamBytes {
88    inner: QuicSendStream,
89}
90
91/// Receive part of the bytes stream.
92pub struct ReceiveStreamBytes {
93    inner: ReceiveStreamWrapper,
94}
95
96/// Send part of the serialized stream.
97pub type SendStream<T> = SymmetricallyFramed<
98    FramedWrite<SendStreamBytes, LengthDelimitedCodec>,
99    T,
100    SymmetricalBincode<T>,
101>;
102
103/// A temporary typed send stream which borrows a [`SendStreamBytes`].
104pub type SendStreamTemp<'a, T> = SymmetricallyFramed<
105    FramedWrite<&'a mut SendStreamBytes, LengthDelimitedCodec>,
106    T,
107    SymmetricalBincode<T>,
108>;
109
110/// Receive part of the serialized stream.
111pub type ReceiveStream<T> = SymmetricallyFramed<
112    FramedRead<ReceiveStreamBytes, LengthDelimitedCodec>,
113    T,
114    SymmetricalBincode<T>,
115>;
116
117/// A temporary typed receive stream which borrows a [`ReceiveStreamBytes`].
118pub type ReceiveStreamTemp<'a, T> = SymmetricallyFramed<
119    FramedRead<&'a mut ReceiveStreamBytes, LengthDelimitedCodec>,
120    T,
121    SymmetricalBincode<T>,
122>;
123
124enum ReceiveStreamWrapper {
125    Channel { stream_recv: StreamRecv },
126    Stream { recv_stream: QuicRecvStream },
127}
128
129#[derive(Debug)]
130enum Cmd {
131    NewStream {
132        uid: UniqueId,
133        stream_return: StreamSend,
134    },
135    AcceptedStream {
136        uid: UniqueId,
137        stream: QuicRecvStream,
138        bytes_read: usize,
139    },
140}
141
142impl StreamManager {
143    pub fn new(acceptor: QuicStreamAcceptor) -> Self {
144        let (cmd_send, cmd_recv) = mpsc::unbounded_channel();
145        Self {
146            acceptor,
147            cmd_send,
148            cmd_recv,
149            pending: Default::default(),
150            accepted: Default::default(),
151        }
152    }
153
154    /// Start the StreamManager to accept streams.
155    ///
156    /// This method needs to be continually polled to establish new streams.
157    #[tracing::instrument(skip_all)]
158    pub async fn start(mut self) {
159        loop {
160            // Guard against possible cancellation unsafety of `accept_receive_stream`
161            let mut receive_stream = pin!(self.acceptor.accept_receive_stream());
162            select! {
163                res = &mut receive_stream => {
164                    match res {
165                        Ok(Some(stream)) => {
166                            debug!("accepted stream");
167                            Self::accepted(stream, self.cmd_send.clone());
168                        }
169                        Ok(None) => {
170                            debug!("remote closed");
171                            return;
172                        }
173                        Err(err) => {
174                            error!(%err, "unable to accept stream");
175                            return;
176                        }
177                    }
178                }
179                Some(cmd) = self.cmd_recv.recv() => {   // recv() is cancel safe
180                    debug!(?cmd, "received cmd");
181                    match cmd {
182                        Cmd::NewStream {uid, stream_return} => {
183                            if let Some(accepted) = self.accepted.remove(&uid) {
184                                if stream_return.send(accepted).is_err() {
185                                    debug!("accepted remote stream but local receiver is closed");
186                                }
187                                debug!("sending new stream to receiver");
188                                continue;
189                            }
190                            match self.pending.entry(uid) {
191                                Entry::Occupied(occupied_entry) => {
192                                    panic!("Duplicate unique id: {:?}", occupied_entry.key())
193                                },
194                                Entry::Vacant(vacant_entry) => {vacant_entry.insert(stream_return);},
195                            }
196                        }
197                        Cmd::AcceptedStream {uid, stream, bytes_read} => {
198                            if let Some(stream_ret) = self.pending.remove(&uid) {
199                               if stream_ret.send((stream, bytes_read)).is_err() {
200                                debug!("accepted remote stream but local receiver is closed");
201                               }
202                            } else {
203                                debug!("accepted stream but no pending");
204                                self.accepted.insert(uid, (stream, bytes_read));
205                            }
206                        }
207                    }
208                }
209            }
210        }
211    }
212
213    // not taking &self to work around borrow issue
214    fn accepted(mut stream: QuicRecvStream, cmd_send: mpsc::UnboundedSender<Cmd>) {
215        tokio::spawn(async move {
216            let (uid, bytes_read) = match UniqueId::read_from(&mut stream).await {
217                Ok(ret) => ret,
218                Err(err) => {
219                    error!(?err, "unable to read stream unique id");
220                    return;
221                }
222            };
223            cmd_send
224                .send(Cmd::AcceptedStream {
225                    uid,
226                    stream,
227                    bytes_read,
228                })
229                .expect("cmd_rcv is owned by StreamManager")
230        });
231    }
232}
233
234/// Possible connection errors.
235#[derive(thiserror::Error, Debug)]
236pub enum ConnectionError {
237    #[error("Unable to open stream")]
238    OpenStream(#[source] s2n_quic::connection::Error),
239    #[error("io error during stream establishment")]
240    IoError(#[source] io::Error),
241    #[error("StreamManager is dropped and not accepting connections")]
242    StreamManagerDropped,
243    #[error("Stream unique id deserialization failed")]
244    UniqueIdDeserialization(#[source] bincode::Error),
245    #[error("Stream unique id serialization failed")]
246    UniqueIdSerialization(#[source] bincode::Error),
247    #[error("Reached maximum number of sub connections")]
248    SubConnectionLimitReached,
249}
250
251impl Connection {
252    pub fn new(quic_conn: s2n_quic::Connection) -> (Self, StreamManager) {
253        let (handle, acceptor) = quic_conn.split();
254        let stream_manager = StreamManager::new(acceptor);
255        let conn = Self {
256            cids: vec![],
257            next_cid: 0,
258            handle,
259            cmd: stream_manager.cmd_send.clone(),
260            next_implicit_id: 0,
261        };
262        (conn, stream_manager)
263    }
264
265    /// Create a sub-connection. The n'th call to sub_connection
266    /// is paired with the n'th call to `sub_connection` on the corresponding
267    /// [`Connection`] of the other party. Creating a sub-connection results
268    /// in no immediate communication and is a fast synchronous operation.
269    #[tracing::instrument(level = Level::DEBUG, skip(self), ret)]
270    pub fn sub_connection(&mut self) -> Self {
271        let cid = self.next_cid;
272        self.next_cid += 1;
273        let mut cids = self.cids.clone();
274        cids.push(ConnectionId(cid));
275        Self {
276            cids,
277            next_cid: 0,
278            handle: self.handle.clone(),
279            cmd: self.cmd.clone(),
280            next_implicit_id: 0,
281        }
282    }
283
284    async fn internal_byte_stream(
285        &self,
286        stream_id: StreamId,
287    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
288        let uid = UniqueId::new(self.cids.clone(), stream_id);
289        let mut snd = self
290            .handle
291            .clone()
292            .open_send_stream()
293            .await
294            .map_err(ConnectionError::OpenStream)?;
295        let bytes_written = uid.write_into(&mut snd).await?;
296        event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes_written);
297        let (stream_return, stream_recv) = oneshot::channel();
298        self.cmd
299            .send(Cmd::NewStream { uid, stream_return })
300            .map_err(|_| ConnectionError::StreamManagerDropped)?;
301        let snd = SendStreamBytes { inner: snd };
302        let recv = ReceiveStreamBytes {
303            inner: ReceiveStreamWrapper::Channel { stream_recv },
304        };
305        Ok((snd, recv))
306    }
307
308    /// Establish a byte stream over this connection with the provided Id.
309    pub async fn byte_stream(
310        &mut self,
311    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
312        self.next_implicit_id += 1;
313        self.internal_byte_stream(StreamId::Implicit(self.next_implicit_id - 1))
314            .await
315    }
316
317    /// Establish a byte stream over this connection with the provided Id.
318    pub async fn byte_stream_with_id(
319        &self,
320        id: Id,
321    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
322        self.internal_byte_stream(StreamId::Explicit(id.0)).await
323    }
324
325    /// Establish a typed stream over this connection.
326    async fn internal_stream<T: Serialize + DeserializeOwned>(
327        &self,
328        id: StreamId,
329    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
330        let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
331        let mut ld_codec = LengthDelimitedCodec::builder();
332        // TODO what is a sensible max length?
333        const MB: usize = 1024 * 1024;
334        ld_codec.max_frame_length(256 * MB);
335        let framed_send = ld_codec.new_write(send_bytes);
336        let framed_read = ld_codec.new_read(recv_bytes);
337        let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
338        let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
339        Ok((serde_send, serde_read))
340    }
341
342    /// Establish a typed stream over this connection.
343    pub async fn stream<T: Serialize + DeserializeOwned>(
344        &mut self,
345    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
346        self.next_implicit_id += 1;
347        self.internal_stream(StreamId::Implicit(self.next_implicit_id - 1))
348            .await
349    }
350
351    /// Establish a typed stream over this connection with the provided explicit
352    /// Id.
353    pub async fn stream_with_id<T: Serialize + DeserializeOwned>(
354        &self,
355        id: Id,
356    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
357        self.internal_stream(StreamId::Explicit(id.0)).await
358    }
359
360    async fn internal_request_response_stream<T: Serialize, S: DeserializeOwned>(
361        &self,
362        id: StreamId,
363    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
364        let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
365        let framed_send = default_codec().new_write(send_bytes);
366        let framed_read = default_codec().new_read(recv_bytes);
367        let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
368        let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
369        Ok((serde_send, serde_read))
370    }
371
372    /// Establish a typed request-response stream over this connection with
373    /// differing types for the request and response.
374    pub async fn request_response_stream<T: Serialize, S: DeserializeOwned>(
375        &mut self,
376    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
377        self.next_implicit_id += 1;
378        self.internal_request_response_stream(StreamId::Implicit(self.next_implicit_id - 1))
379            .await
380    }
381
382    /// Establish a typed request-response stream over this connection with
383    /// differing types for the request and response.
384    pub async fn request_response_stream_with_id<T: Serialize, S: DeserializeOwned>(
385        &self,
386        id: Id,
387    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
388        self.internal_request_response_stream(StreamId::Explicit(id.0))
389            .await
390    }
391}
392
393impl Id {
394    pub fn new(id: u64) -> Self {
395        Self(id)
396    }
397}
398
399fn bincode_opts() -> impl bincode::Options {
400    bincode::options().with_big_endian().with_varint_encoding()
401}
402
403impl UniqueId {
404    fn new(cids: Vec<ConnectionId>, id: StreamId) -> Self {
405        Self { cids, id }
406    }
407
408    async fn write_into<W: AsyncWrite>(&self, write: W) -> Result<usize, ConnectionError> {
409        let mut write = pin!(write);
410        let mut options = bincode_opts();
411        let serialized = (&mut options)
412            .serialize(self)
413            .map_err(ConnectionError::UniqueIdSerialization)?;
414        write
415            .write_u32(
416                serialized
417                    .len()
418                    .try_into()
419                    .map_err(|_| ConnectionError::SubConnectionLimitReached)?,
420            )
421            .await
422            .map_err(ConnectionError::IoError)?;
423        write
424            .write_all(&serialized)
425            .await
426            .map_err(ConnectionError::IoError)?;
427        Ok(mem::size_of::<u32>() + serialized.len())
428    }
429
430    async fn read_from<R: AsyncRead>(reader: R) -> Result<(Self, usize), ConnectionError> {
431        let mut reader = pin!(reader);
432        let len = reader.read_u32().await.map_err(ConnectionError::IoError)?;
433        let mut buf = vec![0; len as usize];
434        reader
435            .read_exact(&mut buf)
436            .await
437            .map_err(ConnectionError::IoError)?;
438        let uid = bincode_opts()
439            .deserialize(&buf)
440            .map_err(ConnectionError::UniqueIdDeserialization)?;
441        Ok((uid, mem::size_of::<u32>() + len as usize))
442    }
443}
444
445/// Possible byte stream errors.
446#[derive(thiserror::Error, Debug)]
447pub enum StreamError {
448    #[error("unable to flush stream")]
449    Flush(#[source] s2n_quic::stream::Error),
450    #[error("unable to close stream")]
451    Close(#[source] s2n_quic::stream::Error),
452    #[error("unable to finish stream")]
453    Finish(#[source] s2n_quic::stream::Error),
454}
455
456impl SendStreamBytes {
457    pub async fn flush(&mut self) -> Result<(), StreamError> {
458        self.inner.flush().await.map_err(StreamError::Flush)
459    }
460
461    pub fn finish(&mut self) -> Result<(), StreamError> {
462        self.inner.finish().map_err(StreamError::Finish)
463    }
464
465    pub async fn close(&mut self) -> Result<(), StreamError> {
466        self.inner.close().await.map_err(StreamError::Close)
467    }
468
469    pub fn as_stream<T: Serialize>(&mut self) -> SendStreamTemp<'_, T> {
470        let framed_send = default_codec().new_write(self);
471        SymmetricallyFramed::new(framed_send, Bincode::default())
472    }
473}
474
475impl AsyncWrite for SendStreamBytes {
476    fn poll_write(
477        mut self: Pin<&mut Self>,
478        cx: &mut Context<'_>,
479        buf: &[u8],
480    ) -> Poll<Result<usize, Error>> {
481        let inner = Pin::new(&mut self.inner);
482        trace_poll(inner.poll_write(cx, buf))
483    }
484
485    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
486        let inner = Pin::new(&mut self.inner);
487        AsyncWrite::poll_flush(inner, cx)
488    }
489
490    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
491        let inner = Pin::new(&mut self.inner);
492        inner.poll_shutdown(cx)
493    }
494
495    fn poll_write_vectored(
496        mut self: Pin<&mut Self>,
497        cx: &mut Context<'_>,
498        bufs: &[IoSlice<'_>],
499    ) -> Poll<Result<usize, Error>> {
500        let inner = Pin::new(&mut self.inner);
501        trace_poll(inner.poll_write_vectored(cx, bufs))
502    }
503
504    fn is_write_vectored(&self) -> bool {
505        self.inner.is_write_vectored()
506    }
507}
508
509fn trace_poll(p: Poll<io::Result<usize>>) -> Poll<io::Result<usize>> {
510    if let Poll::Ready(Ok(bytes)) = p {
511        event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes);
512    }
513    p
514}
515
516impl ReceiveStreamBytes {
517    pub fn as_stream<T: DeserializeOwned>(&mut self) -> ReceiveStreamTemp<'_, T> {
518        let framed_read = default_codec().new_read(self);
519        SymmetricallyFramed::new(framed_read, Bincode::default())
520    }
521}
522
523// Implement AsyncRead for ReceiveStream to poll the oneshot Receiver first if
524// there is not already a channel.
525impl AsyncRead for ReceiveStreamBytes {
526    fn poll_read(
527        mut self: Pin<&mut Self>,
528        cx: &mut Context<'_>,
529        buf: &mut ReadBuf<'_>,
530    ) -> Poll<std::io::Result<()>> {
531        match &mut self.inner {
532            ReceiveStreamWrapper::Channel { stream_recv } => match Pin::new(stream_recv).poll(cx) {
533                Poll::Pending => Poll::Pending,
534                Poll::Ready(Ok((recv_stream, bytes_read))) => {
535                    // We know we read those bytes in the StreamManager, so we emit
536                    // the corresponding event here in the correct span.
537                    event!(target: "cryprot_metrics", Level::TRACE, bytes_read);
538                    self.inner = ReceiveStreamWrapper::Stream { recv_stream };
539                    self.poll_read(cx, buf)
540                }
541                Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::other(Box::new(err)))),
542            },
543            ReceiveStreamWrapper::Stream { recv_stream } => {
544                let len = buf.filled().len();
545                let poll = Pin::new(recv_stream).poll_read(cx, buf);
546                if let Poll::Ready(Ok(())) = poll {
547                    let bytes = buf.filled().len() - len;
548                    if bytes > 0 {
549                        event!(target: "cryprot_metrics", Level::TRACE, bytes_read = bytes);
550                    }
551                }
552                poll
553            }
554        }
555    }
556}
557
558fn default_codec() -> length_delimited::Builder {
559    let mut ld_codec = LengthDelimitedCodec::builder();
560    const MB: usize = 1024 * 1024;
561    ld_codec.max_frame_length(20 * MB);
562    ld_codec
563}
564
565#[cfg(test)]
566mod tests {
567    use std::u8;
568
569    use anyhow::{Context, Result};
570    use futures::{SinkExt, StreamExt};
571    use tokio::{
572        io::{AsyncReadExt, AsyncWriteExt},
573        task::JoinSet,
574    };
575    use tracing::debug;
576
577    use crate::{
578        Id,
579        testing::{init_tracing, local_conn},
580    };
581
582    #[tokio::test]
583    async fn create_local_conn() -> Result<()> {
584        let _g = init_tracing();
585        let _ = local_conn().await?;
586        Ok(())
587    }
588
589    #[tokio::test]
590    async fn byte_stream() -> Result<()> {
591        let _g = init_tracing();
592        let (mut s, mut c) = local_conn().await?;
593        let (mut s_send, _) = s.byte_stream().await?;
594        let (_, mut c_recv) = c.byte_stream().await?;
595        let send_buf = b"hello there";
596        s_send.write_all(send_buf).await?;
597        let mut buf = [0; 11];
598        c_recv.read_exact(&mut buf).await?;
599        assert_eq!(send_buf, &buf);
600        Ok(())
601    }
602
603    #[tokio::test]
604    async fn byte_stream_explicit_implicit_id() -> Result<()> {
605        let _g = init_tracing();
606        let (mut s, mut c) = local_conn().await?;
607        let (mut s_send1, _) = s.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
608        let (mut s_send2, _) = s.byte_stream().await?;
609        let (_, mut c_recv1) = c.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
610        let (_, mut c_recv2) = c.byte_stream().await?;
611        let send_buf1 = b"hello there";
612        s_send1.write_all(send_buf1).await?;
613        let mut buf = [0; 11];
614        c_recv1.read_exact(&mut buf).await?;
615        assert_eq!(send_buf1, &buf);
616
617        let send_buf2 = b"general kenobi";
618        s_send2.write_all(send_buf2).await?;
619        let mut buf = [0; 14];
620        c_recv2.read_exact(&mut buf).await?;
621        assert_eq!(send_buf2, &buf);
622        Ok(())
623    }
624
625    #[tokio::test]
626    async fn byte_stream_different_order() -> Result<()> {
627        let _g = init_tracing();
628        let (mut s, mut c) = local_conn().await?;
629        let (mut s_send, mut s_recv) = s.byte_stream().await?;
630        let s_send_buf = b"hello there";
631        s_send.write_all(s_send_buf).await?;
632        let mut s_recv_buf = [0; 2];
633        // By already spawning the read task before the client calls c._new_byte_stream
634        // we check that the switch from channel to s2n stream works
635        let jh = tokio::spawn(async move {
636            s_recv.read_exact(&mut s_recv_buf).await.unwrap();
637            s_recv_buf
638        });
639        let (mut c_send, mut c_recv) = c.byte_stream().await?;
640        let mut c_recv_buf = [0; 11];
641        c_recv.read_exact(&mut c_recv_buf).await?;
642        assert_eq!(s_send_buf, &c_recv_buf);
643        let c_send_buf = b"42";
644        c_send.write_all(c_send_buf).await?;
645        let s_recv_buf = jh.await?;
646        assert_eq!(c_send_buf, &s_recv_buf);
647        Ok(())
648    }
649
650    #[tokio::test]
651    async fn many_parallel_byte_streams() -> Result<()> {
652        let _g = init_tracing();
653        let (mut c1, mut c2) = local_conn().await?;
654        let mut jhs = JoinSet::new();
655        for i in 0..10 {
656            let ((mut s, _), (_, mut r)) =
657                tokio::try_join!(c1.byte_stream(), c2.byte_stream()).unwrap();
658
659            let jh = tokio::spawn(async move {
660                let buf = vec![0; 10 * 1024 * 1024];
661                s.write_all(&buf).await.unwrap();
662                debug!("wrote buf {i}");
663            });
664            jhs.spawn(jh);
665            let jh = tokio::spawn(async move {
666                let mut buf = vec![0; 10 * 1024 * 1024];
667                r.read_exact(&mut buf).await.unwrap();
668                debug!("received buf {i}");
669            });
670            jhs.spawn(jh);
671        }
672        let res = jhs.join_all().await;
673        for res in res {
674            res.unwrap();
675        }
676        Ok(())
677    }
678
679    #[tokio::test]
680    async fn serde_stream() -> Result<()> {
681        let _g = init_tracing();
682        let (mut s, mut c) = local_conn().await?;
683        let (mut snd, _) = s.stream::<Vec<i32>>().await?;
684        let (_, mut recv) = c.stream::<Vec<i32>>().await?;
685        snd.send(vec![1, 2, 3]).await?;
686        let ret = recv.next().await.context("recv")??;
687        assert_eq!(vec![1, 2, 3], ret);
688        drop(snd);
689        let ret = recv.next().await.map(|res| res.map_err(|_| ()));
690        assert_eq!(None, ret);
691        Ok(())
692    }
693
694    #[tokio::test]
695    async fn serde_stream_block() -> Result<()> {
696        let _g = init_tracing();
697        let (mut s, mut c) = local_conn().await?;
698        let (mut snd, _) = s.stream().await?;
699        let (_, mut recv) = c.stream().await?;
700        snd.send(vec![u8::MAX; 16]).await?;
701        let ret: Vec<_> = recv.next().await.context("recv")??;
702        assert_eq!(vec![u8::MAX; 16], ret);
703        Ok(())
704    }
705
706    #[tokio::test]
707    async fn serde_byte_stream_as_stream() -> Result<()> {
708        let _g = init_tracing();
709        let (mut s, mut c) = local_conn().await?;
710        let (mut s_send, _) = s.byte_stream().await?;
711        let (_, mut c_recv) = c.byte_stream().await?;
712        {
713            let mut send_ser1 = s_send.as_stream::<i32>();
714            let mut recv_ser1 = c_recv.as_stream::<i32>();
715            send_ser1.send(42).await?;
716            let ret = recv_ser1.next().await.context("recv")??;
717            assert_eq!(42, ret);
718        }
719        {
720            let mut send_ser2 = s_send.as_stream::<Vec<i32>>();
721            let mut recv_ser2 = c_recv.as_stream::<Vec<i32>>();
722            send_ser2.send(vec![1, 2, 3]).await?;
723            let ret = recv_ser2.next().await.context("recv")??;
724            assert_eq!(vec![1, 2, 3], ret);
725        }
726        Ok(())
727    }
728
729    #[tokio::test]
730    async fn serde_request_response_stream() -> Result<()> {
731        let _g = init_tracing();
732        let (mut s, mut c) = local_conn().await?;
733        let (mut snd1, mut recv1) = s.request_response_stream::<Vec<i32>, String>().await?;
734        let (mut snd2, mut recv2) = c.request_response_stream::<String, Vec<i32>>().await?;
735        snd1.send(vec![1, 2, 3]).await?;
736        let ret = recv2.next().await.context("recv")??;
737        assert_eq!(vec![1, 2, 3], ret);
738        snd2.send("hello there".to_string()).await?;
739        let ret = recv1.next().await.context("recv2")??;
740        assert_eq!("hello there", &ret);
741        Ok(())
742    }
743
744    #[tokio::test]
745    async fn sub_connection() -> Result<()> {
746        let _g = init_tracing();
747        let (mut s1, mut c1) = local_conn().await?;
748        let mut s2 = s1.sub_connection();
749        let mut c2 = c1.sub_connection();
750        let _ = s1.byte_stream();
751        let _ = c1.byte_stream();
752        let (mut snd, _) = s2.stream::<Vec<i32>>().await?;
753        let (_, mut recv) = c2.stream::<Vec<i32>>().await?;
754
755        snd.send(vec![1, 2, 3]).await?;
756        let ret = recv.next().await.context("recv")??;
757        assert_eq!(vec![1, 2, 3], ret);
758        Ok(())
759    }
760
761    #[tokio::test]
762    async fn sub_sub_connection() -> Result<()> {
763        let _g = init_tracing();
764        let (mut s1, mut c1) = local_conn().await?;
765        let mut s2 = s1.sub_connection();
766        let mut c2 = c1.sub_connection();
767        let mut s3 = s2.sub_connection();
768        let mut c3 = c2.sub_connection();
769        let _ = s1.byte_stream();
770        let _ = c1.byte_stream();
771        let _ = s2.byte_stream();
772        let _ = c2.byte_stream();
773        let (mut snd, _) = s3.stream::<Vec<i32>>().await?;
774        let (_, mut recv) = c3.stream::<Vec<i32>>().await?;
775
776        snd.send(vec![1, 2, 3]).await?;
777        let ret = recv.next().await.context("recv")??;
778        assert_eq!(vec![1, 2, 3], ret);
779        Ok(())
780    }
781}