1use std::{array, cmp::Ordering, io, mem};
3
4use aes::{
5 Aes128,
6 cipher::{BlockCipherEncrypt, KeyInit},
7};
8use bytemuck::{cast_slice, cast_slice_mut};
9use cryprot_core::{
10 AES_PAR_BLOCKS, Block,
11 aes_hash::FIXED_KEY_HASH,
12 aes_rng::AesRng,
13 alloc::allocate_zeroed_vec,
14 buf::Buf,
15 tokio_rayon::spawn_compute,
16 utils::{log2_ceil, xor_inplace},
17};
18use cryprot_net::{Connection, ConnectionError};
19use futures::{SinkExt, StreamExt};
20use ndarray::{Array2, ArrayView2};
21use rand::{CryptoRng, Rng, RngCore, SeedableRng, distr::Uniform, prelude::Distribution};
22use serde::{Deserialize, Serialize};
23use tokio::sync::mpsc::unbounded_channel;
24use tracing::Level;
25
26pub struct RegularPprfSender {
28 conn: Connection,
29 conf: PprfConfig,
30 base_ots: Array2<[Block; 2]>,
31}
32
33pub struct RegularPprfReceiver {
35 conn: Connection,
36 conf: PprfConfig,
37 base_ots: Array2<Block>,
38 base_choices: Array2<u8>,
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub enum OutFormat {
43 ByLeafIndex,
44 ByTreeIndex,
46 Interleaved,
47}
48
49#[derive(thiserror::Error, Debug)]
51pub enum Error {
52 #[error("unable to establish sub-stream to pprf peer")]
53 Connection(#[from] ConnectionError),
54 #[error("error in sending data to pprf peer")]
55 Send(#[source] io::Error),
56 #[error("error in receiving data from pprf peer")]
57 Receive(#[source] io::Error),
58}
59
60pub const PARALLEL_TREES: usize = AES_PAR_BLOCKS;
64
65pub const COMMUNICATION_PHASE: &str = "pprf-expansion";
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub struct PprfConfig {
71 pnt_count: usize,
72 domain: usize,
73 depth: usize,
74}
75
76impl PprfConfig {
77 pub fn new(domain: usize, pnt_count: usize) -> Self {
84 assert!(domain >= 2, "domain must be at least 2");
85 assert_eq!(0, domain % 2, "domain must be even");
86 assert_eq!(
87 0,
88 pnt_count % PARALLEL_TREES,
89 "pnt_count must be divisable by {PARALLEL_TREES}"
90 );
91 let depth = log2_ceil(domain);
92 Self {
93 pnt_count,
94 domain,
95 depth,
96 }
97 }
98
99 pub fn base_ot_count(&self) -> usize {
101 self.depth * self.pnt_count
102 }
103
104 pub fn pnt_count(&self) -> usize {
105 self.pnt_count
106 }
107
108 pub fn domain(&self) -> usize {
109 self.domain
110 }
111
112 pub fn depth(&self) -> usize {
113 self.depth
114 }
115
116 pub fn size(&self) -> usize {
117 self.domain() * self.pnt_count()
118 }
119
120 pub fn sample_choice_bits<R: RngCore + CryptoRng>(&self, rng: &mut R) -> Vec<u8> {
124 let mut choices = vec![0_u8; self.pnt_count() * self.depth()];
125 let dist = Uniform::new(0, self.domain()).expect("correct range");
126 for choice in choices.chunks_exact_mut(self.depth()) {
127 let mut idx = dist.sample(rng);
128 for choice_bit in choice {
129 *choice_bit = (idx & 1) as u8;
130 idx >>= 1;
131 }
132 }
133 choices
134 }
135
136 pub fn get_points(&self, out_fmt: OutFormat, base_choices: &[u8]) -> Vec<usize> {
137 match out_fmt {
138 OutFormat::Interleaved => {
139 let mut points = self.get_points(OutFormat::ByLeafIndex, base_choices);
140 for (i, point) in points.iter_mut().enumerate() {
141 *point = interleave_point(*point, i, self.domain())
142 }
143 points
144 }
145 OutFormat::ByLeafIndex => {
146 let base_choices =
147 ArrayView2::from_shape([self.pnt_count(), self.depth()], base_choices)
148 .expect("base_choices has wrong size for this conf");
149
150 base_choices
151 .rows()
152 .into_iter()
153 .map(|choice_bits| {
154 debug_assert_eq!(self.depth(), choice_bits.len());
155 let point = get_active_path(choice_bits.iter().copied());
156 debug_assert!(point < self.domain());
157 point
158 })
159 .collect()
160 }
161 _ => todo!(),
162 }
163 }
164}
165
166#[derive(Serialize, Deserialize, Default, Clone, Debug)]
167struct TreeGrp {
168 g: usize,
169 sums: [Vec<[Block; PARALLEL_TREES]>; 2],
170 last_ots: Vec<[Block; 4]>,
171}
172
173impl RegularPprfSender {
174 pub fn new_with_conf(conn: Connection, conf: PprfConfig, base_ots: Vec<[Block; 2]>) -> Self {
179 assert_eq!(
180 conf.base_ot_count(),
181 base_ots.len(),
182 "wrong number of base OTs"
183 );
184 let base_ots = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_ots)
185 .expect("base_ots.len() is checked before");
186 Self {
187 conn,
188 conf,
189 base_ots,
190 }
191 }
192
193 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = COMMUNICATION_PHASE))]
199 pub async fn expand(
200 mut self,
201 value: Block,
202 seed: Block,
203 out_fmt: OutFormat,
204 out: &mut impl Buf<Block>,
205 ) -> Result<(), Error> {
206 let size = self.conf.size();
207 let mut output = mem::take(out);
208 let (mut tx, _) = self.conn.stream().await?;
209 let (send, mut recv) = unbounded_channel();
210 let jh = spawn_compute(move || {
211 if output.len() < size {
212 output.grow_zeroed(size);
213 }
214 let aes = create_fixed_aes();
215 let depth = self.conf.depth();
216 let pnt_count = self.conf.pnt_count();
217 let domain = self.conf.domain();
218
219 let mut rng = AesRng::from_seed(seed);
220 let dd = match out_fmt {
221 OutFormat::Interleaved => depth,
222 _ => depth + 1,
223 };
224
225 let mut tree: Vec<[Block; PARALLEL_TREES]> = Vec::zeroed(2_usize.pow(dd as u32));
226
227 for g in (0..pnt_count).step_by(PARALLEL_TREES) {
228 let mut tree_grp = TreeGrp {
229 g,
230 ..Default::default()
231 };
232 let min = PARALLEL_TREES.min(pnt_count - g);
233 let level: &mut [u8] = cast_slice_mut(get_level(&mut tree, 0));
234 rng.fill_bytes(level);
235 tree_grp.sums[0].resize(depth, Default::default());
236 tree_grp.sums[1].resize(depth, Default::default());
237
238 for d in 0..depth {
239 let (lvl0, lvl1) = if out_fmt == OutFormat::Interleaved && d + 1 == depth {
240 (
241 get_level(&mut tree, d),
242 get_level_output(&mut output, g, domain),
243 )
244 } else {
245 get_cons_levels(&mut tree, d)
246 };
247
248 let width = lvl1.len();
249 let mut child_idx = 0;
250 while child_idx < width {
251 let parent_idx = child_idx >> 1;
252 let parent = &lvl0[parent_idx];
253 for (aes, sums) in aes.iter().zip(&mut tree_grp.sums) {
254 let child = &mut lvl1[child_idx];
255 let sum = &mut sums[d];
256 aes.encrypt_blocks_b2b(cast_slice(parent), cast_slice_mut(child))
257 .expect("parent and child have same len");
258 xor_inplace(child, parent);
259 xor_inplace(sum, child);
260 child_idx += 1;
261 }
262 }
263 }
264
265 let mut mask_sums = |idx: usize| {
266 for (d, sums) in tree_grp.sums[idx].iter_mut().take(depth - 1).enumerate() {
267 for (j, sum) in sums.iter_mut().enumerate().take(min) {
268 *sum ^= self.base_ots[(g + j, depth - 1 - d)][idx ^ 1];
269 }
270 }
271 };
272 mask_sums(0);
273 mask_sums(1);
274
275 let d = depth - 1;
276 tree_grp.last_ots.resize(min, Default::default());
277 for j in 0..min {
278 tree_grp.last_ots[j][0] = tree_grp.sums[0][d][j];
279 tree_grp.last_ots[j][1] = tree_grp.sums[1][d][j] ^ value;
280 tree_grp.last_ots[j][2] = tree_grp.sums[1][d][j];
281 tree_grp.last_ots[j][3] = tree_grp.sums[0][d][j] ^ value;
282
283 let mask_in = [
284 self.base_ots[(g + j, 0)][1],
285 self.base_ots[(g + j, 0)][1] ^ Block::ONES,
286 self.base_ots[(g + j, 0)][0],
287 self.base_ots[(g + j, 0)][0] ^ Block::ONES,
288 ];
289 let masks = FIXED_KEY_HASH.cr_hash_blocks(&mask_in);
290 xor_inplace(&mut tree_grp.last_ots[j], &masks);
291 }
292 tree_grp.sums[0].truncate(depth - 1);
293 tree_grp.sums[1].truncate(depth - 1);
294
295 if send.send(tree_grp).is_err() {
296 return output;
299 }
300 if out_fmt != OutFormat::Interleaved {
301 let last_lvl = get_level(&mut tree, depth);
302 copy_out(last_lvl, &mut output, g, out_fmt, self.conf);
303 }
304 }
305 output
306 });
307
308 while let Some(tree_group) = recv.recv().await {
309 tx.send(tree_group).await.map_err(Error::Send)?;
310 }
311
312 *out = jh.await.expect("panic in worker thread");
313 Ok(())
314 }
315}
316
317impl RegularPprfReceiver {
318 pub fn new_with_conf(
325 conn: Connection,
326 conf: PprfConfig,
327 base_ots: Vec<Block>,
328 base_choices: Vec<u8>,
329 ) -> Self {
330 assert_eq!(
331 conf.base_ot_count(),
332 base_ots.len(),
333 "wrong number of base OTs"
334 );
335 assert_eq!(
336 conf.base_ot_count(),
337 base_choices.len(),
338 "wrong number of base choices"
339 );
340 let base_ots = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_ots)
341 .expect("base_ots.len() is checked before");
342 let base_choices = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_choices)
343 .expect("base_ots.len() is checked before");
344 Self {
345 conn,
346 conf,
347 base_ots,
348 base_choices,
349 }
350 }
351
352 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = COMMUNICATION_PHASE))]
358 pub async fn expand(
359 mut self,
360 out_fmt: OutFormat,
361 out: &mut impl Buf<Block>,
362 ) -> Result<(), Error> {
363 let size = self.conf.size();
364 let mut output = mem::take(out);
365 let (_, mut rx) = self.conn.stream().await?;
366 let (send, recv) = std::sync::mpsc::channel();
367 let jh = spawn_compute(move || {
368 if output.len() < size {
369 output.grow_zeroed(size);
370 }
371 let aes = create_fixed_aes();
372 let points = self.conf.get_points(
373 OutFormat::ByLeafIndex,
374 self.base_choices
375 .as_slice()
376 .expect("array order is unchanged"),
377 );
378 let depth = self.conf.depth();
379 let pnt_count = self.conf.pnt_count();
380 let domain = self.conf.domain();
381 let dd = match out_fmt {
382 OutFormat::Interleaved => depth,
383 _ => depth + 1,
384 };
385 let mut tree: Vec<[Block; PARALLEL_TREES]> =
386 allocate_zeroed_vec(2_usize.pow(dd as u32));
387
388 for g in (0..pnt_count).step_by(PARALLEL_TREES) {
389 let Ok(tree_grp): Result<TreeGrp, _> = recv.recv() else {
390 return output;
392 };
393 assert_eq!(g, tree_grp.g);
394
395 if depth > 1 {
396 let lvl1 = get_level(&mut tree, 1);
397 #[allow(clippy::needless_range_loop)]
398 for i in 0..PARALLEL_TREES {
399 let active = self.base_choices[(i + g, depth - 1)] as usize;
400 lvl1[active ^ 1][i] =
401 self.base_ots[(i + g, depth - 1)] ^ tree_grp.sums[active ^ 1][0][i];
402 lvl1[active][i] = Block::ZERO;
403 }
404 }
405
406 let mut my_sums = [[Block::ZERO; PARALLEL_TREES]; 2];
407
408 for d in 1..depth {
409 let (lvl0, lvl1) = if out_fmt == OutFormat::Interleaved && d + 1 == depth {
410 (
411 get_level(&mut tree, d),
412 get_level_output(&mut output, g, domain),
413 )
414 } else {
415 get_cons_levels(&mut tree, d)
416 };
417
418 my_sums = [[Block::ZERO; PARALLEL_TREES]; 2];
419
420 let width = lvl1.len();
421 let mut child_idx = 0;
422 while child_idx < width {
423 let parent_idx = child_idx >> 1;
424 let parent = &lvl0[parent_idx];
425 for (aes, sum) in aes.iter().zip(&mut my_sums) {
426 let child = &mut lvl1[child_idx];
427 aes.encrypt_blocks_b2b(cast_slice(parent), cast_slice_mut(child))
428 .expect("parent and child have same len");
429 xor_inplace(child, parent);
430 xor_inplace(sum, child);
431 child_idx += 1;
432 }
433 }
434
435 if d != depth - 1 {
436 for i in 0..PARALLEL_TREES {
437 let leaf_idx = points[i + g];
438 let active_child_idx = leaf_idx >> (depth - 1 - d);
439 let inactive_child_idx = active_child_idx ^ 1;
440 let not_ai = inactive_child_idx & 1;
441 let inactive_child = &mut lvl1[inactive_child_idx][i];
442 let correct_sum = *inactive_child ^ tree_grp.sums[not_ai][d][i];
443 *inactive_child = correct_sum
444 ^ my_sums[not_ai][i]
445 ^ self.base_ots[(i + g, depth - 1 - d)];
446 }
447 }
448 }
449 let lvl = if out_fmt == OutFormat::Interleaved {
450 get_level_output(&mut output, g, domain)
451 } else {
452 get_level(&mut tree, depth)
453 };
454
455 for j in 0..PARALLEL_TREES {
456 let active_child_idx = points[j + g];
457 let inactive_child_idx = active_child_idx ^ 1;
458 let not_ai = inactive_child_idx & 1;
459
460 let mask_in = [
461 self.base_ots[(g + j, 0)],
462 self.base_ots[(g + j, 0)] ^ Block::ONES,
463 ];
464 let masks = FIXED_KEY_HASH.cr_hash_blocks(&mask_in);
465
466 let ots: [Block; 2] =
467 array::from_fn(|i| tree_grp.last_ots[j][2 * not_ai + i] ^ masks[i]);
468
469 let [inactive_child, active_child] =
470 get_inactive_active_child(j, lvl, inactive_child_idx, active_child_idx);
471
472 let inactive_sum = my_sums[not_ai][j] ^ *inactive_child;
475 let active_sum = my_sums[not_ai ^ 1][j] ^ *active_child;
476 *inactive_child = ots[0] ^ inactive_sum;
477 *active_child = ots[1] ^ active_sum;
478 }
479 if out_fmt != OutFormat::Interleaved {
480 let last_lvl = get_level(&mut tree, depth);
481 copy_out(last_lvl, &mut output, g, out_fmt, self.conf);
482 }
483 }
484 output
485 });
486
487 while let Some(tree_grp) = rx.next().await {
488 let tree_grp = tree_grp.map_err(Error::Receive)?;
489 if send.send(tree_grp).is_err() {
490 break;
492 }
493 }
494
495 *out = jh.await.expect("panic in worker thread");
496 Ok(())
497 }
498}
499
500fn get_level(tree: &mut [[Block; PARALLEL_TREES]], i: usize) -> &mut [[Block; PARALLEL_TREES]] {
504 let size = 1 << i;
505 let offset = size - 1;
506 &mut tree[offset..offset + size]
507}
508
509fn get_cons_levels(
511 tree: &mut [[Block; PARALLEL_TREES]],
512 i: usize,
513) -> (
514 &mut [[Block; PARALLEL_TREES]],
515 &mut [[Block; PARALLEL_TREES]],
516) {
517 let size0 = 1 << i;
518 let offset0 = size0 - 1;
519 let tree = &mut tree[offset0..];
520 let (level0, rest) = tree.split_at_mut(size0);
521 let size1 = 1 << (i + 1);
522 debug_assert_eq!(size0 + offset0, size1 - 1);
523 let level1 = &mut rest[..size1];
524 (level0, level1)
525}
526
527fn get_level_output(
528 output: &mut [Block],
529 tree_idx: usize,
530 domain: usize,
531) -> &mut [[Block; PARALLEL_TREES]] {
532 let out = cast_slice_mut(output);
533 let forest = tree_idx / PARALLEL_TREES;
534 debug_assert_eq!(tree_idx % PARALLEL_TREES, 0);
535 let start = forest * domain;
536 &mut out[start..start + domain]
537}
538
539fn get_active_path<I>(choice_bits: I) -> usize
540where
541 I: Iterator<Item = u8> + ExactSizeIterator,
542{
543 choice_bits
544 .enumerate()
545 .fold(0, |point, (i, cb)| point | ((cb as usize) << i))
546}
547
548fn get_inactive_active_child(
549 tree: usize,
550 lvl: &mut [[Block; PARALLEL_TREES]],
551 inactive_child_idx: usize,
552 active_child_idx: usize,
553) -> [&mut Block; 2] {
554 let children = match active_child_idx.cmp(&inactive_child_idx) {
555 Ordering::Less => {
556 let (left, right) = lvl.split_at_mut(inactive_child_idx);
557 [&mut right[0], &mut left[active_child_idx]]
558 }
559 Ordering::Greater => {
560 let (left, right) = lvl.split_at_mut(active_child_idx);
561 [&mut left[inactive_child_idx], &mut right[0]]
562 }
563 Ordering::Equal => {
564 unreachable!("Impossible, active and inactive indices are always different")
565 }
566 };
567 children.map(|arr| &mut arr[tree])
568}
569
570fn interleave_point(point: usize, tree_idx: usize, domain: usize) -> usize {
571 let sub_tree = tree_idx % PARALLEL_TREES;
572 let forest = tree_idx / PARALLEL_TREES;
573 (forest * domain + point) * PARALLEL_TREES + sub_tree
574}
575
576fn copy_out(
577 last_lvl: &[[Block; PARALLEL_TREES]],
578 output: &mut [Block],
579 tree_idx: usize,
580 out_fmt: OutFormat,
581 conf: PprfConfig,
582) {
583 let total_trees = conf.pnt_count();
584 let curr_size = PARALLEL_TREES.min(total_trees - tree_idx);
585 let last_lvl: &[Block] = cast_slice(last_lvl);
586 let domain = conf.domain();
588 match out_fmt {
589 OutFormat::ByLeafIndex => {
590 for leaf_idx in 0..domain {
591 let o_idx = total_trees * leaf_idx + tree_idx;
592 let i_idx = leaf_idx * PARALLEL_TREES;
593 output[o_idx..curr_size + o_idx]
595 .copy_from_slice(&last_lvl[i_idx..curr_size + i_idx]);
596 }
597 }
598 OutFormat::ByTreeIndex => todo!(),
599 OutFormat::Interleaved => panic!("Do not copy_out for OutFormat::Interleaved"),
600 }
601}
602
603fn create_fixed_aes() -> [Aes128; 2] {
605 [
606 Aes128::new(
607 &91389970179024809574621370423327856399_u128
608 .to_le_bytes()
609 .into(),
610 ),
611 Aes128::new(
612 &297966570818470707816499469807199042980_u128
613 .to_le_bytes()
614 .into(),
615 ),
616 ]
617}
618
619#[doc(hidden)]
622pub fn fake_base<R: RngCore + CryptoRng>(
623 conf: PprfConfig,
624 rng: &mut R,
625) -> (Vec<[Block; 2]>, Vec<Block>, Vec<u8>) {
626 let base_ot_count = conf.base_ot_count();
627 let msg2: Vec<[Block; 2]> = (0..base_ot_count).map(|_| rng.random()).collect();
628 let choices = conf.sample_choice_bits(rng);
629 let msg = msg2
630 .iter()
631 .zip(choices.iter())
632 .map(|(m, c)| m[*c as usize])
633 .collect();
634 (msg2, msg, choices)
635}
636
637#[cfg(test)]
638mod tests {
639 use cryprot_core::{Block, alloc::HugePageMemory, buf::Buf, utils::xor_inplace};
640 use cryprot_net::testing::local_conn;
641 use rand::{Rng, SeedableRng, rngs::StdRng};
642
643 use crate::{
644 OutFormat, PARALLEL_TREES, PprfConfig, RegularPprfReceiver, RegularPprfSender, fake_base,
645 };
646
647 #[tokio::test]
648 async fn test_pprf_by_leaf() {
649 let conf = PprfConfig::new(334, 5 * PARALLEL_TREES);
650 let out_fmt = OutFormat::ByLeafIndex;
651 let mut rng = StdRng::seed_from_u64(42);
652
653 let (c1, c2) = local_conn().await.unwrap();
654 let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
655 let points = conf.get_points(out_fmt, &base_choices);
656
657 let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
658 let receiver =
659 RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
660 eprintln!("{points:?}");
661 let mut s_out = HugePageMemory::zeroed(conf.size());
662 let mut r_out = HugePageMemory::zeroed(conf.size());
663 let seed = rng.random();
664 tokio::try_join!(
665 sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
666 receiver.expand(out_fmt, &mut r_out)
667 )
668 .unwrap();
669
670 xor_inplace(&mut s_out, &r_out);
671
672 for j in 0..points.len() {
673 for i in 0..conf.domain() {
674 let idx = i * points.len() + j;
675
676 let exp = if points[j] == i {
677 Block::ONES
678 } else {
679 Block::ZERO
680 };
681 assert_eq!(exp, s_out[idx]);
682 }
683 }
684 }
685
686 #[tokio::test]
687 async fn test_pprf_interleaved_simple() {
688 let conf = PprfConfig::new(2, PARALLEL_TREES);
690 let out_fmt = OutFormat::Interleaved;
691 let mut rng = StdRng::seed_from_u64(42);
692
693 let (c1, c2) = local_conn().await.unwrap();
694 let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
695 let points = conf.get_points(out_fmt, &base_choices);
696
697 println!("Base choices: {:?}", base_choices);
701
702 let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
703 let receiver =
704 RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
705 println!("Points: {:?}", points);
706 let mut s_out = Vec::zeroed(conf.size());
707 let mut r_out = Vec::zeroed(conf.size());
708 let seed = rng.random();
709 tokio::try_join!(
710 sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
711 receiver.expand(out_fmt, &mut r_out)
712 )
713 .unwrap();
714
715 xor_inplace(&mut s_out, &r_out);
716 println!("XORed output: {:?}", s_out);
717 for (i, blk) in s_out.iter().enumerate() {
718 let f = points.contains(&i);
719 let exp = if f { Block::ONES } else { Block::ZERO };
720 assert_eq!(exp, *blk, "block {i} not as expected");
721 }
722 }
723
724 #[tokio::test]
725 async fn test_pprf_interleaved() {
726 let conf = PprfConfig::new(334, 5 * PARALLEL_TREES);
727 let out_fmt = OutFormat::Interleaved;
728 let mut rng = StdRng::seed_from_u64(42);
729
730 let (c1, c2) = local_conn().await.unwrap();
731 let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
732 let points = conf.get_points(out_fmt, &base_choices);
733
734 let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
735 let receiver =
736 RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
737 let mut s_out = HugePageMemory::zeroed(conf.size());
738 let mut r_out = HugePageMemory::zeroed(conf.size());
739 let seed = rng.random();
740 tokio::try_join!(
741 sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
742 receiver.expand(out_fmt, &mut r_out)
743 )
744 .unwrap();
745
746 xor_inplace(&mut s_out, &r_out);
747 for (i, blk) in s_out.iter().enumerate() {
748 let f = points.contains(&i);
749 let exp = if f { Block::ONES } else { Block::ZERO };
750 assert_eq!(exp, *blk, "block {i} not as expected");
751 }
752 }
753}