cryprot_core/
tokio_rayon.rs

1//! Tokio-Rayon integration to spawn compute tasks in async contexts.
2
3use std::{
4    future::Future,
5    panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
6    pin::Pin,
7    task::{Context, Poll},
8    thread,
9};
10
11use tokio::sync::oneshot;
12
13pub struct TokioRayonJoinHandle<T: Send> {
14    rx: oneshot::Receiver<thread::Result<T>>,
15}
16
17/// Spawns a compute intensive task on the [`rayon`] global threadpool and
18/// returns a future that can be awaited without blocking the async task.
19pub fn spawn_compute<F, T>(func: F) -> TokioRayonJoinHandle<T>
20where
21    F: FnOnce() -> T + Send + 'static,
22    T: Send + 'static,
23{
24    let (tx, rx) = oneshot::channel();
25    rayon::spawn(|| {
26        let res = catch_unwind(AssertUnwindSafe(func));
27
28        if let Err(Err(err)) = tx.send(res) {
29            // if sending fails and func panicked, propagate panic to rayon panic handler
30            resume_unwind(err);
31        }
32    });
33    TokioRayonJoinHandle { rx }
34}
35
36impl<T: Send + 'static> Future for TokioRayonJoinHandle<T> {
37    type Output = thread::Result<T>;
38
39    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        let rx = Pin::new(&mut self.rx);
41        match rx.poll(cx) {
42            Poll::Ready(res) => {
43                Poll::Ready(res.expect("oneshot::Sender is not dropped before send"))
44            }
45            Poll::Pending => Poll::Pending,
46        }
47    }
48}