cryprot_ot/
extension.rs

1//! Fast OT extension using optimized [[IKNP03](https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf)] (semi-honest)
2//! or [[KOS15](https://eprint.iacr.org/2015/546.pdf)] (malicious) protocol.
3//!
4//! The protocols are optimized for the availability of `aes` and `avx2` target
5//! features for the semi-honest protocol and additionally `pclmulqdq` for the
6//! malicious protocol.
7//!
8//! ## Batching
9//! The protocols automatically compute the OTs in batches to increase
10//! throughput. The [`DEFAULT_OT_BATCH_SIZE`] has been chosen to maximise
11//! throughput in very low latency settings for large numbers of OTs.
12//! The batch size can changed using the corresponding methods on the sender and
13//! receiver (e.g. [`OtExtensionSender::with_batch_size`]).
14use std::{io, iter, marker::PhantomData, mem, panic::resume_unwind, task::Poll};
15
16use bytemuck::cast_slice_mut;
17use cryprot_core::{
18    Block,
19    aes_hash::FIXED_KEY_HASH,
20    aes_rng::AesRng,
21    alloc::allocate_zeroed_vec,
22    buf::Buf,
23    tokio_rayon::spawn_compute,
24    transpose::transpose_bitmatrix,
25    utils::{and_inplace_elem, xor_inplace},
26};
27use cryprot_net::{Connection, ConnectionError};
28use futures::{FutureExt, SinkExt, StreamExt, future::poll_fn};
29use rand::{Rng, RngCore, SeedableRng, distr::StandardUniform, rngs::StdRng};
30use subtle::{Choice, ConditionallySelectable};
31use tokio::{
32    io::{AsyncReadExt, AsyncWriteExt},
33    sync::mpsc,
34};
35use tracing::Level;
36
37use crate::{
38    Connected, CotReceiver, CotSender, Malicious, MaliciousMarker, RotReceiver, RotSender,
39    Security, SemiHonest, SemiHonestMarker,
40    adapter::CorrelatedFromRandom,
41    base::{self, SimplestOt},
42    phase, random_choices,
43};
44
45pub const BASE_OT_COUNT: usize = 128;
46
47pub const DEFAULT_OT_BATCH_SIZE: usize = 2_usize.pow(16);
48
49/// OT extension sender generic over its [`Security`] level.
50pub struct OtExtensionSender<S: Security> {
51    rng: StdRng,
52    base_ot: SimplestOt,
53    conn: Connection,
54    base_rngs: Vec<AesRng>,
55    base_choices: Vec<Choice>,
56    delta: Option<Block>,
57    batch_size: usize,
58    security: PhantomData<S>,
59}
60
61/// OT extension receiver generic over its [`Security`] level.
62pub struct OtExtensionReceiver<S: Security> {
63    base_ot: SimplestOt,
64    conn: Connection,
65    base_rngs: Vec<[AesRng; 2]>,
66    batch_size: usize,
67    security: PhantomData<S>,
68    rng: StdRng,
69}
70
71/// SemiHonest OT extension sender alias.
72pub type SemiHonestOtExtensionSender = OtExtensionSender<SemiHonestMarker>;
73/// SemiHonest OT extension receiver alias.
74pub type SemiHonestOtExtensionReceiver = OtExtensionReceiver<SemiHonestMarker>;
75
76/// Malicious OT extension sender alias.
77pub type MaliciousOtExtensionSender = OtExtensionSender<MaliciousMarker>;
78/// Malicious OT extension receiver alias.
79pub type MaliciousOtExtensionReceiver = OtExtensionReceiver<MaliciousMarker>;
80
81/// Error type returned by the OT extension protocols.
82#[derive(thiserror::Error, Debug)]
83#[non_exhaustive]
84pub enum Error {
85    #[error("unable to compute base OTs")]
86    BaseOT(#[from] base::Error),
87    #[error("connection error to peer")]
88    Connection(#[from] ConnectionError),
89    #[error("error in sending/receiving data")]
90    Communication(#[from] io::Error),
91    #[error("connection closed by peer")]
92    UnexcpectedClose,
93    /// Only possible for malicious variant.
94    #[error("Commitment does not match seed")]
95    WrongCommitment,
96    /// Only possible for malicious variant.
97    #[error("sender did not receiver x value in KOS check")]
98    MissingXValue,
99    /// Only possible for malicious variant.
100    #[error("malicious check failed")]
101    MaliciousCheck,
102    #[doc(hidden)]
103    #[error("async task is dropped. This error should not be observable.")]
104    AsyncTaskDropped,
105}
106
107impl<S: Security> OtExtensionSender<S> {
108    /// Create a new sender for the given [`Connection`].
109    pub fn new(conn: Connection) -> Self {
110        Self::new_with_rng(conn, StdRng::from_os_rng())
111    }
112
113    /// Create a new sender for the given [`Connection`] and [`StdRng`].
114    ///
115    /// For an rng seeded with a fixed seed, the output is deterministic.
116    pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
117        let base_ot = SimplestOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
118        Self {
119            rng,
120            base_ot,
121            conn,
122            base_rngs: vec![],
123            base_choices: vec![],
124            delta: None,
125            batch_size: DEFAULT_OT_BATCH_SIZE,
126            security: PhantomData,
127        }
128    }
129
130    /// Set the OT batch size for the sender.
131    ///
132    /// If the sender batch size is changed, the receiver's must also be changed
133    /// (see [`OtExtensionReceiver::with_batch_size`]).
134    /// Note that [`OtExtensionSender::send`] methods will fail if `count %
135    /// self.batch_size()` is not divisable by 128.
136    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
137        self.batch_size = batch_size;
138        self
139    }
140
141    /// The currently configured OT batch size.
142    pub fn batch_size(&self) -> usize {
143        self.batch_size
144    }
145
146    /// Returns true if base OTs have been performed. Subsequent calls to send
147    /// will not perform base OTs again.
148    pub fn has_base_ots(&self) -> bool {
149        self.base_rngs.len() == BASE_OT_COUNT
150    }
151
152    /// Perform base OTs for later extension. Subsequent calls to send
153    /// will not perform base OTs again.
154    pub async fn do_base_ots(&mut self) -> Result<(), Error> {
155        let base_choices = random_choices(BASE_OT_COUNT, &mut self.rng);
156        let base_ots = self.base_ot.receive(&base_choices).await?;
157        self.base_rngs = base_ots.into_iter().map(AesRng::from_seed).collect();
158        self.delta = Some(Block::from_choices(&base_choices));
159        self.base_choices = base_choices;
160        Ok(())
161    }
162}
163
164impl<S: Security> Connected for OtExtensionSender<S> {
165    fn connection(&mut self) -> &mut Connection {
166        &mut self.conn
167    }
168}
169
170impl SemiHonest for OtExtensionSender<SemiHonestMarker> {}
171/// A maliciously secure sender also offers semi-honest security at decreased
172/// performance.
173impl SemiHonest for OtExtensionSender<MaliciousMarker> {}
174
175impl Malicious for OtExtensionSender<MaliciousMarker> {}
176
177impl<S: Security> RotSender for OtExtensionSender<S> {
178    type Error = Error;
179
180    /// Sender part of OT extension.
181    ///
182    /// # Panics
183    /// - If `count` is not divisable by 128.
184    /// - If `count % self.batch_size()` is not divisable by 128.
185    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
186    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
187    async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
188        let count = ots.len();
189        assert_eq!(0, count % 128, "count must be multiple of 128");
190        let batch_size = self.batch_size();
191        let batches = count / batch_size;
192        let batch_size_remainder = count % batch_size;
193        let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
194
195        assert_eq!(
196            0,
197            batch_size_remainder % 128,
198            "count % batch_size must be multiple of 128"
199        );
200
201        let batch_sizes = iter::repeat_n(batch_size, batches)
202            .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
203
204        if !self.has_base_ots() {
205            self.do_base_ots().await?;
206        }
207
208        let delta = self.delta.expect("base OTs are done");
209        let mut sub_conn = self.conn.sub_connection();
210
211        // channel for communication between async task and compute thread
212        let (ch_s, ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
213        let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Block>();
214        let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
215        // take these to move them into compute thread, will be returned via ret channel
216        let mut base_rngs = mem::take(&mut self.base_rngs);
217        let base_choices = mem::take(&mut self.base_choices);
218        let batch_sizes_th = batch_sizes.clone();
219        let owned_ots = mem::take(ots);
220        let mut rng = StdRng::from_rng(&mut self.rng);
221
222        // spawn compute thread for CPU intensive work. This way we increase throughput
223        // and don't risk of blocking tokio worker threads
224        let jh = spawn_compute(move || {
225            let mut ots = owned_ots;
226            let mut extra_messages: Vec<[Block; 2]> = Vec::zeroed(num_extra);
227            let mut transposed = Vec::zeroed(batch_size);
228            let mut owned_v_mat: Vec<Block> = if S::MALICIOUS_SECURITY {
229                Vec::zeroed(ots.len())
230            } else {
231                vec![]
232            };
233            let mut extra_v_mat = vec![Block::ZERO; num_extra];
234
235            for (ots, batch_sizes, extra) in [
236                (
237                    &mut ots[..],
238                    &mut batch_sizes_th.clone() as &mut dyn Iterator<Item = _>,
239                    false,
240                ),
241                (&mut extra_messages[..], &mut iter::once(num_extra), true),
242            ] {
243                // to increase throughput, we divide the `count` many OTs into batches of size
244                // self.batch_size(). Crucially, this allows us to do the transpose
245                // and hash step while not having received the complete data from the
246                // OtExtensionReceiver.
247                for (chunk_idx, (ot_batch, curr_batch_size)) in
248                    ots.chunks_mut(batch_size).zip(batch_sizes).enumerate()
249                {
250                    let v_mat = if S::MALICIOUS_SECURITY {
251                        if extra {
252                            &mut extra_v_mat
253                        } else {
254                            let offset = chunk_idx * batch_size;
255                            &mut owned_v_mat[offset..offset + curr_batch_size]
256                        }
257                    } else {
258                        // we temporarily use the output OT buffer to hold the current chunk of the
259                        // V matrix which we XOR with our received row or 0
260                        // and then transpose into `transposed`
261                        cast_slice_mut(&mut ot_batch[..curr_batch_size / 2])
262                    };
263                    let v_mat = cast_slice_mut(v_mat);
264
265                    let cols_byte_batch = curr_batch_size / 8;
266                    let row_iter = v_mat.chunks_exact_mut(cols_byte_batch);
267
268                    for ((v_row, base_rng), base_choice) in
269                        row_iter.zip(&mut base_rngs).zip(&base_choices)
270                    {
271                        base_rng.fill_bytes(v_row);
272                        let mut recv_row = ch_r.recv()?;
273                        // constant time version of
274                        // if !base_choice {
275                        //   v_row ^= recv_row;
276                        // }
277                        let choice_mask =
278                            Block::conditional_select(&Block::ZERO, &Block::ONES, *base_choice);
279                        // if choice_mask == 0, we zero out recv_row
280                        // if choice_mask == 1, recv_row is not changed
281                        and_inplace_elem(&mut recv_row, choice_mask);
282                        let v_row = bytemuck::cast_slice_mut(v_row);
283                        // if choice_mask == 0, v_row = v_row ^ 000000..
284                        // if choice_mask == 1, v_row = v_row ^ recv_row
285                        xor_inplace(v_row, &recv_row);
286                    }
287                    {
288                        let transposed = bytemuck::cast_slice_mut(&mut transposed);
289                        transpose_bitmatrix(v_mat, &mut transposed[..v_mat.len()], BASE_OT_COUNT);
290                    }
291
292                    for (v, ots) in transposed.iter().zip(ot_batch.iter_mut()) {
293                        *ots = [*v, *v ^ delta]
294                    }
295
296                    if S::MALICIOUS_SECURITY {
297                        FIXED_KEY_HASH.tccr_hash_slice_mut(
298                            bytemuck::must_cast_slice_mut(ot_batch),
299                            |i| {
300                                // use batch_size here, which is the batch_size of all batches
301                                // except potentially the last. If we use curr_batch_size, our
302                                // offset would be wrong for the last batch if curr_batch_size <
303                                // batch_size
304                                Block::from(chunk_idx * batch_size + (i / 2))
305                            },
306                        );
307                    } else {
308                        FIXED_KEY_HASH.cr_hash_slice_mut(bytemuck::must_cast_slice_mut(ot_batch));
309                    }
310                }
311            }
312
313            if S::MALICIOUS_SECURITY {
314                let seed: Block = rng.random();
315                kos_ch_s.send(seed)?;
316                let rng = AesRng::from_seed(seed);
317
318                let mut q1 = extra_v_mat;
319                let mut q2 = vec![Block::ZERO; BASE_OT_COUNT];
320
321                let owned_v_mat_ref = &owned_v_mat;
322
323                let challenges: Vec<Block> = rng
324                    .sample_iter(StandardUniform)
325                    .take(ots.len() / BASE_OT_COUNT)
326                    .collect();
327
328                let block_batch_size = batch_size / BASE_OT_COUNT;
329
330                let challenge_iter =
331                    batch_sizes_th
332                        .clone()
333                        .enumerate()
334                        .flat_map(|(batch, curr_batch_size)| {
335                            challenges[batch * block_batch_size
336                                ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
337                                .iter()
338                                .cycle()
339                                .take(curr_batch_size)
340                        });
341
342                let q_idx_iter = batch_sizes_th.flat_map(|curr_batch_size| {
343                    (0..BASE_OT_COUNT).flat_map(move |t_idx| {
344                        iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
345                    })
346                });
347
348                for ((v, s), q_idx) in owned_v_mat_ref.iter().zip(challenge_iter).zip(q_idx_iter) {
349                    let (qi, qi2) = v.clmul(s);
350                    q1[q_idx] ^= qi;
351                    q2[q_idx] ^= qi2;
352                }
353
354                for (q1i, q2i) in q1.iter_mut().zip(&q2) {
355                    *q1i = Block::gf_reduce(q1i, q2i);
356                }
357                let mut u = kos_ch_r.recv()?;
358                let Some(received_x) = u.pop() else {
359                    return Err(Error::MissingXValue);
360                };
361                for ((received_t, base_choice), q1i) in u.iter().zip(&base_choices).zip(&q1) {
362                    let tt =
363                        Block::conditional_select(&Block::ZERO, &received_x, *base_choice) ^ *q1i;
364                    if tt != *received_t {
365                        return Err(Error::MaliciousCheck);
366                    }
367                }
368            }
369
370            Ok::<_, Error>((ots, base_rngs, base_choices))
371        });
372
373        let (_, mut recv) = sub_conn.byte_stream().await?;
374
375        for batch_size in batch_sizes.chain((num_extra != 0).then_some(num_extra)) {
376            for _ in 0..BASE_OT_COUNT {
377                let mut recv_row = allocate_zeroed_vec(batch_size / Block::BITS);
378                recv.read_exact(bytemuck::cast_slice_mut(&mut recv_row))
379                    .await?;
380                if ch_s.send(recv_row).is_err() {
381                    // If we can't send on the channel, the channel must've been dropped due to a
382                    // panic in the worker thread. So we try to join the compute task to resume the
383                    // panic
384                    resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
385                };
386            }
387        }
388
389        if S::MALICIOUS_SECURITY {
390            let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
391            let success = 'success: {
392                let Some(blk) = kos_ch_r_task.recv().await else {
393                    break 'success false;
394                };
395                kos_send.as_stream().send(blk).await?;
396
397                {
398                    let mut kos_recv = kos_recv.as_stream();
399                    let u = kos_recv.next().await.ok_or(Error::UnexcpectedClose)??;
400                    if kos_ch_s_task.send(u).is_err() {
401                        break 'success false;
402                    }
403                }
404
405                true
406            };
407            if !success {
408                resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
409            }
410        }
411
412        let (owned_ots, base_rngs, base_choices) = match jh.await {
413            Ok(res) => res?,
414            Err(panicked) => resume_unwind(panicked),
415        };
416        self.base_rngs = base_rngs;
417        self.base_choices = base_choices;
418        *ots = owned_ots;
419        Ok(())
420    }
421}
422
423impl SemiHonest for OtExtensionReceiver<SemiHonestMarker> {}
424impl SemiHonest for OtExtensionReceiver<MaliciousMarker> {}
425
426impl Malicious for OtExtensionReceiver<MaliciousMarker> {}
427
428impl<S: Security> OtExtensionReceiver<S> {
429    /// Create a new sender for the given [`Connection`].
430    pub fn new(conn: Connection) -> Self {
431        Self::new_with_rng(conn, StdRng::from_os_rng())
432    }
433
434    /// Create a new sender for the given [`Connection`] and [`StdRng`].
435    ///
436    /// For an rng seeded with a fixed seed, the output is deterministic.
437    pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
438        let base_ot = SimplestOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
439        Self {
440            rng,
441            base_ot,
442            conn,
443            base_rngs: vec![],
444            batch_size: DEFAULT_OT_BATCH_SIZE,
445            security: PhantomData,
446        }
447    }
448
449    /// Set the OT batch size for the receiver.
450    ///
451    /// If the receiver batch size is changed, the senders's must also be
452    /// changed (see [`OtExtensionSender::with_batch_size`]).
453    /// Note that [`OtExtensionReceiver::receive`] methods will fail if `count %
454    /// self.batch_size()` is not divisable by 128.
455    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
456        self.batch_size = batch_size;
457        self
458    }
459
460    /// The currently configured OT batch size.
461    pub fn batch_size(&self) -> usize {
462        self.batch_size
463    }
464
465    /// Returns true if base OTs have been performed. Subsequent calls to send
466    /// will not perform base OTs again.
467    pub fn has_base_ots(&self) -> bool {
468        self.base_rngs.len() == BASE_OT_COUNT
469    }
470
471    /// Perform base OTs for later extension. Subsequent calls to send
472    /// will not perform base OTs again.
473    pub async fn do_base_ots(&mut self) -> Result<(), Error> {
474        let base_ots = self.base_ot.send(BASE_OT_COUNT).await?;
475        self.base_rngs = base_ots
476            .into_iter()
477            .map(|[s1, s2]| [AesRng::from_seed(s1), AesRng::from_seed(s2)])
478            .collect();
479        Ok(())
480    }
481}
482
483impl<S: Security> Connected for OtExtensionReceiver<S> {
484    fn connection(&mut self) -> &mut Connection {
485        &mut self.conn
486    }
487}
488
489impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
490    type Error = Error;
491
492    /// Receiver part of OT extension.
493    ///
494    /// # Panics
495    /// - If `choices.len()` is not divisable by 128.
496    /// - If `choices.len() % self.batch_size()` is not divisable by 128.
497    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
498    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
499    async fn receive_into(
500        &mut self,
501        ots: &mut impl Buf<Block>,
502        choices: &[Choice],
503    ) -> Result<(), Self::Error> {
504        assert_eq!(choices.len(), ots.len());
505        assert_eq!(
506            0,
507            choices.len() % 128,
508            "choices.len() must be multiple of 128"
509        );
510        let batch_size = self.batch_size();
511        let count = choices.len();
512        let batch_size_remainder = count % batch_size;
513        assert_eq!(
514            0,
515            batch_size_remainder % 128,
516            "count % batch_size must be multiple of 128"
517        );
518
519        if !self.has_base_ots() {
520            self.do_base_ots().await?;
521        }
522
523        let mut sub_conn = self.conn.sub_connection();
524
525        let cols_byte_batch = batch_size / 8;
526        let choice_vec = choices_to_u8_vec(choices);
527
528        let (ch_s, mut ch_r) = mpsc::unbounded_channel::<Vec<u8>>();
529        let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Vec<Block>>();
530        let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Block>();
531        let mut rng = StdRng::from_rng(&mut self.rng);
532
533        let mut base_rngs = mem::take(&mut self.base_rngs);
534        let owned_ots = mem::take(ots);
535        let mut jh = spawn_compute(move || {
536            let mut ots = owned_ots;
537            let t_mat_size = if S::MALICIOUS_SECURITY {
538                ots.len()
539            } else {
540                batch_size
541            };
542            let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
543            let mut t_mat = vec![Block::ZERO; t_mat_size];
544            let mut extra_t_mat = vec![Block::ZERO; num_extra];
545            let mut extra_messages: Vec<Block> = Vec::zeroed(num_extra);
546            let extra_choices = random_choices(num_extra, &mut rng);
547            let extra_choice_vec = choices_to_u8_vec(&extra_choices);
548
549            for (ots, choice_vec, extra) in [
550                (&mut ots[..], &choice_vec, false),
551                (&mut extra_messages[..], &extra_choice_vec, true),
552            ] {
553                for (chunk_idx, (output_chunk, choice_batch)) in ots
554                    .chunks_mut(batch_size)
555                    .zip(choice_vec.chunks(cols_byte_batch))
556                    .enumerate()
557                {
558                    let curr_batch_size = output_chunk.len();
559                    let chunk_t_mat = if S::MALICIOUS_SECURITY {
560                        if extra {
561                            &mut extra_t_mat
562                        } else {
563                            let offset = chunk_idx * batch_size;
564                            &mut t_mat[offset..offset + curr_batch_size]
565                        }
566                    } else {
567                        &mut t_mat[..curr_batch_size]
568                    };
569                    assert_eq!(output_chunk.len(), chunk_t_mat.len());
570                    assert_eq!(choice_batch.len() * 8, chunk_t_mat.len());
571                    let chunk_t_mat: &mut [u8] = bytemuck::must_cast_slice_mut(chunk_t_mat);
572                    // might change for last chunk
573                    let cols_byte_batch = choice_batch.len();
574                    for (row, [rng1, rng2]) in chunk_t_mat
575                        .chunks_exact_mut(cols_byte_batch)
576                        .zip(&mut base_rngs)
577                    {
578                        rng1.fill_bytes(row);
579                        let mut send_row = vec![0_u8; cols_byte_batch];
580                        rng2.fill_bytes(&mut send_row);
581                        // TODO wouldn't this be better on Blocks instead of u8?
582                        for ((v2, v1), choices) in send_row.iter_mut().zip(row).zip(choice_batch) {
583                            *v2 ^= *v1 ^ *choices;
584                        }
585                        ch_s.send(send_row)?;
586                    }
587                    let output_bytes = bytemuck::cast_slice_mut(output_chunk);
588                    transpose_bitmatrix(
589                        &chunk_t_mat[..BASE_OT_COUNT * cols_byte_batch],
590                        output_bytes,
591                        BASE_OT_COUNT,
592                    );
593                    if S::MALICIOUS_SECURITY {
594                        FIXED_KEY_HASH.tccr_hash_slice_mut(output_chunk, |i| {
595                            Block::from(chunk_idx * batch_size + i)
596                        });
597                    } else {
598                        FIXED_KEY_HASH.cr_hash_slice_mut(output_chunk);
599                    }
600                }
601            }
602
603            if S::MALICIOUS_SECURITY {
604                // dropping ch_s is important so the async task exits the ch_r loop
605                drop(ch_s);
606                let seed = kos_ch_r.recv()?;
607
608                let mut t1 = extra_t_mat;
609                let mut t2 = vec![Block::ZERO; BASE_OT_COUNT];
610
611                let mut x1 = Block::from_choices(&extra_choices);
612                let mut x2 = Block::ZERO;
613
614                let rng = AesRng::from_seed(seed);
615
616                let t_mat_ref = &t_mat;
617                let batches = count / batch_size;
618                let batch_sizes = iter::repeat_n(batch_size, batches)
619                    .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
620
621                let choice_blocks: Vec<_> = choice_vec
622                    .chunks_exact(Block::BYTES)
623                    .map(|chunk| Block::try_from(chunk).expect("chunk is 16 bytes"))
624                    .collect();
625
626                let challenges: Vec<Block> = rng
627                    .sample_iter(StandardUniform)
628                    .take(choice_blocks.len())
629                    .collect();
630
631                for (x, s) in choice_blocks.iter().zip(challenges.iter()) {
632                    let (xi, xi2) = x.clmul(s);
633                    x1 ^= xi;
634                    x2 ^= xi2;
635                }
636
637                let block_batch_size = batch_size / BASE_OT_COUNT;
638
639                let challenge_iter =
640                    batch_sizes
641                        .clone()
642                        .enumerate()
643                        .flat_map(|(batch, curr_batch_size)| {
644                            challenges[batch * block_batch_size
645                                ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
646                                .iter()
647                                .cycle()
648                                .take(curr_batch_size)
649                        });
650                let t_idx_iter = batch_sizes.flat_map(|curr_batch_size| {
651                    (0..BASE_OT_COUNT).flat_map(move |t_idx| {
652                        iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
653                    })
654                });
655
656                for ((t, s), t_idx) in t_mat_ref.iter().zip(challenge_iter).zip(t_idx_iter) {
657                    let (ti, ti2) = t.clmul(s);
658                    t1[t_idx] ^= ti;
659                    t2[t_idx] ^= ti2;
660                }
661
662                for (t1i, t2i) in t1.iter_mut().zip(&mut t2) {
663                    *t1i = Block::gf_reduce(t1i, t2i);
664                }
665                t1.push(Block::gf_reduce(&x1, &x2));
666                kos_ch_s.send(t1)?;
667            }
668            Ok::<_, Error>((ots, base_rngs))
669        });
670
671        let (mut send, _) = sub_conn.byte_stream().await?;
672        while let Some(row) = ch_r.recv().await {
673            send.write_all(&row).await.map_err(Error::Communication)?;
674        }
675
676        if S::MALICIOUS_SECURITY {
677            // If the worker thread panics we break early from the above loop. We check for
678            // the panic to prevent a deadlock where we try to get the next message but the
679            // peer is still in the worker thread
680            let err = poll_fn(|cx| match jh.poll_unpin(cx) {
681                Poll::Ready(res) => Poll::Ready(res.map(drop)),
682                Poll::Pending => Poll::Ready(Ok(())),
683            })
684            .await;
685            if let Err(err) = err {
686                resume_unwind(err);
687            };
688            let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
689
690            let seed = {
691                let mut kos_recv = kos_recv.as_stream::<Block>();
692                kos_recv.next().await.ok_or(Error::UnexcpectedClose)??
693            };
694
695            let success = 'success: {
696                if kos_ch_s_task.send(seed).is_err() {
697                    break 'success false;
698                }
699
700                let mut kos_send = kos_send.as_stream::<Vec<Block>>();
701                let Some(v) = kos_ch_r_task.recv().await else {
702                    break 'success false;
703                };
704                kos_send.send(v).await.map_err(Error::Communication)?;
705
706                true
707            };
708            if !success {
709                resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
710            }
711        }
712
713        let (owned_ots, base_rngs) = match jh.await {
714            Ok(res) => res?,
715            Err(panicked) => resume_unwind(panicked),
716        };
717
718        self.base_rngs = base_rngs;
719        *ots = owned_ots;
720        Ok(())
721    }
722}
723
724impl<S: Security> CotSender for OtExtensionSender<S> {
725    type Error = Error;
726
727    async fn correlated_send_into<B, F>(
728        &mut self,
729        ots: &mut B,
730        correlation: F,
731    ) -> Result<(), Self::Error>
732    where
733        B: Buf<Block>,
734        F: FnMut(usize) -> Block + Send,
735    {
736        CorrelatedFromRandom::new(self)
737            .correlated_send_into(ots, correlation)
738            .await
739    }
740}
741
742impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
743    type Error = Error;
744
745    async fn correlated_receive_into<B>(
746        &mut self,
747        ots: &mut B,
748        choices: &[Choice],
749    ) -> Result<(), Self::Error>
750    where
751        B: Buf<Block>,
752    {
753        CorrelatedFromRandom::new(self)
754            .correlated_receive_into(ots, choices)
755            .await
756    }
757}
758
759fn choices_to_u8_vec(choices: &[Choice]) -> Vec<u8> {
760    assert_eq!(0, choices.len() % 8);
761    let mut v = vec![0_u8; choices.len() / 8];
762    for (chunk, byte) in choices.chunks_exact(8).zip(&mut v) {
763        for (i, choice) in chunk.iter().enumerate() {
764            *byte ^= choice.unwrap_u8() << i;
765        }
766    }
767    v
768}
769
770impl From<std::sync::mpsc::RecvError> for Error {
771    fn from(_: std::sync::mpsc::RecvError) -> Self {
772        Error::AsyncTaskDropped
773    }
774}
775
776impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
777    fn from(_: tokio::sync::mpsc::error::SendError<T>) -> Self {
778        Error::AsyncTaskDropped
779    }
780}
781
782#[cfg(test)]
783mod tests {
784
785    use cryprot_core::Block;
786    use cryprot_net::testing::{init_tracing, local_conn};
787    use rand::{SeedableRng, rngs::StdRng};
788
789    use crate::{
790        CotReceiver, CotSender, MaliciousMarker, RotReceiver, RotSender,
791        extension::{
792            DEFAULT_OT_BATCH_SIZE, OtExtensionReceiver, OtExtensionSender,
793            SemiHonestOtExtensionReceiver, SemiHonestOtExtensionSender,
794        },
795        random_choices,
796    };
797
798    #[tokio::test]
799    async fn test_extension() {
800        let _g = init_tracing();
801        const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE;
802        let (c1, c2) = local_conn().await.unwrap();
803        let rng1 = StdRng::seed_from_u64(42);
804        let mut rng2 = StdRng::seed_from_u64(24);
805        let choices = random_choices(COUNT, &mut rng2);
806        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
807        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
808        let (send_ots, recv_ots) =
809            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
810        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
811            assert_eq!(r, s[c.unwrap_u8() as usize]);
812        }
813    }
814
815    #[tokio::test]
816    async fn test_extension_half_batch() {
817        let _g = init_tracing();
818        const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2;
819        let (c1, c2) = local_conn().await.unwrap();
820        let rng1 = StdRng::seed_from_u64(42);
821        let mut rng2 = StdRng::seed_from_u64(24);
822        let choices = random_choices(COUNT, &mut rng2);
823        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
824        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
825        let (send_ots, recv_ots) =
826            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
827        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
828            assert_eq!(r, s[c.unwrap_u8() as usize]);
829        }
830    }
831
832    #[tokio::test]
833    async fn test_extension_partial_batch() {
834        let _g = init_tracing();
835        const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2 + 128;
836        let (c1, c2) = local_conn().await.unwrap();
837        let rng1 = StdRng::seed_from_u64(42);
838        let mut rng2 = StdRng::seed_from_u64(24);
839        let choices = random_choices(COUNT, &mut rng2);
840        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
841        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
842        let (send_ots, recv_ots) =
843            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
844        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
845            assert_eq!(r, s[c.unwrap_u8() as usize]);
846        }
847    }
848
849    #[tokio::test]
850    async fn test_extension_malicious_half_batch() {
851        let _g = init_tracing();
852        const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2;
853        let (c1, c2) = local_conn().await.unwrap();
854        let rng1 = StdRng::seed_from_u64(42);
855        let mut rng2 = StdRng::seed_from_u64(24);
856        let choices = random_choices(COUNT, &mut rng2);
857        let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
858        let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
859
860        let (send_ots, recv_ots) =
861            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
862        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
863            assert_eq!(r, s[c.unwrap_u8() as usize]);
864        }
865    }
866
867    #[tokio::test]
868    async fn test_extension_malicious_partial_batch() {
869        let _g = init_tracing();
870        const COUNT: usize = DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2 + 128;
871        let (c1, c2) = local_conn().await.unwrap();
872        let rng1 = StdRng::seed_from_u64(42);
873        let mut rng2 = StdRng::seed_from_u64(24);
874        let choices = random_choices(COUNT, &mut rng2);
875        let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
876        let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
877
878        let (send_ots, recv_ots) =
879            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
880        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
881            assert_eq!(r, s[c.unwrap_u8() as usize]);
882        }
883    }
884
885    #[tokio::test]
886    async fn test_extension_malicious_multiple_batch() {
887        let _g = init_tracing();
888        const COUNT: usize = DEFAULT_OT_BATCH_SIZE * 2;
889        let (c1, c2) = local_conn().await.unwrap();
890        let rng1 = StdRng::seed_from_u64(42);
891        let mut rng2 = StdRng::seed_from_u64(24);
892        let choices = random_choices(COUNT, &mut rng2);
893        let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
894        let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
895
896        let (send_ots, recv_ots) =
897            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
898        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
899            assert_eq!(r, s[c.unwrap_u8() as usize]);
900        }
901    }
902
903    #[tokio::test]
904    async fn test_correlated_extension() {
905        let _g = init_tracing();
906        const COUNT: usize = 128;
907        let (c1, c2) = local_conn().await.unwrap();
908        let rng1 = StdRng::seed_from_u64(42);
909        let mut rng2 = StdRng::seed_from_u64(24);
910        let choices = random_choices(COUNT, &mut rng2);
911        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
912        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
913        let (send_ots, recv_ots) = tokio::try_join!(
914            sender.correlated_send(COUNT, |_| Block::ONES),
915            receiver.correlated_receive(&choices)
916        )
917        .unwrap();
918        for (i, ((r, s), c)) in recv_ots.into_iter().zip(send_ots).zip(choices).enumerate() {
919            if bool::from(c) {
920                assert_eq!(r ^ Block::ONES, s, "Block {i}");
921            } else {
922                assert_eq!(r, s, "Block {i}")
923            }
924        }
925    }
926}