cryprot_ot/
silent_ot.rs

1//! Semi-honest and malicious Silent OT implementation using expand-convolute code [[RRT23](https://eprint.iacr.org/2023/882)].
2#![allow(non_snake_case)]
3use std::{io, marker::PhantomData, mem};
4
5use bytemuck::cast_slice_mut;
6use cryprot_codes::ex_conv::{ExConvCode, ExConvCodeConfig};
7use cryprot_core::{
8    AES_PAR_BLOCKS, Block, aes_hash::FIXED_KEY_HASH, alloc::HugePageMemory, buf::Buf,
9    random_oracle::Hash, tokio_rayon::spawn_compute,
10};
11use cryprot_net::{Connection, ConnectionError};
12use cryprot_pprf::{PprfConfig, RegularPprfReceiver, RegularPprfSender};
13use futures::{SinkExt, StreamExt};
14use rand::{Rng, SeedableRng, rngs::StdRng};
15use subtle::Choice;
16use tracing::Level;
17
18use crate::{
19    Connected, Malicious, MaliciousMarker, RandChoiceRotReceiver, RandChoiceRotSender, RotReceiver,
20    RotSender, Security, SemiHonest, SemiHonestMarker,
21    extension::{self, OtExtensionReceiver, OtExtensionSender},
22    noisy_vole::{self, NoisyVoleReceiver, NoisyVoleSender},
23    phase,
24};
25
26pub const SECURITY_PARAMETER: usize = 128;
27const SCALER: usize = 2;
28
29pub type SemiHonestSilentOtSender = SilentOtSender<SemiHonestMarker>;
30pub type SemiHonestSilentOtReceiver = SilentOtReceiver<SemiHonestMarker>;
31
32pub type MaliciousSilentOtSender = SilentOtSender<MaliciousMarker>;
33pub type MaliciousSilentOtReceiver = SilentOtReceiver<MaliciousMarker>;
34
35pub struct SilentOtSender<S> {
36    conn: Connection,
37    ot_sender: OtExtensionSender<SemiHonestMarker>,
38    rng: StdRng,
39    s: PhantomData<S>,
40}
41
42#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
43pub enum MultType {
44    #[default]
45    ExConv7x24,
46    ExConv21x24,
47}
48
49#[derive(thiserror::Error, Debug)]
50pub enum Error {
51    #[error("unable to perform base OTs for silent OTs")]
52    BaseOt(#[from] extension::Error),
53    #[error("error in pprf expansion for silent OTs")]
54    Pprf(#[from] cryprot_pprf::Error),
55    #[error("io error during malicious check")]
56    Io(#[from] io::Error),
57    #[error("error in connection to peer")]
58    Connection(#[from] ConnectionError),
59    #[error("error in noisy vole during malicious check")]
60    NoisyVole(#[from] noisy_vole::Error),
61    #[error("sender did not transmit hash in malicious check")]
62    MissingSenderHash,
63    #[error("receiver did not transmit seed in malicious check")]
64    MissingReceiverSeed,
65    #[error("malicious check failed")]
66    MaliciousCheck,
67}
68
69impl<S: Security> SilentOtSender<S> {
70    pub fn new(mut conn: Connection) -> Self {
71        let ot_sender = OtExtensionSender::new(conn.sub_connection());
72        Self {
73            conn,
74            ot_sender,
75            rng: StdRng::from_os_rng(),
76            s: PhantomData,
77        }
78    }
79
80    // Needed minimum buffer size when using `correlated_send_into` method.
81    pub fn ots_buf_size(count: usize) -> usize {
82        let conf = Config::configure(count, MultType::default());
83        let pprf_conf = PprfConfig::from(conf);
84        pprf_conf.size()
85    }
86
87    pub async fn random_send(&mut self, count: usize) -> Result<impl Buf<[Block; 2]>, Error> {
88        let mut ots = HugePageMemory::zeroed(count);
89        self.random_sent_into(count, &mut ots).await?;
90        Ok(ots)
91    }
92
93    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::SILENT_RANDOM_EXTENSION))]
94    pub async fn random_sent_into(
95        &mut self,
96        count: usize,
97        ots: &mut impl Buf<[Block; 2]>,
98    ) -> Result<(), Error> {
99        assert_eq!(count, ots.len());
100        let delta = self.rng.random();
101        let mut ots_buf = mem::take(ots);
102        let correlated = self.correlated_send(count, delta).await?;
103
104        let ots_buf = spawn_compute(move || {
105            let masked_delta = delta & Block::MASK_LSB;
106            for ((chunk_idx, ot_chunk), corr_chunk) in ots_buf
107                .chunks_mut(AES_PAR_BLOCKS)
108                .enumerate()
109                .zip(correlated.chunks(AES_PAR_BLOCKS))
110            {
111                for (ots, corr) in ot_chunk.iter_mut().zip(corr_chunk) {
112                    let masked = *corr & Block::MASK_LSB;
113                    *ots = [masked, masked ^ masked_delta]
114                }
115                if S::MALICIOUS_SECURITY {
116                    // It is currently unknown whether a cr hash is sufficient for Silent OT, so we
117                    // use the safe choice of a tccr hash at the cost of some performance.
118                    // See https://github.com/osu-crypto/libOTe/issues/166 for discussion
119                    FIXED_KEY_HASH.tccr_hash_slice_mut(cast_slice_mut(ot_chunk), |i| {
120                        Block::from(chunk_idx * AES_PAR_BLOCKS + i / 2)
121                    });
122                } else {
123                    FIXED_KEY_HASH.cr_hash_slice_mut(cast_slice_mut(ot_chunk));
124                }
125            }
126            ots_buf
127        })
128        .await
129        .expect("worker panic");
130        *ots = ots_buf;
131        Ok(())
132    }
133
134    pub async fn correlated_send(
135        &mut self,
136        count: usize,
137        delta: Block,
138    ) -> Result<impl Buf<Block>, Error> {
139        let mut ots = HugePageMemory::zeroed(Self::ots_buf_size(count));
140        self.correlated_send_into(count, delta, &mut ots).await?;
141        Ok(ots)
142    }
143
144    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::SILENT_CORRELATED_EXTENSION))]
145    pub async fn correlated_send_into(
146        &mut self,
147        count: usize,
148        delta: Block,
149        ots: &mut impl Buf<Block>,
150    ) -> Result<(), Error> {
151        let mult_type = MultType::default();
152        let conf = Config::configure(count, mult_type);
153        let pprf_conf = PprfConfig::from(conf);
154        assert!(
155            ots.len() >= pprf_conf.size(),
156            "ots Buf not big enough. Allocate at least Self::ots_buf_size"
157        );
158
159        let mal_check_ot_count = S::MALICIOUS_SECURITY as usize * SECURITY_PARAMETER;
160        let base_ot_count = pprf_conf.base_ot_count().next_multiple_of(128) + mal_check_ot_count;
161
162        // count must be divisable by 128 for ot_extension
163        let mut base_ots = self.ot_sender.send(base_ot_count).await?;
164
165        let mal_check_ots = base_ots.split_off(base_ot_count - mal_check_ot_count);
166
167        base_ots.truncate(pprf_conf.base_ot_count());
168
169        let pprf_sender =
170            RegularPprfSender::new_with_conf(self.conn.sub_connection(), pprf_conf, base_ots);
171        let mut B = mem::take(ots);
172        pprf_sender
173            .expand(delta, self.rng.random(), conf.pprf_out_fmt(), &mut B)
174            .await?;
175
176        if S::MALICIOUS_SECURITY {
177            self.ferret_mal_check(delta, &mut B, mal_check_ots).await?;
178        }
179
180        let enc = Encoder::new(count, mult_type);
181        *ots = enc.send_compress(B).await;
182        Ok(())
183    }
184
185    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::MALICIOUS_CHECK))]
186    async fn ferret_mal_check(
187        &mut self,
188        delta: Block,
189        B: &mut impl Buf<Block>,
190        mal_check_ots: Vec<[Block; 2]>,
191    ) -> Result<(), Error> {
192        assert_eq!(SECURITY_PARAMETER, mal_check_ots.len());
193        let (mut tx, mut rx) = self.conn.request_response_stream().await?;
194        let mal_check_seed: Block = rx.next().await.ok_or(Error::MissingReceiverSeed)??;
195
196        let owned_B = mem::take(B);
197        let jh = spawn_compute(move || {
198            let mut xx = mal_check_seed;
199            let (sum_low, sum_high) = owned_B.iter().fold(
200                (Block::ZERO, Block::ZERO),
201                |(mut sum_low, mut sum_high), b| {
202                    let (low, high) = xx.clmul(b);
203                    sum_low ^= low;
204                    sum_high ^= high;
205                    xx = xx.gf_mul(&mal_check_seed);
206                    (sum_low, sum_high)
207                },
208            );
209            (Block::gf_reduce(&sum_low, &sum_high), owned_B)
210        });
211
212        let mut receiver = NoisyVoleReceiver::new(self.conn.sub_connection());
213        let a = receiver.receive(vec![delta], mal_check_ots).await?;
214
215        let (my_sum, owned_B) = jh.await.expect("worker panic");
216        *B = owned_B;
217
218        let my_hash = (my_sum ^ a[0]).ro_hash();
219        tx.send(my_hash).await?;
220
221        Ok(())
222    }
223}
224
225pub struct SilentOtReceiver<S> {
226    conn: Connection,
227    ot_receiver: OtExtensionReceiver<SemiHonestMarker>,
228    rng: StdRng,
229    s: PhantomData<S>,
230}
231
232impl<S: Security> SilentOtReceiver<S> {
233    pub fn new(mut conn: Connection) -> Self {
234        let ot_receiver = OtExtensionReceiver::new(conn.sub_connection());
235        Self {
236            conn,
237            ot_receiver,
238            rng: StdRng::from_os_rng(),
239            s: PhantomData,
240        }
241    }
242
243    // Needed minimum buffer size when using `receive_into` methods.
244    pub fn ots_buf_size(count: usize) -> usize {
245        let conf = Config::configure(count, MultType::default());
246        let pprf_conf = PprfConfig::from(conf);
247        pprf_conf.size()
248    }
249
250    pub async fn random_receive(
251        &mut self,
252        count: usize,
253    ) -> Result<(impl Buf<Block>, Vec<Choice>), Error> {
254        let mut ots = HugePageMemory::zeroed(Self::ots_buf_size(count));
255        let choices = self.random_receive_into(count, &mut ots).await?;
256        Ok((ots, choices))
257    }
258
259    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::SILENT_RANDOM_EXTENSION))]
260    pub async fn random_receive_into(
261        &mut self,
262        count: usize,
263        ots: &mut impl Buf<Block>,
264    ) -> Result<Vec<Choice>, Error> {
265        self.internal_correlated_receive_into(count, ChoiceBitPacking::Packed, ots)
266            .await?;
267
268        let mut ots_buf = mem::take(ots);
269        let (ots_buf, choices) = spawn_compute(move || {
270            let choices = ots_buf
271                .iter_mut()
272                .map(|block| {
273                    let choice = Choice::from(block.lsb() as u8);
274                    *block &= Block::MASK_LSB;
275                    choice
276                })
277                .collect();
278
279            if S::MALICIOUS_SECURITY {
280                FIXED_KEY_HASH.tccr_hash_slice_mut(&mut ots_buf, Block::from);
281            } else {
282                FIXED_KEY_HASH.cr_hash_slice_mut(&mut ots_buf);
283            }
284            (ots_buf, choices)
285        })
286        .await
287        .expect("worker panic");
288        *ots = ots_buf;
289        Ok(choices)
290    }
291
292    pub async fn correlated_receive(
293        &mut self,
294        count: usize,
295    ) -> Result<(impl Buf<Block>, Vec<Choice>), Error> {
296        let mut ots = HugePageMemory::zeroed(Self::ots_buf_size(count));
297        let choices = self.correlated_receive_into(count, &mut ots).await?;
298        Ok((ots, choices))
299    }
300
301    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::SILENT_CORRELATED_EXTENSION))]
302    pub async fn correlated_receive_into(
303        &mut self,
304        count: usize,
305        ots: &mut impl Buf<Block>,
306    ) -> Result<Vec<Choice>, Error> {
307        self.internal_correlated_receive_into(count, ChoiceBitPacking::NotPacked, ots)
308            .await
309            .map(|cb| cb.expect("not choice packed"))
310    }
311
312    async fn internal_correlated_receive_into(
313        &mut self,
314        count: usize,
315        cb_packing: ChoiceBitPacking,
316        ots: &mut impl Buf<Block>,
317    ) -> Result<Option<Vec<Choice>>, Error> {
318        let mult_type = MultType::default();
319        let conf = Config::configure(count, mult_type);
320        let pprf_conf = PprfConfig::from(conf);
321        assert_eq!(ots.len(), pprf_conf.size());
322
323        let base_choices = pprf_conf.sample_choice_bits(&mut self.rng);
324        let noisy_points = pprf_conf.get_points(conf.pprf_out_fmt(), &base_choices);
325
326        let mut base_choices_subtle: Vec<_> =
327            base_choices.iter().copied().map(Choice::from).collect();
328        // we will discard these base OTs so we simply set the choice to 0. The ot
329        // extension implementation can currently only handle num ots that are multiple
330        // of 128
331        base_choices_subtle.resize(
332            pprf_conf.base_ot_count().next_multiple_of(128),
333            Choice::from(0),
334        );
335
336        let mut mal_check_seed = Block::ZERO;
337        let mut mal_check_x = Block::ZERO;
338        if S::MALICIOUS_SECURITY {
339            mal_check_seed = self.rng.random();
340
341            for &p in &noisy_points {
342                mal_check_x ^= mal_check_seed.gf_pow(p as u64 + 1);
343            }
344            base_choices_subtle.extend(mal_check_x.bits().map(|b| Choice::from(b as u8)));
345        }
346
347        let mut base_ots = self.ot_receiver.receive(&base_choices_subtle).await?;
348        let mal_check_ots = base_ots
349            .split_off(base_ots.len() - (S::MALICIOUS_SECURITY as usize * SECURITY_PARAMETER));
350
351        base_ots.truncate(pprf_conf.base_ot_count());
352
353        let pprf_receiver = RegularPprfReceiver::new_with_conf(
354            self.conn.sub_connection(),
355            pprf_conf,
356            base_ots,
357            base_choices,
358        );
359        let mut A = mem::take(ots);
360        pprf_receiver.expand(conf.pprf_out_fmt(), &mut A).await?;
361
362        if S::MALICIOUS_SECURITY {
363            self.ferret_mal_check(&mut A, mal_check_seed, mal_check_x, mal_check_ots)
364                .await?;
365        }
366
367        let enc = Encoder::new(count, mult_type);
368        let (A, choices) = enc.receive_compress(A, noisy_points, cb_packing).await;
369        *ots = A;
370        Ok(choices)
371    }
372
373    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::MALICIOUS_CHECK))]
374    async fn ferret_mal_check(
375        &mut self,
376        A: &mut impl Buf<Block>,
377        mal_check_seed: Block,
378        mal_check_x: Block,
379        mal_check_ots: Vec<Block>,
380    ) -> Result<(), Error> {
381        assert_eq!(SECURITY_PARAMETER, mal_check_ots.len());
382        let (mut tx, mut rx) = self.conn.request_response_stream().await?;
383        tx.send(mal_check_seed).await?;
384
385        let owned_A = mem::take(A);
386        let jh = spawn_compute(move || {
387            let mut xx = mal_check_seed;
388            let (sum_low, sum_high) = owned_A.iter().fold(
389                (Block::ZERO, Block::ZERO),
390                |(mut sum_low, mut sum_high), a| {
391                    let (low, high) = xx.clmul(a);
392                    sum_low ^= low;
393                    sum_high ^= high;
394                    xx = xx.gf_mul(&mal_check_seed);
395                    (sum_low, sum_high)
396                },
397            );
398            (Block::gf_reduce(&sum_low, &sum_high), owned_A)
399        });
400
401        let mut sender = NoisyVoleSender::new(self.conn.sub_connection());
402        let b = sender.send(1, mal_check_x, mal_check_ots).await?;
403
404        let (my_sum, owned_A) = jh.await.expect("worker panic");
405        *A = owned_A;
406
407        let my_hash = (my_sum ^ b[0]).ro_hash();
408
409        let their_hash: Hash = rx.next().await.ok_or(Error::MissingSenderHash)??;
410        if my_hash != their_hash {
411            return Err(Error::MaliciousCheck);
412        }
413        Ok(())
414    }
415}
416
417impl SemiHonest for SilentOtSender<SemiHonestMarker> {}
418impl SemiHonest for SilentOtReceiver<SemiHonestMarker> {}
419
420impl SemiHonest for SilentOtSender<MaliciousMarker> {}
421impl SemiHonest for SilentOtReceiver<MaliciousMarker> {}
422impl Malicious for SilentOtSender<MaliciousMarker> {}
423impl Malicious for SilentOtReceiver<MaliciousMarker> {}
424
425impl<S> Connected for SilentOtSender<S> {
426    fn connection(&mut self) -> &mut Connection {
427        &mut self.conn
428    }
429}
430
431impl<S: Security> RandChoiceRotSender for SilentOtSender<S> {}
432
433impl<S: Security> RotSender for SilentOtSender<S> {
434    type Error = Error;
435
436    async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
437        self.random_sent_into(ots.len(), ots).await?;
438        Ok(())
439    }
440}
441
442impl<S> Connected for SilentOtReceiver<S> {
443    fn connection(&mut self) -> &mut Connection {
444        &mut self.conn
445    }
446}
447
448impl<S: Security> RandChoiceRotReceiver for SilentOtReceiver<S> {
449    type Error = Error;
450
451    async fn rand_choice_receive_into(
452        &mut self,
453        ots: &mut impl Buf<Block>,
454    ) -> Result<Vec<Choice>, Self::Error> {
455        let count = ots.len();
456        ots.grow_zeroed(Self::ots_buf_size(count));
457        let choices = self.random_receive_into(count, ots).await?;
458        Ok(choices)
459    }
460}
461
462#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
463enum ChoiceBitPacking {
464    #[default]
465    Packed,
466    NotPacked,
467}
468
469#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470struct Config {
471    num_partitions: usize,
472    size_per: usize,
473    mult_type: MultType,
474}
475
476impl Config {
477    fn configure(num_ots: usize, mult_type: MultType) -> Self {
478        let min_dist = match mult_type {
479            MultType::ExConv7x24 => 0.15,
480            MultType::ExConv21x24 => 0.2,
481        };
482        let num_partitions = get_reg_noise_weight(min_dist, num_ots * SCALER, SECURITY_PARAMETER);
483        let size_per = 4.max(
484            (num_ots * SCALER)
485                .div_ceil(num_partitions)
486                .next_multiple_of(2),
487        );
488
489        Self {
490            num_partitions,
491            size_per,
492            mult_type,
493        }
494    }
495
496    fn pprf_out_fmt(&self) -> cryprot_pprf::OutFormat {
497        cryprot_pprf::OutFormat::Interleaved
498    }
499}
500
501impl From<Config> for PprfConfig {
502    fn from(value: Config) -> Self {
503        Self::new(value.size_per, value.num_partitions)
504    }
505}
506
507struct Encoder {
508    code: ExConvCode,
509}
510
511impl Encoder {
512    fn new(num_ots: usize, mult_type: MultType) -> Self {
513        let expander_weight = match mult_type {
514            MultType::ExConv7x24 => 7,
515            MultType::ExConv21x24 => 21,
516        };
517        let code = ExConvCode::new_with_conf(
518            num_ots,
519            ExConvCodeConfig {
520                code_size: num_ots * SCALER,
521                expander_weight,
522                ..Default::default()
523            },
524        );
525        assert_eq!(code.conf().accumulator_size, 24);
526        Self { code }
527    }
528
529    async fn send_compress<B: Buf<Block>>(self, mut b: B) -> B {
530        spawn_compute(move || {
531            self.code.dual_encode(&mut b[..self.code.conf().code_size]);
532            b.set_len(self.code.message_size());
533            b
534        })
535        .await
536        .expect("worker panic")
537    }
538
539    async fn receive_compress<B: Buf<Block>>(
540        self,
541        mut a: B,
542        noisy_points: Vec<usize>,
543        cb_packing: ChoiceBitPacking,
544    ) -> (B, Option<Vec<Choice>>) {
545        let jh = spawn_compute(move || {
546            let (mut a, cb) = if cb_packing == ChoiceBitPacking::Packed {
547                // Set lsb of noisy point idx to 1, all others to 0
548                let mask_lsb = Block::ONES ^ Block::ONE;
549                for block in a.iter_mut() {
550                    *block &= mask_lsb;
551                }
552
553                for idx in noisy_points {
554                    a[idx] |= Block::ONE
555                }
556
557                self.code.dual_encode(&mut a[..self.code.conf().code_size]);
558                (a, None::<Vec<Choice>>)
559            } else {
560                self.code.dual_encode(&mut a[..self.code.conf().code_size]);
561                let mut choices = vec![0_u8; self.code.conf().code_size];
562                for idx in noisy_points {
563                    if idx < choices.len() {
564                        choices[idx] = 1;
565                    }
566                }
567                self.code.dual_encode(&mut choices);
568                let mut choices: Vec<_> = choices.into_iter().map(Choice::from).collect();
569                choices.truncate(self.code.message_size());
570                (a, Some(choices))
571            };
572
573            a.set_len(self.code.message_size());
574            (a, cb)
575        });
576        jh.await.expect("worker panic")
577    }
578}
579
580#[allow(non_snake_case)]
581fn get_reg_noise_weight(min_dist_ratio: f64, N: usize, sec_param: usize) -> usize {
582    assert!(min_dist_ratio <= 0.5 && min_dist_ratio > 0.0);
583    let d = (1.0 - 2.0 * min_dist_ratio).log2();
584    let mut t = 40.max((-(sec_param as f64) / d) as usize);
585    if N < 512 {
586        t = t.max(64);
587    }
588    t.next_multiple_of(cryprot_pprf::PARALLEL_TREES)
589}
590
591#[cfg(test)]
592mod tests {
593    use cryprot_core::Block;
594    use cryprot_net::testing::{init_tracing, local_conn};
595    use subtle::Choice;
596
597    use crate::{
598        RandChoiceRotReceiver, RotSender,
599        silent_ot::{
600            MaliciousSilentOtReceiver, MaliciousSilentOtSender, SemiHonestSilentOtReceiver,
601            SemiHonestSilentOtSender,
602        },
603    };
604
605    fn check_correlated(a: &[Block], b: &[Block], choice: Option<&[Choice]>, delta: Block) {
606        {
607            let n = a.len();
608            assert_eq!(b.len(), n);
609            if let Some(choice) = choice {
610                assert_eq!(choice.len(), n)
611            }
612            let mask = if choice.is_some() {
613                // don't mask off lsb when not using choice packing
614                Block::ONES
615            } else {
616                // mask off lsb
617                Block::ONES ^ Block::ONE
618            };
619
620            for i in 0..n {
621                let m1 = a[i];
622                let c = if let Some(choice) = choice {
623                    choice[i].unwrap_u8() as usize
624                } else {
625                    // extract choice bit from m1
626                    ((m1 & Block::ONE) == Block::ONE) as usize
627                };
628                let m1 = m1 & mask;
629                let m2a = b[i] & mask;
630                let m2b = (b[i] ^ delta) & mask;
631
632                let eqq = [m1 == m2a, m1 == m2b];
633                assert!(
634                    eqq[c] && !eqq[c ^ 1],
635                    "Blocks at {i} differ. Choice: {c} {m1:?}, {m2a:?}, {m2b:?}"
636                );
637                assert!(eqq[0] || eqq[1]);
638            }
639        }
640    }
641
642    fn check_random(count: usize, s_ot: &[[Block; 2]], r_ot: &[Block], c: &[Choice]) {
643        assert_eq!(s_ot.len(), count);
644        assert_eq!(r_ot.len(), count);
645        assert_eq!(c.len(), count);
646
647        for i in 0..count {
648            assert_eq!(
649                r_ot[i],
650                s_ot[i][c[i].unwrap_u8() as usize],
651                "Difference at OT {i}\nr_ot: {:?}\ns_ot: {:?}\nc: {}",
652                r_ot[i],
653                s_ot[i],
654                c[i].unwrap_u8()
655            );
656        }
657    }
658
659    #[tokio::test]
660    async fn correlated_silent_ot() {
661        let _g = init_tracing();
662        let (c1, c2) = local_conn().await.unwrap();
663
664        let mut sender = SemiHonestSilentOtSender::new(c1);
665        let mut receiver = SemiHonestSilentOtReceiver::new(c2);
666        let delta = Block::ONES;
667        let count = 2_usize.pow(11);
668
669        let (s_ot, (r_ot, choices)) = tokio::try_join!(
670            sender.correlated_send(count, delta),
671            receiver.correlated_receive(count)
672        )
673        .unwrap();
674
675        assert_eq!(s_ot.len(), count);
676        assert_eq!(r_ot.len(), count);
677
678        check_correlated(&r_ot, &s_ot, Some(&choices), delta);
679    }
680
681    #[tokio::test]
682    async fn random_silent_ot() {
683        let _g = init_tracing();
684        let (c1, c2) = local_conn().await.unwrap();
685
686        let mut sender = SemiHonestSilentOtSender::new(c1);
687        let mut receiver = SemiHonestSilentOtReceiver::new(c2);
688        let count = 2_usize.pow(11);
689
690        let (s_ot, (r_ot, choices)) =
691            tokio::try_join!(sender.random_send(count), receiver.random_receive(count)).unwrap();
692
693        check_random(count, &s_ot, &r_ot[..], &choices);
694    }
695
696    #[tokio::test]
697    async fn test_rot_trait_for_silent_ot() {
698        let _g = init_tracing();
699        let (c1, c2) = local_conn().await.unwrap();
700
701        let mut sender = SemiHonestSilentOtSender::new(c1);
702        let mut receiver = SemiHonestSilentOtReceiver::new(c2);
703        let count = 2_usize.pow(11);
704
705        let (s_ot, (r_ot, c)) =
706            tokio::try_join!(sender.send(count), receiver.rand_choice_receive(count)).unwrap();
707
708        check_random(count, &s_ot, &r_ot, &c);
709    }
710
711    #[tokio::test]
712    async fn test_malicious_silent_ot() {
713        let _g = init_tracing();
714        let (c1, c2) = local_conn().await.unwrap();
715
716        let mut sender = MaliciousSilentOtSender::new(c1);
717        let mut receiver = MaliciousSilentOtReceiver::new(c2);
718        let count = 2_usize.pow(11);
719
720        let (s_ot, (r_ot, choices)) =
721            tokio::try_join!(sender.random_send(count), receiver.random_receive(count)).unwrap();
722
723        check_random(count, &s_ot, &r_ot[..], &choices);
724    }
725
726    #[cfg(not(debug_assertions))]
727    #[tokio::test]
728    // This test, when run with RUST_LOG=info and --nocapture will print the
729    // communication for 2^18 silent OTs
730    async fn silent_ot_comm() {
731        let _g = init_tracing();
732        let (c1, c2) = local_conn().await.unwrap();
733
734        let mut sender = SemiHonestSilentOtSender::new(c1);
735        let mut receiver = SemiHonestSilentOtReceiver::new(c2);
736        let delta = Block::ONES;
737        let count = 2_usize.pow(18);
738
739        let (s_ot, (r_ot, choices)) = tokio::try_join!(
740            sender.correlated_send(count, delta),
741            receiver.correlated_receive(count, ChoiceBitPacking::Packed)
742        )
743        .unwrap();
744
745        assert_eq!(s_ot.len(), count);
746
747        check_correlated(&r_ot, &s_ot, choices.as_deref(), delta);
748    }
749}