1use std::{
3 collections::{HashMap, hash_map::Entry},
4 future::Future,
5 io::{Error, IoSlice},
6 mem,
7 pin::{Pin, pin},
8 sync::{
9 Arc,
10 atomic::{AtomicU32, Ordering},
11 },
12 task::{Context, Poll},
13};
14
15use bincode::Options;
16use s2n_quic::{
17 connection::{Handle, StreamAcceptor as QuicStreamAcceptor},
18 stream::{ReceiveStream as QuicRecvStream, SendStream as QuicSendStream},
19};
20use serde::{Deserialize, Serialize, de::DeserializeOwned};
21use tokio::{
22 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
23 select,
24 sync::{mpsc, oneshot},
25};
26use tokio_serde::{
27 SymmetricallyFramed,
28 formats::{Bincode, SymmetricalBincode},
29};
30use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec, length_delimited};
31use tracing::{Level, debug, error, event};
32
33#[cfg(feature = "metrics")]
34pub mod metrics;
35
36#[doc(hidden)]
37#[cfg(any(test, feature = "__testing"))]
38pub mod testing;
39
40#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
42pub struct Id(pub(crate) u64);
43
44#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
45enum StreamId {
46 Implicit(u64),
47 Explicit(u64),
48}
49
50#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
54struct ConnectionId(pub(crate) u32);
55
56#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
58struct UniqueId {
59 cids: Vec<ConnectionId>,
60 id: StreamId,
61}
62
63type StreamSend = oneshot::Sender<(QuicRecvStream, usize)>;
64type StreamRecv = oneshot::Receiver<(QuicRecvStream, usize)>;
65
66pub struct StreamManager {
68 acceptor: QuicStreamAcceptor,
69 cmd_send: mpsc::UnboundedSender<Cmd>,
70 cmd_recv: mpsc::UnboundedReceiver<Cmd>,
71 pending: HashMap<UniqueId, StreamSend>,
72 accepted: HashMap<UniqueId, (QuicRecvStream, usize)>,
73}
74
75#[derive(Debug)]
82pub struct Connection {
83 cids: Vec<ConnectionId>,
84 next_cid: Arc<AtomicU32>,
85 handle: Handle,
86 cmd: mpsc::UnboundedSender<Cmd>,
87 next_implicit_id: u64,
88}
89
90pub struct SendStreamBytes {
92 inner: QuicSendStream,
93}
94
95pub struct ReceiveStreamBytes {
97 inner: ReceiveStreamWrapper,
98}
99
100pub type SendStream<T> = SymmetricallyFramed<
102 FramedWrite<SendStreamBytes, LengthDelimitedCodec>,
103 T,
104 SymmetricalBincode<T>,
105>;
106
107pub type SendStreamTemp<'a, T> = SymmetricallyFramed<
109 FramedWrite<&'a mut SendStreamBytes, LengthDelimitedCodec>,
110 T,
111 SymmetricalBincode<T>,
112>;
113
114pub type ReceiveStream<T> = SymmetricallyFramed<
116 FramedRead<ReceiveStreamBytes, LengthDelimitedCodec>,
117 T,
118 SymmetricalBincode<T>,
119>;
120
121pub type ReceiveStreamTemp<'a, T> = SymmetricallyFramed<
123 FramedRead<&'a mut ReceiveStreamBytes, LengthDelimitedCodec>,
124 T,
125 SymmetricalBincode<T>,
126>;
127
128enum ReceiveStreamWrapper {
129 Channel { stream_recv: StreamRecv },
130 Stream { recv_stream: QuicRecvStream },
131}
132
133#[derive(Debug)]
134enum Cmd {
135 NewStream {
136 uid: UniqueId,
137 stream_return: StreamSend,
138 },
139 AcceptedStream {
140 uid: UniqueId,
141 stream: QuicRecvStream,
142 bytes_read: usize,
143 },
144}
145
146impl StreamManager {
147 pub fn new(acceptor: QuicStreamAcceptor) -> Self {
148 let (cmd_send, cmd_recv) = mpsc::unbounded_channel();
149 Self {
150 acceptor,
151 cmd_send,
152 cmd_recv,
153 pending: Default::default(),
154 accepted: Default::default(),
155 }
156 }
157
158 #[tracing::instrument(skip_all)]
162 pub async fn start(mut self) {
163 loop {
164 let mut receive_stream = pin!(self.acceptor.accept_receive_stream());
166 select! {
167 res = &mut receive_stream => {
168 match res {
169 Ok(Some(stream)) => {
170 debug!("accepted stream");
171 Self::accepted(stream, self.cmd_send.clone());
172 }
173 Ok(None) => {
174 debug!("remote closed");
175 return;
176 }
177 Err(err) => {
178 error!(%err, "unable to accept stream");
179 return;
180 }
181 }
182 }
183 Some(cmd) = self.cmd_recv.recv() => { debug!(?cmd, "received cmd");
185 match cmd {
186 Cmd::NewStream {uid, stream_return} => {
187 if let Some(accepted) = self.accepted.remove(&uid) {
188 if stream_return.send(accepted).is_err() {
189 debug!("accepted remote stream but local receiver is closed");
190 }
191 debug!("sending new stream to receiver");
192 continue;
193 }
194 match self.pending.entry(uid) {
195 Entry::Occupied(occupied_entry) => {
196 panic!("Duplicate unique id: {:?}", occupied_entry.key())
197 },
198 Entry::Vacant(vacant_entry) => {vacant_entry.insert(stream_return);},
199 }
200 }
201 Cmd::AcceptedStream {uid, stream, bytes_read} => {
202 if let Some(stream_ret) = self.pending.remove(&uid) {
203 if stream_ret.send((stream, bytes_read)).is_err() {
204 debug!("accepted remote stream but local receiver is closed");
205 }
206 } else {
207 debug!("accepted stream but no pending");
208 self.accepted.insert(uid, (stream, bytes_read));
209 }
210 }
211 }
212 }
213 }
214 }
215 }
216
217 fn accepted(mut stream: QuicRecvStream, cmd_send: mpsc::UnboundedSender<Cmd>) {
219 tokio::spawn(async move {
220 let (uid, bytes_read) = match UniqueId::read_from(&mut stream).await {
221 Ok(ret) => ret,
222 Err(err) => {
223 error!(?err, "unable to read stream unique id");
224 return;
225 }
226 };
227 cmd_send
228 .send(Cmd::AcceptedStream {
229 uid,
230 stream,
231 bytes_read,
232 })
233 .expect("cmd_rcv is owned by StreamManager")
234 });
235 }
236}
237
238#[derive(thiserror::Error, Debug)]
240pub enum ConnectionError {
241 #[error("Unable to open stream")]
242 OpenStream(#[source] s2n_quic::connection::Error),
243 #[error("io error during stream establishment")]
244 IoError(#[source] io::Error),
245 #[error("StreamManager is dropped and not accepting connections")]
246 StreamManagerDropped,
247 #[error("Stream unique id deserialization failed")]
248 UniqueIdDeserialization(#[source] bincode::Error),
249 #[error("Stream unique id serialization failed")]
250 UniqueIdSerialization(#[source] bincode::Error),
251 #[error("Reached maximum number of sub connections")]
252 SubConnectionLimitReached,
253}
254
255impl Connection {
256 pub fn new(quic_conn: s2n_quic::Connection) -> (Self, StreamManager) {
257 let (handle, acceptor) = quic_conn.split();
258 let stream_manager = StreamManager::new(acceptor);
259 let conn = Self {
260 cids: vec![],
261 next_cid: Arc::new(AtomicU32::new(0)),
262 handle,
263 cmd: stream_manager.cmd_send.clone(),
264 next_implicit_id: 0,
265 };
266 (conn, stream_manager)
267 }
268
269 #[tracing::instrument(level = Level::DEBUG, skip(self), ret)]
274 pub fn sub_connection(&mut self) -> Self {
275 let cid = self.next_cid.fetch_add(1, Ordering::Relaxed);
276 let mut cids = self.cids.clone();
277 cids.push(ConnectionId(cid));
278 Self {
279 cids,
280 next_cid: Arc::new(AtomicU32::new(0)),
281 handle: self.handle.clone(),
282 cmd: self.cmd.clone(),
283 next_implicit_id: 0,
284 }
285 }
286
287 async fn internal_byte_stream(
288 &self,
289 stream_id: StreamId,
290 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
291 let uid = UniqueId::new(self.cids.clone(), stream_id);
292 let mut snd = self
293 .handle
294 .clone()
295 .open_send_stream()
296 .await
297 .map_err(ConnectionError::OpenStream)?;
298 let bytes_written = uid.write_into(&mut snd).await?;
299 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes_written);
300 let (stream_return, stream_recv) = oneshot::channel();
301 self.cmd
302 .send(Cmd::NewStream { uid, stream_return })
303 .map_err(|_| ConnectionError::StreamManagerDropped)?;
304 let snd = SendStreamBytes { inner: snd };
305 let recv = ReceiveStreamBytes {
306 inner: ReceiveStreamWrapper::Channel { stream_recv },
307 };
308 Ok((snd, recv))
309 }
310
311 pub async fn byte_stream(
313 &mut self,
314 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
315 self.next_implicit_id += 1;
316 self.internal_byte_stream(StreamId::Implicit(self.next_implicit_id - 1))
317 .await
318 }
319
320 pub async fn byte_stream_with_id(
322 &self,
323 id: Id,
324 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
325 self.internal_byte_stream(StreamId::Explicit(id.0)).await
326 }
327
328 async fn internal_stream<T: Serialize + DeserializeOwned>(
330 &self,
331 id: StreamId,
332 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
333 let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
334 let mut ld_codec = LengthDelimitedCodec::builder();
335 const MB: usize = 1024 * 1024;
337 ld_codec.max_frame_length(256 * MB);
338 let framed_send = ld_codec.new_write(send_bytes);
339 let framed_read = ld_codec.new_read(recv_bytes);
340 let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
341 let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
342 Ok((serde_send, serde_read))
343 }
344
345 pub async fn stream<T: Serialize + DeserializeOwned>(
347 &mut self,
348 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
349 self.next_implicit_id += 1;
350 self.internal_stream(StreamId::Implicit(self.next_implicit_id - 1))
351 .await
352 }
353
354 pub async fn stream_with_id<T: Serialize + DeserializeOwned>(
357 &self,
358 id: Id,
359 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
360 self.internal_stream(StreamId::Explicit(id.0)).await
361 }
362
363 async fn internal_request_response_stream<T: Serialize, S: DeserializeOwned>(
364 &self,
365 id: StreamId,
366 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
367 let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
368 let framed_send = default_codec().new_write(send_bytes);
369 let framed_read = default_codec().new_read(recv_bytes);
370 let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
371 let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
372 Ok((serde_send, serde_read))
373 }
374
375 pub async fn request_response_stream<T: Serialize, S: DeserializeOwned>(
378 &mut self,
379 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
380 self.next_implicit_id += 1;
381 self.internal_request_response_stream(StreamId::Implicit(self.next_implicit_id - 1))
382 .await
383 }
384
385 pub async fn request_response_stream_with_id<T: Serialize, S: DeserializeOwned>(
388 &self,
389 id: Id,
390 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
391 self.internal_request_response_stream(StreamId::Explicit(id.0))
392 .await
393 }
394}
395
396impl Id {
397 pub fn new(id: u64) -> Self {
398 Self(id)
399 }
400}
401
402fn bincode_opts() -> impl bincode::Options {
403 bincode::options().with_big_endian().with_varint_encoding()
404}
405
406impl UniqueId {
407 fn new(cids: Vec<ConnectionId>, id: StreamId) -> Self {
408 Self { cids, id }
409 }
410
411 async fn write_into<W: AsyncWrite>(&self, write: W) -> Result<usize, ConnectionError> {
412 let mut write = pin!(write);
413 let mut options = bincode_opts();
414 let serialized = (&mut options)
415 .serialize(self)
416 .map_err(ConnectionError::UniqueIdSerialization)?;
417 write
418 .write_u32(
419 serialized
420 .len()
421 .try_into()
422 .map_err(|_| ConnectionError::SubConnectionLimitReached)?,
423 )
424 .await
425 .map_err(ConnectionError::IoError)?;
426 write
427 .write_all(&serialized)
428 .await
429 .map_err(ConnectionError::IoError)?;
430 Ok(mem::size_of::<u32>() + serialized.len())
431 }
432
433 async fn read_from<R: AsyncRead>(reader: R) -> Result<(Self, usize), ConnectionError> {
434 let mut reader = pin!(reader);
435 let len = reader.read_u32().await.map_err(ConnectionError::IoError)?;
436 let mut buf = vec![0; len as usize];
437 reader
438 .read_exact(&mut buf)
439 .await
440 .map_err(ConnectionError::IoError)?;
441 let uid = bincode_opts()
442 .deserialize(&buf)
443 .map_err(ConnectionError::UniqueIdDeserialization)?;
444 Ok((uid, mem::size_of::<u32>() + len as usize))
445 }
446}
447
448#[derive(thiserror::Error, Debug)]
450pub enum StreamError {
451 #[error("unable to flush stream")]
452 Flush(#[source] s2n_quic::stream::Error),
453 #[error("unable to close stream")]
454 Close(#[source] s2n_quic::stream::Error),
455 #[error("unable to finish stream")]
456 Finish(#[source] s2n_quic::stream::Error),
457}
458
459impl SendStreamBytes {
460 pub async fn flush(&mut self) -> Result<(), StreamError> {
461 self.inner.flush().await.map_err(StreamError::Flush)
462 }
463
464 pub fn finish(&mut self) -> Result<(), StreamError> {
465 self.inner.finish().map_err(StreamError::Finish)
466 }
467
468 pub async fn close(&mut self) -> Result<(), StreamError> {
469 self.inner.close().await.map_err(StreamError::Close)
470 }
471
472 pub fn as_stream<T: Serialize>(&mut self) -> SendStreamTemp<T> {
473 let framed_send = default_codec().new_write(self);
474 SymmetricallyFramed::new(framed_send, Bincode::default())
475 }
476}
477
478impl AsyncWrite for SendStreamBytes {
479 fn poll_write(
480 mut self: Pin<&mut Self>,
481 cx: &mut Context<'_>,
482 buf: &[u8],
483 ) -> Poll<Result<usize, Error>> {
484 let inner = Pin::new(&mut self.inner);
485 trace_poll(inner.poll_write(cx, buf))
486 }
487
488 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
489 let inner = Pin::new(&mut self.inner);
490 AsyncWrite::poll_flush(inner, cx)
491 }
492
493 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
494 let inner = Pin::new(&mut self.inner);
495 inner.poll_shutdown(cx)
496 }
497
498 fn poll_write_vectored(
499 mut self: Pin<&mut Self>,
500 cx: &mut Context<'_>,
501 bufs: &[IoSlice<'_>],
502 ) -> Poll<Result<usize, Error>> {
503 let inner = Pin::new(&mut self.inner);
504 trace_poll(inner.poll_write_vectored(cx, bufs))
505 }
506
507 fn is_write_vectored(&self) -> bool {
508 self.inner.is_write_vectored()
509 }
510}
511
512fn trace_poll(p: Poll<io::Result<usize>>) -> Poll<io::Result<usize>> {
513 if let Poll::Ready(Ok(bytes)) = p {
514 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes);
515 }
516 p
517}
518
519impl ReceiveStreamBytes {
520 pub fn as_stream<T: DeserializeOwned>(&mut self) -> ReceiveStreamTemp<T> {
521 let framed_read = default_codec().new_read(self);
522 SymmetricallyFramed::new(framed_read, Bincode::default())
523 }
524}
525
526impl AsyncRead for ReceiveStreamBytes {
529 fn poll_read(
530 mut self: Pin<&mut Self>,
531 cx: &mut Context<'_>,
532 buf: &mut ReadBuf<'_>,
533 ) -> Poll<std::io::Result<()>> {
534 match &mut self.inner {
535 ReceiveStreamWrapper::Channel { stream_recv } => match Pin::new(stream_recv).poll(cx) {
536 Poll::Pending => Poll::Pending,
537 Poll::Ready(Ok((recv_stream, bytes_read))) => {
538 event!(target: "cryprot_metrics", Level::TRACE, bytes_read);
541 self.inner = ReceiveStreamWrapper::Stream { recv_stream };
542 self.poll_read(cx, buf)
543 }
544 Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::other(Box::new(err)))),
545 },
546 ReceiveStreamWrapper::Stream { recv_stream } => {
547 let len = buf.filled().len();
548 let poll = Pin::new(recv_stream).poll_read(cx, buf);
549 if let Poll::Ready(Ok(())) = poll {
550 let bytes = buf.filled().len() - len;
551 if bytes > 0 {
552 event!(target: "cryprot_metrics", Level::TRACE, bytes_read = bytes);
553 }
554 }
555 poll
556 }
557 }
558 }
559}
560
561fn default_codec() -> length_delimited::Builder {
562 let mut ld_codec = LengthDelimitedCodec::builder();
563 const MB: usize = 1024 * 1024;
564 ld_codec.max_frame_length(20 * MB);
565 ld_codec
566}
567
568#[cfg(test)]
569mod tests {
570 use std::u8;
571
572 use anyhow::{Context, Result};
573 use futures::{SinkExt, StreamExt};
574 use tokio::{
575 io::{AsyncReadExt, AsyncWriteExt},
576 task::JoinSet,
577 };
578 use tracing::debug;
579
580 use crate::{
581 Id,
582 testing::{init_tracing, local_conn},
583 };
584
585 #[tokio::test]
586 async fn create_local_conn() -> Result<()> {
587 let _g = init_tracing();
588 let _ = local_conn().await?;
589 Ok(())
590 }
591
592 #[tokio::test]
593 async fn byte_stream() -> Result<()> {
594 let _g = init_tracing();
595 let (mut s, mut c) = local_conn().await?;
596 let (mut s_send, _) = s.byte_stream().await?;
597 let (_, mut c_recv) = c.byte_stream().await?;
598 let send_buf = b"hello there";
599 s_send.write_all(send_buf).await?;
600 let mut buf = [0; 11];
601 c_recv.read_exact(&mut buf).await?;
602 assert_eq!(send_buf, &buf);
603 Ok(())
604 }
605
606 #[tokio::test]
607 async fn byte_stream_explicit_implicit_id() -> Result<()> {
608 let _g = init_tracing();
609 let (mut s, mut c) = local_conn().await?;
610 let (mut s_send1, _) = s.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
611 let (mut s_send2, _) = s.byte_stream().await?;
612 let (_, mut c_recv1) = c.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
613 let (_, mut c_recv2) = c.byte_stream().await?;
614 let send_buf1 = b"hello there";
615 s_send1.write_all(send_buf1).await?;
616 let mut buf = [0; 11];
617 c_recv1.read_exact(&mut buf).await?;
618 assert_eq!(send_buf1, &buf);
619
620 let send_buf2 = b"general kenobi";
621 s_send2.write_all(send_buf2).await?;
622 let mut buf = [0; 14];
623 c_recv2.read_exact(&mut buf).await?;
624 assert_eq!(send_buf2, &buf);
625 Ok(())
626 }
627
628 #[tokio::test]
629 async fn byte_stream_different_order() -> Result<()> {
630 let _g = init_tracing();
631 let (mut s, mut c) = local_conn().await?;
632 let (mut s_send, mut s_recv) = s.byte_stream().await?;
633 let s_send_buf = b"hello there";
634 s_send.write_all(s_send_buf).await?;
635 let mut s_recv_buf = [0; 2];
636 let jh = tokio::spawn(async move {
639 s_recv.read_exact(&mut s_recv_buf).await.unwrap();
640 s_recv_buf
641 });
642 let (mut c_send, mut c_recv) = c.byte_stream().await?;
643 let mut c_recv_buf = [0; 11];
644 c_recv.read_exact(&mut c_recv_buf).await?;
645 assert_eq!(s_send_buf, &c_recv_buf);
646 let c_send_buf = b"42";
647 c_send.write_all(c_send_buf).await?;
648 let s_recv_buf = jh.await?;
649 assert_eq!(c_send_buf, &s_recv_buf);
650 Ok(())
651 }
652
653 #[tokio::test]
654 async fn many_parallel_byte_streams() -> Result<()> {
655 let _g = init_tracing();
656 let (mut c1, mut c2) = local_conn().await?;
657 let mut jhs = JoinSet::new();
658 for i in 0..10 {
659 let ((mut s, _), (_, mut r)) =
660 tokio::try_join!(c1.byte_stream(), c2.byte_stream()).unwrap();
661
662 let jh = tokio::spawn(async move {
663 let buf = vec![0; 10 * 1024 * 1024];
664 s.write_all(&buf).await.unwrap();
665 debug!("wrote buf {i}");
666 });
667 jhs.spawn(jh);
668 let jh = tokio::spawn(async move {
669 let mut buf = vec![0; 10 * 1024 * 1024];
670 r.read_exact(&mut buf).await.unwrap();
671 debug!("received buf {i}");
672 });
673 jhs.spawn(jh);
674 }
675 let res = jhs.join_all().await;
676 for res in res {
677 res.unwrap();
678 }
679 Ok(())
680 }
681
682 #[tokio::test]
683 async fn serde_stream() -> Result<()> {
684 let _g = init_tracing();
685 let (mut s, mut c) = local_conn().await?;
686 let (mut snd, _) = s.stream::<Vec<i32>>().await?;
687 let (_, mut recv) = c.stream::<Vec<i32>>().await?;
688 snd.send(vec![1, 2, 3]).await?;
689 let ret = recv.next().await.context("recv")??;
690 assert_eq!(vec![1, 2, 3], ret);
691 drop(snd);
692 let ret = recv.next().await.map(|res| res.map_err(|_| ()));
693 assert_eq!(None, ret);
694 Ok(())
695 }
696
697 #[tokio::test]
698 async fn serde_stream_block() -> Result<()> {
699 let _g = init_tracing();
700 let (mut s, mut c) = local_conn().await?;
701 let (mut snd, _) = s.stream().await?;
702 let (_, mut recv) = c.stream().await?;
703 snd.send(vec![u8::MAX; 16]).await?;
704 let ret: Vec<_> = recv.next().await.context("recv")??;
705 assert_eq!(vec![u8::MAX; 16], ret);
706 Ok(())
707 }
708
709 #[tokio::test]
710 async fn serde_byte_stream_as_stream() -> Result<()> {
711 let _g = init_tracing();
712 let (mut s, mut c) = local_conn().await?;
713 let (mut s_send, _) = s.byte_stream().await?;
714 let (_, mut c_recv) = c.byte_stream().await?;
715 {
716 let mut send_ser1 = s_send.as_stream::<i32>();
717 let mut recv_ser1 = c_recv.as_stream::<i32>();
718 send_ser1.send(42).await?;
719 let ret = recv_ser1.next().await.context("recv")??;
720 assert_eq!(42, ret);
721 }
722 {
723 let mut send_ser2 = s_send.as_stream::<Vec<i32>>();
724 let mut recv_ser2 = c_recv.as_stream::<Vec<i32>>();
725 send_ser2.send(vec![1, 2, 3]).await?;
726 let ret = recv_ser2.next().await.context("recv")??;
727 assert_eq!(vec![1, 2, 3], ret);
728 }
729 Ok(())
730 }
731
732 #[tokio::test]
733 async fn serde_request_response_stream() -> Result<()> {
734 let _g = init_tracing();
735 let (mut s, mut c) = local_conn().await?;
736 let (mut snd1, mut recv1) = s.request_response_stream::<Vec<i32>, String>().await?;
737 let (mut snd2, mut recv2) = c.request_response_stream::<String, Vec<i32>>().await?;
738 snd1.send(vec![1, 2, 3]).await?;
739 let ret = recv2.next().await.context("recv")??;
740 assert_eq!(vec![1, 2, 3], ret);
741 snd2.send("hello there".to_string()).await?;
742 let ret = recv1.next().await.context("recv2")??;
743 assert_eq!("hello there", &ret);
744 Ok(())
745 }
746
747 #[tokio::test]
748 async fn sub_connection() -> Result<()> {
749 let _g = init_tracing();
750 let (mut s1, mut c1) = local_conn().await?;
751 let mut s2 = s1.sub_connection();
752 let mut c2 = c1.sub_connection();
753 let _ = s1.byte_stream();
754 let _ = c1.byte_stream();
755 let (mut snd, _) = s2.stream::<Vec<i32>>().await?;
756 let (_, mut recv) = c2.stream::<Vec<i32>>().await?;
757
758 snd.send(vec![1, 2, 3]).await?;
759 let ret = recv.next().await.context("recv")??;
760 assert_eq!(vec![1, 2, 3], ret);
761 Ok(())
762 }
763
764 #[tokio::test]
765 async fn sub_sub_connection() -> Result<()> {
766 let _g = init_tracing();
767 let (mut s1, mut c1) = local_conn().await?;
768 let mut s2 = s1.sub_connection();
769 let mut c2 = c1.sub_connection();
770 let mut s3 = s2.sub_connection();
771 let mut c3 = c2.sub_connection();
772 let _ = s1.byte_stream();
773 let _ = c1.byte_stream();
774 let _ = s2.byte_stream();
775 let _ = c2.byte_stream();
776 let (mut snd, _) = s3.stream::<Vec<i32>>().await?;
777 let (_, mut recv) = c3.stream::<Vec<i32>>().await?;
778
779 snd.send(vec![1, 2, 3]).await?;
780 let ret = recv.next().await.context("recv")??;
781 assert_eq!(vec![1, 2, 3], ret);
782 Ok(())
783 }
784}