1use std::{io, iter, marker::PhantomData, mem, panic::resume_unwind, task::Poll};
15
16use bytemuck::cast_slice_mut;
17use cryprot_core::{
18 Block,
19 aes_hash::FIXED_KEY_HASH,
20 aes_rng::AesRng,
21 alloc::allocate_zeroed_vec,
22 buf::Buf,
23 tokio_rayon::spawn_compute,
24 transpose::transpose_bitmatrix,
25 utils::{and_inplace_elem, xor_inplace},
26};
27use cryprot_net::{Connection, ConnectionError};
28use futures::{FutureExt, SinkExt, StreamExt, future::poll_fn};
29use rand::{Rng, RngCore, SeedableRng, distr::StandardUniform, rngs::StdRng};
30use subtle::{Choice, ConditionallySelectable};
31use tokio::{
32 io::{AsyncReadExt, AsyncWriteExt},
33 sync::mpsc,
34};
35use tracing::Level;
36
37use crate::{
38 Connected, CotReceiver, CotSender, Malicious, MaliciousMarker, RotReceiver, RotSender,
39 Security, SemiHonest, SemiHonestMarker,
40 adapter::CorrelatedFromRandom,
41 base::{self, SimplestOt},
42 phase, random_choices,
43};
44
45pub const BASE_OT_COUNT: usize = 128;
46
47pub const DEFAULT_OT_BATCH_SIZE: usize = 2_usize.pow(16);
48
49pub struct OtExtensionSender<S: Security> {
51 rng: StdRng,
52 base_ot: SimplestOt,
53 conn: Connection,
54 base_rngs: Vec<AesRng>,
55 base_choices: Vec<Choice>,
56 delta: Option<Block>,
57 batch_size: usize,
58 security: PhantomData<S>,
59}
60
61pub struct OtExtensionReceiver<S: Security> {
63 base_ot: SimplestOt,
64 conn: Connection,
65 base_rngs: Vec<[AesRng; 2]>,
66 batch_size: usize,
67 security: PhantomData<S>,
68 rng: StdRng,
69}
70
71pub type SemiHonestOtExtensionSender = OtExtensionSender<SemiHonestMarker>;
73pub type SemiHonestOtExtensionReceiver = OtExtensionReceiver<SemiHonestMarker>;
75
76pub type MaliciousOtExtensionSender = OtExtensionSender<MaliciousMarker>;
78pub type MaliciousOtExtensionReceiver = OtExtensionReceiver<MaliciousMarker>;
80
81#[derive(thiserror::Error, Debug)]
83#[non_exhaustive]
84pub enum Error {
85 #[error("unable to compute base OTs")]
86 BaseOT(#[from] base::Error),
87 #[error("connection error to peer")]
88 Connection(#[from] ConnectionError),
89 #[error("error in sending/receiving data")]
90 Communication(#[from] io::Error),
91 #[error("connection closed by peer")]
92 UnexcpectedClose,
93 #[error("Commitment does not match seed")]
95 WrongCommitment,
96 #[error("sender did not receiver x value in KOS check")]
98 MissingXValue,
99 #[error("malicious check failed")]
101 MaliciousCheck,
102 #[doc(hidden)]
103 #[error("async task is dropped. This error should not be observable.")]
104 AsyncTaskDropped,
105}
106
107impl<S: Security> OtExtensionSender<S> {
108 pub fn new(conn: Connection) -> Self {
110 Self::new_with_rng(conn, StdRng::from_os_rng())
111 }
112
113 pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
117 let base_ot = SimplestOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
118 Self {
119 rng,
120 base_ot,
121 conn,
122 base_rngs: vec![],
123 base_choices: vec![],
124 delta: None,
125 batch_size: DEFAULT_OT_BATCH_SIZE,
126 security: PhantomData,
127 }
128 }
129
130 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
137 self.batch_size = batch_size;
138 self
139 }
140
141 pub fn batch_size(&self) -> usize {
143 self.batch_size
144 }
145
146 pub fn has_base_ots(&self) -> bool {
149 self.base_rngs.len() == BASE_OT_COUNT
150 }
151
152 pub async fn do_base_ots(&mut self) -> Result<(), Error> {
155 let base_choices = random_choices(BASE_OT_COUNT, &mut self.rng);
156 let base_ots = self.base_ot.receive(&base_choices).await?;
157 self.base_rngs = base_ots.into_iter().map(AesRng::from_seed).collect();
158 self.delta = Some(Block::from_choices(&base_choices));
159 self.base_choices = base_choices;
160 Ok(())
161 }
162}
163
164impl<S: Security> Connected for OtExtensionSender<S> {
165 fn connection(&mut self) -> &mut Connection {
166 &mut self.conn
167 }
168}
169
170impl SemiHonest for OtExtensionSender<SemiHonestMarker> {}
171impl SemiHonest for OtExtensionSender<MaliciousMarker> {}
174
175impl Malicious for OtExtensionSender<MaliciousMarker> {}
176
177impl<S: Security> RotSender for OtExtensionSender<S> {
178 type Error = Error;
179
180 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
186 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
187 async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
188 let count = ots.len();
189 assert_eq!(0, count % 128, "count must be multiple of 128");
190 let batch_size = self.batch_size();
191 let batches = count / batch_size;
192 let batch_size_remainder = count % batch_size;
193 let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
194
195 assert_eq!(
196 0,
197 batch_size_remainder % 128,
198 "count % batch_size must be multiple of 128"
199 );
200
201 let batch_sizes = iter::repeat_n(batch_size, batches)
202 .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
203
204 if !self.has_base_ots() {
205 self.do_base_ots().await?;
206 }
207
208 let delta = self.delta.expect("base OTs are done");
209 let mut sub_conn = self.conn.sub_connection();
210
211 let (ch_s, ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
213 let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Block>();
214 let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
215 let mut base_rngs = mem::take(&mut self.base_rngs);
217 let base_choices = mem::take(&mut self.base_choices);
218 let batch_sizes_th = batch_sizes.clone();
219 let owned_ots = mem::take(ots);
220 let mut rng = StdRng::from_rng(&mut self.rng);
221
222 let jh = spawn_compute(move || {
225 let mut ots = owned_ots;
226 let mut extra_messages: Vec<[Block; 2]> = Vec::zeroed(num_extra);
227 let mut transposed = Vec::zeroed(batch_size);
228 let mut owned_v_mat: Vec<Block> = if S::MALICIOUS_SECURITY {
229 Vec::zeroed(ots.len())
230 } else {
231 vec![]
232 };
233 let mut extra_v_mat = vec![Block::ZERO; num_extra];
234
235 for (ots, batch_sizes, extra) in [
236 (
237 &mut ots[..],
238 &mut batch_sizes_th.clone() as &mut dyn Iterator<Item = _>,
239 false,
240 ),
241 (&mut extra_messages[..], &mut iter::once(num_extra), true),
242 ] {
243 for (chunk_idx, (ot_batch, curr_batch_size)) in
248 ots.chunks_mut(batch_size).zip(batch_sizes).enumerate()
249 {
250 let v_mat = if S::MALICIOUS_SECURITY {
251 if extra {
252 &mut extra_v_mat
253 } else {
254 let offset = chunk_idx * batch_size;
255 &mut owned_v_mat[offset..offset + curr_batch_size]
256 }
257 } else {
258 cast_slice_mut(&mut ot_batch[..curr_batch_size / 2])
262 };
263 let v_mat = cast_slice_mut(v_mat);
264
265 let cols_byte_batch = curr_batch_size / 8;
266 let row_iter = v_mat.chunks_exact_mut(cols_byte_batch);
267
268 for ((v_row, base_rng), base_choice) in
269 row_iter.zip(&mut base_rngs).zip(&base_choices)
270 {
271 base_rng.fill_bytes(v_row);
272 let mut recv_row = ch_r.recv()?;
273 let choice_mask =
278 Block::conditional_select(&Block::ZERO, &Block::ONES, *base_choice);
279 and_inplace_elem(&mut recv_row, choice_mask);
282 let v_row = bytemuck::cast_slice_mut(v_row);
283 xor_inplace(v_row, &recv_row);
286 }
287 {
288 let transposed = bytemuck::cast_slice_mut(&mut transposed);
289 transpose_bitmatrix(v_mat, &mut transposed[..v_mat.len()], BASE_OT_COUNT);
290 }
291
292 for (v, ots) in transposed.iter().zip(ot_batch.iter_mut()) {
293 *ots = [*v, *v ^ delta]
294 }
295
296 if S::MALICIOUS_SECURITY {
297 FIXED_KEY_HASH.tccr_hash_slice_mut(
298 bytemuck::must_cast_slice_mut(ot_batch),
299 |i| {
300 Block::from(chunk_idx * batch_size + (i / 2))
305 },
306 );
307 } else {
308 FIXED_KEY_HASH.cr_hash_slice_mut(bytemuck::must_cast_slice_mut(ot_batch));
309 }
310 }
311 }
312
313 if S::MALICIOUS_SECURITY {
314 let seed: Block = rng.random();
315 kos_ch_s.send(seed)?;
316 let rng = AesRng::from_seed(seed);
317
318 let mut q1 = extra_v_mat;
319 let mut q2 = vec![Block::ZERO; BASE_OT_COUNT];
320
321 let owned_v_mat_ref = &owned_v_mat;
322
323 let challenges: Vec<Block> = rng
324 .sample_iter(StandardUniform)
325 .take(ots.len() / BASE_OT_COUNT)
326 .collect();
327
328 let block_batch_size = batch_size / BASE_OT_COUNT;
329
330 let challenge_iter =
331 batch_sizes_th
332 .clone()
333 .enumerate()
334 .flat_map(|(batch, curr_batch_size)| {
335 challenges[batch * block_batch_size
336 ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
337 .iter()
338 .cycle()
339 .take(curr_batch_size)
340 });
341
342 let q_idx_iter = batch_sizes_th.flat_map(|curr_batch_size| {
343 (0..BASE_OT_COUNT).flat_map(move |t_idx| {
344 iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
345 })
346 });
347
348 for ((v, s), q_idx) in owned_v_mat_ref.iter().zip(challenge_iter).zip(q_idx_iter) {
349 let (qi, qi2) = v.clmul(s);
350 q1[q_idx] ^= qi;
351 q2[q_idx] ^= qi2;
352 }
353
354 for (q1i, q2i) in q1.iter_mut().zip(&q2) {
355 *q1i = Block::gf_reduce(q1i, q2i);
356 }
357 let mut u = kos_ch_r.recv()?;
358 let Some(received_x) = u.pop() else {
359 return Err(Error::MissingXValue);
360 };
361 for ((received_t, base_choice), q1i) in u.iter().zip(&base_choices).zip(&q1) {
362 let tt =
363 Block::conditional_select(&Block::ZERO, &received_x, *base_choice) ^ *q1i;
364 if tt != *received_t {
365 return Err(Error::MaliciousCheck);
366 }
367 }
368 }
369
370 Ok::<_, Error>((ots, base_rngs, base_choices))
371 });
372
373 let (_, mut recv) = sub_conn.byte_stream().await?;
374
375 for batch_size in batch_sizes.chain((num_extra != 0).then_some(num_extra)) {
376 for _ in 0..BASE_OT_COUNT {
377 let mut recv_row = allocate_zeroed_vec(batch_size / Block::BITS);
378 recv.read_exact(bytemuck::cast_slice_mut(&mut recv_row))
379 .await?;
380 if ch_s.send(recv_row).is_err() {
381 resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
385 };
386 }
387 }
388
389 if S::MALICIOUS_SECURITY {
390 let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
391 let success = 'success: {
392 let Some(blk) = kos_ch_r_task.recv().await else {
393 break 'success false;
394 };
395 kos_send.as_stream().send(blk).await?;
396
397 {
398 let mut kos_recv = kos_recv.as_stream();
399 let u = kos_recv.next().await.ok_or(Error::UnexcpectedClose)??;
400 if kos_ch_s_task.send(u).is_err() {
401 break 'success false;
402 }
403 }
404
405 true
406 };
407 if !success {
408 resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
409 }
410 }
411
412 let (owned_ots, base_rngs, base_choices) = match jh.await {
413 Ok(res) => res?,
414 Err(panicked) => resume_unwind(panicked),
415 };
416 self.base_rngs = base_rngs;
417 self.base_choices = base_choices;
418 *ots = owned_ots;
419 Ok(())
420 }
421}
422
423impl SemiHonest for OtExtensionReceiver<SemiHonestMarker> {}
424impl SemiHonest for OtExtensionReceiver<MaliciousMarker> {}
425
426impl Malicious for OtExtensionReceiver<MaliciousMarker> {}
427
428impl<S: Security> OtExtensionReceiver<S> {
429 pub fn new(conn: Connection) -> Self {
431 Self::new_with_rng(conn, StdRng::from_os_rng())
432 }
433
434 pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
438 let base_ot = SimplestOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
439 Self {
440 rng,
441 base_ot,
442 conn,
443 base_rngs: vec![],
444 batch_size: DEFAULT_OT_BATCH_SIZE,
445 security: PhantomData,
446 }
447 }
448
449 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
456 self.batch_size = batch_size;
457 self
458 }
459
460 pub fn batch_size(&self) -> usize {
462 self.batch_size
463 }
464
465 pub fn has_base_ots(&self) -> bool {
468 self.base_rngs.len() == BASE_OT_COUNT
469 }
470
471 pub async fn do_base_ots(&mut self) -> Result<(), Error> {
474 let base_ots = self.base_ot.send(BASE_OT_COUNT).await?;
475 self.base_rngs = base_ots
476 .into_iter()
477 .map(|[s1, s2]| [AesRng::from_seed(s1), AesRng::from_seed(s2)])
478 .collect();
479 Ok(())
480 }
481}
482
483impl<S: Security> Connected for OtExtensionReceiver<S> {
484 fn connection(&mut self) -> &mut Connection {
485 &mut self.conn
486 }
487}
488
489impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
490 type Error = Error;
491
492 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
498 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
499 async fn receive_into(
500 &mut self,
501 ots: &mut impl Buf<Block>,
502 choices: &[Choice],
503 ) -> Result<(), Self::Error> {
504 assert_eq!(choices.len(), ots.len());
505 assert_eq!(
506 0,
507 choices.len() % 128,
508 "choices.len() must be multiple of 128"
509 );
510 let batch_size = self.batch_size();
511 let count = choices.len();
512 let batch_size_remainder = count % batch_size;
513 assert_eq!(
514 0,
515 batch_size_remainder % 128,
516 "count % batch_size must be multiple of 128"
517 );
518
519 if !self.has_base_ots() {
520 self.do_base_ots().await?;
521 }
522
523 let mut sub_conn = self.conn.sub_connection();
524
525 let cols_byte_batch = batch_size / 8;
526 let choice_vec = choices_to_u8_vec(choices);
527
528 let (ch_s, mut ch_r) = mpsc::unbounded_channel::<Vec<u8>>();
529 let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Vec<Block>>();
530 let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Block>();
531 let mut rng = StdRng::from_rng(&mut self.rng);
532
533 let mut base_rngs = mem::take(&mut self.base_rngs);
534 let owned_ots = mem::take(ots);
535 let mut jh = spawn_compute(move || {
536 let mut ots = owned_ots;
537 let t_mat_size = if S::MALICIOUS_SECURITY {
538 ots.len()
539 } else {
540 batch_size
541 };
542 let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
543 let mut t_mat = vec![Block::ZERO; t_mat_size];
544 let mut extra_t_mat = vec![Block::ZERO; num_extra];
545 let mut extra_messages: Vec<Block> = Vec::zeroed(num_extra);
546 let extra_choices = random_choices(num_extra, &mut rng);
547 let extra_choice_vec = choices_to_u8_vec(&extra_choices);
548
549 for (ots, choice_vec, extra) in [
550 (&mut ots[..], &choice_vec, false),
551 (&mut extra_messages[..], &extra_choice_vec, true),
552 ] {
553 for (chunk_idx, (output_chunk, choice_batch)) in ots
554 .chunks_mut(batch_size)
555 .zip(choice_vec.chunks(cols_byte_batch))
556 .enumerate()
557 {
558 let curr_batch_size = output_chunk.len();
559 let chunk_t_mat = if S::MALICIOUS_SECURITY {
560 if extra {
561 &mut extra_t_mat
562 } else {
563 let offset = chunk_idx * batch_size;
564 &mut t_mat[offset..offset + curr_batch_size]
565 }
566 } else {
567 &mut t_mat[..curr_batch_size]
568 };
569 assert_eq!(output_chunk.len(), chunk_t_mat.len());
570 assert_eq!(choice_batch.len() * 8, chunk_t_mat.len());
571 let chunk_t_mat: &mut [u8] = bytemuck::must_cast_slice_mut(chunk_t_mat);
572 let cols_byte_batch = choice_batch.len();
574 for (row, [rng1, rng2]) in chunk_t_mat
575 .chunks_exact_mut(cols_byte_batch)
576 .zip(&mut base_rngs)
577 {
578 rng1.fill_bytes(row);
579 let mut send_row = vec![0_u8; cols_byte_batch];
580 rng2.fill_bytes(&mut send_row);
581 for ((v2, v1), choices) in send_row.iter_mut().zip(row).zip(choice_batch) {
583 *v2 ^= *v1 ^ *choices;
584 }
585 ch_s.send(send_row)?;
586 }
587 let output_bytes = bytemuck::cast_slice_mut(output_chunk);
588 transpose_bitmatrix(
589 &chunk_t_mat[..BASE_OT_COUNT * cols_byte_batch],
590 output_bytes,
591 BASE_OT_COUNT,
592 );
593 if S::MALICIOUS_SECURITY {
594 FIXED_KEY_HASH.tccr_hash_slice_mut(output_chunk, |i| {
595 Block::from(chunk_idx * batch_size + i)
596 });
597 } else {
598 FIXED_KEY_HASH.cr_hash_slice_mut(output_chunk);
599 }
600 }
601 }
602
603 if S::MALICIOUS_SECURITY {
604 drop(ch_s);
606 let seed = kos_ch_r.recv()?;
607
608 let mut t1 = extra_t_mat;
609 let mut t2 = vec![Block::ZERO; BASE_OT_COUNT];
610
611 let mut x1 = Block::from_choices(&extra_choices);
612 let mut x2 = Block::ZERO;
613
614 let rng = AesRng::from_seed(seed);
615
616 let t_mat_ref = &t_mat;
617 let batches = count / batch_size;
618 let batch_sizes = iter::repeat_n(batch_size, batches)
619 .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
620
621 let choice_blocks: Vec<_> = choice_vec
622 .chunks_exact(Block::BYTES)
623 .map(|chunk| Block::try_from(chunk).expect("chunk is 16 bytes"))
624 .collect();
625
626 let challenges: Vec<Block> = rng
627 .sample_iter(StandardUniform)
628 .take(choice_blocks.len())
629 .collect();
630
631 for (x, s) in choice_blocks.iter().zip(challenges.iter()) {
632 let (xi, xi2) = x.clmul(s);
633 x1 ^= xi;
634 x2 ^= xi2;
635 }
636
637 let block_batch_size = batch_size / BASE_OT_COUNT;
638
639 let challenge_iter =
640 batch_sizes
641 .clone()
642 .enumerate()
643 .flat_map(|(batch, curr_batch_size)| {
644 challenges[batch * block_batch_size
645 ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
646 .iter()
647 .cycle()
648 .take(curr_batch_size)
649 });
650 let t_idx_iter = batch_sizes.flat_map(|curr_batch_size| {
651 (0..BASE_OT_COUNT).flat_map(move |t_idx| {
652 iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
653 })
654 });
655
656 for ((t, s), t_idx) in t_mat_ref.iter().zip(challenge_iter).zip(t_idx_iter) {
657 let (ti, ti2) = t.clmul(s);
658 t1[t_idx] ^= ti;
659 t2[t_idx] ^= ti2;
660 }
661
662 for (t1i, t2i) in t1.iter_mut().zip(&mut t2) {
663 *t1i = Block::gf_reduce(t1i, t2i);
664 }
665 t1.push(Block::gf_reduce(&x1, &x2));
666 kos_ch_s.send(t1)?;
667 }
668 Ok::<_, Error>((ots, base_rngs))
669 });
670
671 let (mut send, _) = sub_conn.byte_stream().await?;
672 while let Some(row) = ch_r.recv().await {
673 send.write_all(&row).await.map_err(Error::Communication)?;
674 }
675
676 if S::MALICIOUS_SECURITY {
677 let err = poll_fn(|cx| match jh.poll_unpin(cx) {
681 Poll::Ready(res) => Poll::Ready(res.map(drop)),
682 Poll::Pending => Poll::Ready(Ok(())),
683 })
684 .await;
685 if let Err(err) = err {
686 resume_unwind(err);
687 };
688 let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
689
690 let seed = {
691 let mut kos_recv = kos_recv.as_stream::<Block>();
692 kos_recv.next().await.ok_or(Error::UnexcpectedClose)??
693 };
694
695 let success = 'success: {
696 if kos_ch_s_task.send(seed).is_err() {
697 break 'success false;
698 }
699
700 let mut kos_send = kos_send.as_stream::<Vec<Block>>();
701 let Some(v) = kos_ch_r_task.recv().await else {
702 break 'success false;
703 };
704 kos_send.send(v).await.map_err(Error::Communication)?;
705
706 true
707 };
708 if !success {
709 resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
710 }
711 }
712
713 let (owned_ots, base_rngs) = match jh.await {
714 Ok(res) => res?,
715 Err(panicked) => resume_unwind(panicked),
716 };
717
718 self.base_rngs = base_rngs;
719 *ots = owned_ots;
720 Ok(())
721 }
722}
723
724impl<S: Security> CotSender for OtExtensionSender<S> {
725 type Error = Error;
726
727 async fn correlated_send_into<B, F>(
728 &mut self,
729 ots: &mut B,
730 correlation: F,
731 ) -> Result<(), Self::Error>
732 where
733 B: Buf<Block>,
734 F: FnMut(usize) -> Block + Send,
735 {
736 CorrelatedFromRandom::new(self)
737 .correlated_send_into(ots, correlation)
738 .await
739 }
740}
741
742impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
743 type Error = Error;
744
745 async fn correlated_receive_into<B>(
746 &mut self,
747 ots: &mut B,
748 choices: &[Choice],
749 ) -> Result<(), Self::Error>
750 where
751 B: Buf<Block>,
752 {
753 CorrelatedFromRandom::new(self)
754 .correlated_receive_into(ots, choices)
755 .await
756 }
757}
758
759fn choices_to_u8_vec(choices: &[Choice]) -> Vec<u8> {
760 assert_eq!(0, choices.len() % 8);
761 let mut v = vec![0_u8; choices.len() / 8];
762 for (chunk, byte) in choices.chunks_exact(8).zip(&mut v) {
763 for (i, choice) in chunk.iter().enumerate() {
764 *byte ^= choice.unwrap_u8() << i;
765 }
766 }
767 v
768}
769
770impl From<std::sync::mpsc::RecvError> for Error {
771 fn from(_: std::sync::mpsc::RecvError) -> Self {
772 Error::AsyncTaskDropped
773 }
774}
775
776impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
777 fn from(_: tokio::sync::mpsc::error::SendError<T>) -> Self {
778 Error::AsyncTaskDropped
779 }
780}
781
782#[cfg(test)]
783mod tests {
784
785 use cryprot_core::Block;
786 use cryprot_net::testing::{init_tracing, local_conn};
787 use rand::{SeedableRng, rngs::StdRng};
788
789 use crate::{
790 CotReceiver, CotSender, MaliciousMarker, RotReceiver, RotSender,
791 extension::{
792 DEFAULT_OT_BATCH_SIZE, OtExtensionReceiver, OtExtensionSender,
793 SemiHonestOtExtensionReceiver, SemiHonestOtExtensionSender,
794 },
795 random_choices,
796 };
797
798 #[tokio::test]
799 async fn test_extension() {
800 let _g = init_tracing();
801 const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE;
802 let (c1, c2) = local_conn().await.unwrap();
803 let rng1 = StdRng::seed_from_u64(42);
804 let mut rng2 = StdRng::seed_from_u64(24);
805 let choices = random_choices(COUNT, &mut rng2);
806 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
807 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
808 let (send_ots, recv_ots) =
809 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
810 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
811 assert_eq!(r, s[c.unwrap_u8() as usize]);
812 }
813 }
814
815 #[tokio::test]
816 async fn test_extension_half_batch() {
817 let _g = init_tracing();
818 const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2;
819 let (c1, c2) = local_conn().await.unwrap();
820 let rng1 = StdRng::seed_from_u64(42);
821 let mut rng2 = StdRng::seed_from_u64(24);
822 let choices = random_choices(COUNT, &mut rng2);
823 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
824 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
825 let (send_ots, recv_ots) =
826 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
827 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
828 assert_eq!(r, s[c.unwrap_u8() as usize]);
829 }
830 }
831
832 #[tokio::test]
833 async fn test_extension_partial_batch() {
834 let _g = init_tracing();
835 const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2 + 128;
836 let (c1, c2) = local_conn().await.unwrap();
837 let rng1 = StdRng::seed_from_u64(42);
838 let mut rng2 = StdRng::seed_from_u64(24);
839 let choices = random_choices(COUNT, &mut rng2);
840 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
841 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
842 let (send_ots, recv_ots) =
843 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
844 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
845 assert_eq!(r, s[c.unwrap_u8() as usize]);
846 }
847 }
848
849 #[tokio::test]
850 async fn test_extension_malicious_half_batch() {
851 let _g = init_tracing();
852 const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2;
853 let (c1, c2) = local_conn().await.unwrap();
854 let rng1 = StdRng::seed_from_u64(42);
855 let mut rng2 = StdRng::seed_from_u64(24);
856 let choices = random_choices(COUNT, &mut rng2);
857 let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
858 let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
859
860 let (send_ots, recv_ots) =
861 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
862 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
863 assert_eq!(r, s[c.unwrap_u8() as usize]);
864 }
865 }
866
867 #[tokio::test]
868 async fn test_extension_malicious_partial_batch() {
869 let _g = init_tracing();
870 const COUNT: usize = DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2 + 128;
871 let (c1, c2) = local_conn().await.unwrap();
872 let rng1 = StdRng::seed_from_u64(42);
873 let mut rng2 = StdRng::seed_from_u64(24);
874 let choices = random_choices(COUNT, &mut rng2);
875 let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
876 let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
877
878 let (send_ots, recv_ots) =
879 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
880 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
881 assert_eq!(r, s[c.unwrap_u8() as usize]);
882 }
883 }
884
885 #[tokio::test]
886 async fn test_extension_malicious_multiple_batch() {
887 let _g = init_tracing();
888 const COUNT: usize = DEFAULT_OT_BATCH_SIZE * 2;
889 let (c1, c2) = local_conn().await.unwrap();
890 let rng1 = StdRng::seed_from_u64(42);
891 let mut rng2 = StdRng::seed_from_u64(24);
892 let choices = random_choices(COUNT, &mut rng2);
893 let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
894 let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
895
896 let (send_ots, recv_ots) =
897 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
898 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
899 assert_eq!(r, s[c.unwrap_u8() as usize]);
900 }
901 }
902
903 #[tokio::test]
904 async fn test_correlated_extension() {
905 let _g = init_tracing();
906 const COUNT: usize = 128;
907 let (c1, c2) = local_conn().await.unwrap();
908 let rng1 = StdRng::seed_from_u64(42);
909 let mut rng2 = StdRng::seed_from_u64(24);
910 let choices = random_choices(COUNT, &mut rng2);
911 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
912 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
913 let (send_ots, recv_ots) = tokio::try_join!(
914 sender.correlated_send(COUNT, |_| Block::ONES),
915 receiver.correlated_receive(&choices)
916 )
917 .unwrap();
918 for (i, ((r, s), c)) in recv_ots.into_iter().zip(send_ots).zip(choices).enumerate() {
919 if bool::from(c) {
920 assert_eq!(r ^ Block::ONES, s, "Block {i}");
921 } else {
922 assert_eq!(r, s, "Block {i}")
923 }
924 }
925 }
926}