cryprot_pprf/
lib.rs

1//! Distributed Puncturable Pseudorandom Function (PPRF) Implementation.
2use std::{array, cmp::Ordering, io, mem};
3
4use aes::{
5    Aes128,
6    cipher::{BlockCipherEncrypt, KeyInit},
7};
8use bytemuck::{cast_slice, cast_slice_mut};
9use cryprot_core::{
10    AES_PAR_BLOCKS, Block,
11    aes_hash::FIXED_KEY_HASH,
12    aes_rng::AesRng,
13    alloc::allocate_zeroed_vec,
14    buf::Buf,
15    tokio_rayon::spawn_compute,
16    utils::{log2_ceil, xor_inplace},
17};
18use cryprot_net::{Connection, ConnectionError};
19use futures::{SinkExt, StreamExt};
20use ndarray::{Array2, ArrayView2};
21use rand::{CryptoRng, Rng, RngCore, SeedableRng, distr::Uniform, prelude::Distribution};
22use serde::{Deserialize, Serialize};
23use tokio::sync::mpsc::unbounded_channel;
24use tracing::Level;
25
26/// Sender for PPRF expansion.
27pub struct RegularPprfSender {
28    conn: Connection,
29    conf: PprfConfig,
30    base_ots: Array2<[Block; 2]>,
31}
32
33/// Receiver for the PPRF expansion.
34pub struct RegularPprfReceiver {
35    conn: Connection,
36    conf: PprfConfig,
37    base_ots: Array2<Block>,
38    base_choices: Array2<u8>,
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub enum OutFormat {
43    ByLeafIndex,
44    /// TODO: currently unimplemented
45    ByTreeIndex,
46    Interleaved,
47}
48
49/// Errors returned by the PPRF expansion.
50#[derive(thiserror::Error, Debug)]
51pub enum Error {
52    #[error("unable to establish sub-stream to pprf peer")]
53    Connection(#[from] ConnectionError),
54    #[error("error in sending data to pprf peer")]
55    Send(#[source] io::Error),
56    #[error("error in receiving data from pprf peer")]
57    Receive(#[source] io::Error),
58}
59
60/// Number of trees that are expanded in parallel using AES ILP.
61///
62/// The concrete value depends on the target architecture.
63pub const PARALLEL_TREES: usize = AES_PAR_BLOCKS;
64
65/// Communication phase for [`cryprot_net::metrics`].
66pub const COMMUNICATION_PHASE: &str = "pprf-expansion";
67
68/// Config for a PPRF expansion.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub struct PprfConfig {
71    pnt_count: usize,
72    domain: usize,
73    depth: usize,
74}
75
76impl PprfConfig {
77    /// Create a PprfConfig
78    ///
79    /// # Panics
80    /// - if `domain < 2`
81    /// - if `domain % 2 != 0`
82    /// - if `pnt_count % `[`PARALLEL_TREES`]` != 0`
83    pub fn new(domain: usize, pnt_count: usize) -> Self {
84        assert!(domain >= 2, "domain must be at least 2");
85        assert_eq!(0, domain % 2, "domain must be even");
86        assert_eq!(
87            0,
88            pnt_count % PARALLEL_TREES,
89            "pnt_count must be divisable by {PARALLEL_TREES}"
90        );
91        let depth = log2_ceil(domain);
92        Self {
93            pnt_count,
94            domain,
95            depth,
96        }
97    }
98
99    /// Number of base OTs needed for the configured PPRF expansion.
100    pub fn base_ot_count(&self) -> usize {
101        self.depth * self.pnt_count
102    }
103
104    pub fn pnt_count(&self) -> usize {
105        self.pnt_count
106    }
107
108    pub fn domain(&self) -> usize {
109        self.domain
110    }
111
112    pub fn depth(&self) -> usize {
113        self.depth
114    }
115
116    pub fn size(&self) -> usize {
117        self.domain() * self.pnt_count()
118    }
119
120    /// Base OT choice bits needed for this PPRF expansion.
121    ///
122    /// Every `u8` element is either `0` or `1`.
123    pub fn sample_choice_bits<R: RngCore + CryptoRng>(&self, rng: &mut R) -> Vec<u8> {
124        let mut choices = vec![0_u8; self.pnt_count() * self.depth()];
125        let dist = Uniform::new(0, self.domain()).expect("correct range");
126        for choice in choices.chunks_exact_mut(self.depth()) {
127            let mut idx = dist.sample(rng);
128            for choice_bit in choice {
129                *choice_bit = (idx & 1) as u8;
130                idx >>= 1;
131            }
132        }
133        choices
134    }
135
136    pub fn get_points(&self, out_fmt: OutFormat, base_choices: &[u8]) -> Vec<usize> {
137        match out_fmt {
138            OutFormat::Interleaved => {
139                let mut points = self.get_points(OutFormat::ByLeafIndex, base_choices);
140                for (i, point) in points.iter_mut().enumerate() {
141                    *point = interleave_point(*point, i, self.domain())
142                }
143                points
144            }
145            OutFormat::ByLeafIndex => {
146                let base_choices =
147                    ArrayView2::from_shape([self.pnt_count(), self.depth()], base_choices)
148                        .expect("base_choices has wrong size for this conf");
149
150                base_choices
151                    .rows()
152                    .into_iter()
153                    .map(|choice_bits| {
154                        debug_assert_eq!(self.depth(), choice_bits.len());
155                        let point = get_active_path(choice_bits.iter().copied());
156                        debug_assert!(point < self.domain());
157                        point
158                    })
159                    .collect()
160            }
161            _ => todo!(),
162        }
163    }
164}
165
166#[derive(Serialize, Deserialize, Default, Clone, Debug)]
167struct TreeGrp {
168    g: usize,
169    sums: [Vec<[Block; PARALLEL_TREES]>; 2],
170    last_ots: Vec<[Block; 4]>,
171}
172
173impl RegularPprfSender {
174    /// Create a new `RegularPprfSender`.
175    ///
176    /// # Panics
177    /// If `base_ots.len() != conf.base_ot_count()`.
178    pub fn new_with_conf(conn: Connection, conf: PprfConfig, base_ots: Vec<[Block; 2]>) -> Self {
179        assert_eq!(
180            conf.base_ot_count(),
181            base_ots.len(),
182            "wrong number of base OTs"
183        );
184        let base_ots = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_ots)
185            .expect("base_ots.len() is checked before");
186        Self {
187            conn,
188            conf,
189            base_ots,
190        }
191    }
192
193    /// Expand the PPRF into out.
194    ///
195    /// Note that this method temporarily moves out the buffer pointed to by
196    /// `out`. If the future returned by `expand` is dropped or panics, out
197    /// might point to a different buffer.
198    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = COMMUNICATION_PHASE))]
199    pub async fn expand(
200        mut self,
201        value: Block,
202        seed: Block,
203        out_fmt: OutFormat,
204        out: &mut impl Buf<Block>,
205    ) -> Result<(), Error> {
206        let size = self.conf.size();
207        let mut output = mem::take(out);
208        let (mut tx, _) = self.conn.stream().await?;
209        let (send, mut recv) = unbounded_channel();
210        let jh = spawn_compute(move || {
211            if output.len() < size {
212                output.grow_zeroed(size);
213            }
214            let aes = create_fixed_aes();
215            let depth = self.conf.depth();
216            let pnt_count = self.conf.pnt_count();
217            let domain = self.conf.domain();
218
219            let mut rng = AesRng::from_seed(seed);
220            let dd = match out_fmt {
221                OutFormat::Interleaved => depth,
222                _ => depth + 1,
223            };
224
225            let mut tree: Vec<[Block; PARALLEL_TREES]> = Vec::zeroed(2_usize.pow(dd as u32));
226
227            for g in (0..pnt_count).step_by(PARALLEL_TREES) {
228                let mut tree_grp = TreeGrp {
229                    g,
230                    ..Default::default()
231                };
232                let min = PARALLEL_TREES.min(pnt_count - g);
233                let level: &mut [u8] = cast_slice_mut(get_level(&mut tree, 0));
234                rng.fill_bytes(level);
235                tree_grp.sums[0].resize(depth, Default::default());
236                tree_grp.sums[1].resize(depth, Default::default());
237
238                for d in 0..depth {
239                    let (lvl0, lvl1) = if out_fmt == OutFormat::Interleaved && d + 1 == depth {
240                        (
241                            get_level(&mut tree, d),
242                            get_level_output(&mut output, g, domain),
243                        )
244                    } else {
245                        get_cons_levels(&mut tree, d)
246                    };
247
248                    let width = lvl1.len();
249                    let mut child_idx = 0;
250                    while child_idx < width {
251                        let parent_idx = child_idx >> 1;
252                        let parent = &lvl0[parent_idx];
253                        for (aes, sums) in aes.iter().zip(&mut tree_grp.sums) {
254                            let child = &mut lvl1[child_idx];
255                            let sum = &mut sums[d];
256                            aes.encrypt_blocks_b2b(cast_slice(parent), cast_slice_mut(child))
257                                .expect("parent and child have same len");
258                            xor_inplace(child, parent);
259                            xor_inplace(sum, child);
260                            child_idx += 1;
261                        }
262                    }
263                }
264
265                let mut mask_sums = |idx: usize| {
266                    for (d, sums) in tree_grp.sums[idx].iter_mut().take(depth - 1).enumerate() {
267                        for (j, sum) in sums.iter_mut().enumerate().take(min) {
268                            *sum ^= self.base_ots[(g + j, depth - 1 - d)][idx ^ 1];
269                        }
270                    }
271                };
272                mask_sums(0);
273                mask_sums(1);
274
275                let d = depth - 1;
276                tree_grp.last_ots.resize(min, Default::default());
277                for j in 0..min {
278                    tree_grp.last_ots[j][0] = tree_grp.sums[0][d][j];
279                    tree_grp.last_ots[j][1] = tree_grp.sums[1][d][j] ^ value;
280                    tree_grp.last_ots[j][2] = tree_grp.sums[1][d][j];
281                    tree_grp.last_ots[j][3] = tree_grp.sums[0][d][j] ^ value;
282
283                    let mask_in = [
284                        self.base_ots[(g + j, 0)][1],
285                        self.base_ots[(g + j, 0)][1] ^ Block::ONES,
286                        self.base_ots[(g + j, 0)][0],
287                        self.base_ots[(g + j, 0)][0] ^ Block::ONES,
288                    ];
289                    let masks = FIXED_KEY_HASH.cr_hash_blocks(&mask_in);
290                    xor_inplace(&mut tree_grp.last_ots[j], &masks);
291                }
292                tree_grp.sums[0].truncate(depth - 1);
293                tree_grp.sums[1].truncate(depth - 1);
294
295                if send.send(tree_grp).is_err() {
296                    // receiver in async task is dropped, so we stop compute task by returning.
297                    // output will be dropped and not put back into initial &mut out parameter
298                    return output;
299                }
300                if out_fmt != OutFormat::Interleaved {
301                    let last_lvl = get_level(&mut tree, depth);
302                    copy_out(last_lvl, &mut output, g, out_fmt, self.conf);
303                }
304            }
305            output
306        });
307
308        while let Some(tree_group) = recv.recv().await {
309            tx.send(tree_group).await.map_err(Error::Send)?;
310        }
311
312        *out = jh.await.expect("panic in worker thread");
313        Ok(())
314    }
315}
316
317impl RegularPprfReceiver {
318    /// Create a new `RegularPprfReceiver`.
319    ///
320    /// # Panics
321    /// If:
322    /// - `base_ots.len() != conf.base_ot_count()` or
323    /// - `base_choices.len() != conf.base_ot_count()` or
324    pub fn new_with_conf(
325        conn: Connection,
326        conf: PprfConfig,
327        base_ots: Vec<Block>,
328        base_choices: Vec<u8>,
329    ) -> Self {
330        assert_eq!(
331            conf.base_ot_count(),
332            base_ots.len(),
333            "wrong number of base OTs"
334        );
335        assert_eq!(
336            conf.base_ot_count(),
337            base_choices.len(),
338            "wrong number of base choices"
339        );
340        let base_ots = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_ots)
341            .expect("base_ots.len() is checked before");
342        let base_choices = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_choices)
343            .expect("base_ots.len() is checked before");
344        Self {
345            conn,
346            conf,
347            base_ots,
348            base_choices,
349        }
350    }
351
352    /// Expand the PPRF into out.
353    ///
354    /// Note that this method temporarily moves out the buffer pointed to by
355    /// `out`. If the future returned by `expand` is dropped or panics, out
356    /// might point to a different buffer.
357    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = COMMUNICATION_PHASE))]
358    pub async fn expand(
359        mut self,
360        out_fmt: OutFormat,
361        out: &mut impl Buf<Block>,
362    ) -> Result<(), Error> {
363        let size = self.conf.size();
364        let mut output = mem::take(out);
365        let (_, mut rx) = self.conn.stream().await?;
366        let (send, recv) = std::sync::mpsc::channel();
367        let jh = spawn_compute(move || {
368            if output.len() < size {
369                output.grow_zeroed(size);
370            }
371            let aes = create_fixed_aes();
372            let points = self.conf.get_points(
373                OutFormat::ByLeafIndex,
374                self.base_choices
375                    .as_slice()
376                    .expect("array order is unchanged"),
377            );
378            let depth = self.conf.depth();
379            let pnt_count = self.conf.pnt_count();
380            let domain = self.conf.domain();
381            let dd = match out_fmt {
382                OutFormat::Interleaved => depth,
383                _ => depth + 1,
384            };
385            let mut tree: Vec<[Block; PARALLEL_TREES]> =
386                allocate_zeroed_vec(2_usize.pow(dd as u32));
387
388            for g in (0..pnt_count).step_by(PARALLEL_TREES) {
389                let Ok(tree_grp): Result<TreeGrp, _> = recv.recv() else {
390                    // Async task is dropped, so we simply return the output which will be dropped
391                    return output;
392                };
393                assert_eq!(g, tree_grp.g);
394
395                if depth > 1 {
396                    let lvl1 = get_level(&mut tree, 1);
397                    #[allow(clippy::needless_range_loop)]
398                    for i in 0..PARALLEL_TREES {
399                        let active = self.base_choices[(i + g, depth - 1)] as usize;
400                        lvl1[active ^ 1][i] =
401                            self.base_ots[(i + g, depth - 1)] ^ tree_grp.sums[active ^ 1][0][i];
402                        lvl1[active][i] = Block::ZERO;
403                    }
404                }
405
406                let mut my_sums = [[Block::ZERO; PARALLEL_TREES]; 2];
407
408                for d in 1..depth {
409                    let (lvl0, lvl1) = if out_fmt == OutFormat::Interleaved && d + 1 == depth {
410                        (
411                            get_level(&mut tree, d),
412                            get_level_output(&mut output, g, domain),
413                        )
414                    } else {
415                        get_cons_levels(&mut tree, d)
416                    };
417
418                    my_sums = [[Block::ZERO; PARALLEL_TREES]; 2];
419
420                    let width = lvl1.len();
421                    let mut child_idx = 0;
422                    while child_idx < width {
423                        let parent_idx = child_idx >> 1;
424                        let parent = &lvl0[parent_idx];
425                        for (aes, sum) in aes.iter().zip(&mut my_sums) {
426                            let child = &mut lvl1[child_idx];
427                            aes.encrypt_blocks_b2b(cast_slice(parent), cast_slice_mut(child))
428                                .expect("parent and child have same len");
429                            xor_inplace(child, parent);
430                            xor_inplace(sum, child);
431                            child_idx += 1;
432                        }
433                    }
434
435                    if d != depth - 1 {
436                        for i in 0..PARALLEL_TREES {
437                            let leaf_idx = points[i + g];
438                            let active_child_idx = leaf_idx >> (depth - 1 - d);
439                            let inactive_child_idx = active_child_idx ^ 1;
440                            let not_ai = inactive_child_idx & 1;
441                            let inactive_child = &mut lvl1[inactive_child_idx][i];
442                            let correct_sum = *inactive_child ^ tree_grp.sums[not_ai][d][i];
443                            *inactive_child = correct_sum
444                                ^ my_sums[not_ai][i]
445                                ^ self.base_ots[(i + g, depth - 1 - d)];
446                        }
447                    }
448                }
449                let lvl = if out_fmt == OutFormat::Interleaved {
450                    get_level_output(&mut output, g, domain)
451                } else {
452                    get_level(&mut tree, depth)
453                };
454
455                for j in 0..PARALLEL_TREES {
456                    let active_child_idx = points[j + g];
457                    let inactive_child_idx = active_child_idx ^ 1;
458                    let not_ai = inactive_child_idx & 1;
459
460                    let mask_in = [
461                        self.base_ots[(g + j, 0)],
462                        self.base_ots[(g + j, 0)] ^ Block::ONES,
463                    ];
464                    let masks = FIXED_KEY_HASH.cr_hash_blocks(&mask_in);
465
466                    let ots: [Block; 2] =
467                        array::from_fn(|i| tree_grp.last_ots[j][2 * not_ai + i] ^ masks[i]);
468
469                    let [inactive_child, active_child] =
470                        get_inactive_active_child(j, lvl, inactive_child_idx, active_child_idx);
471
472                    // Fix the sums we computed previously to not include the
473                    // incorrect child values.
474                    let inactive_sum = my_sums[not_ai][j] ^ *inactive_child;
475                    let active_sum = my_sums[not_ai ^ 1][j] ^ *active_child;
476                    *inactive_child = ots[0] ^ inactive_sum;
477                    *active_child = ots[1] ^ active_sum;
478                }
479                if out_fmt != OutFormat::Interleaved {
480                    let last_lvl = get_level(&mut tree, depth);
481                    copy_out(last_lvl, &mut output, g, out_fmt, self.conf);
482                }
483            }
484            output
485        });
486
487        while let Some(tree_grp) = rx.next().await {
488            let tree_grp = tree_grp.map_err(Error::Receive)?;
489            if send.send(tree_grp).is_err() {
490                // panic in the worker thread, so we break from receiving more data
491                break;
492            }
493        }
494
495        *out = jh.await.expect("panic in worker thread");
496        Ok(())
497    }
498}
499
500// Returns the i'th level of the current PARALLEL_TREES trees. The
501// children of node j on level i are located at 2*j and
502// 2*j+1  on level i+1.
503fn get_level(tree: &mut [[Block; PARALLEL_TREES]], i: usize) -> &mut [[Block; PARALLEL_TREES]] {
504    let size = 1 << i;
505    let offset = size - 1;
506    &mut tree[offset..offset + size]
507}
508
509// Return the i'th and (i+1)'th level
510fn get_cons_levels(
511    tree: &mut [[Block; PARALLEL_TREES]],
512    i: usize,
513) -> (
514    &mut [[Block; PARALLEL_TREES]],
515    &mut [[Block; PARALLEL_TREES]],
516) {
517    let size0 = 1 << i;
518    let offset0 = size0 - 1;
519    let tree = &mut tree[offset0..];
520    let (level0, rest) = tree.split_at_mut(size0);
521    let size1 = 1 << (i + 1);
522    debug_assert_eq!(size0 + offset0, size1 - 1);
523    let level1 = &mut rest[..size1];
524    (level0, level1)
525}
526
527fn get_level_output(
528    output: &mut [Block],
529    tree_idx: usize,
530    domain: usize,
531) -> &mut [[Block; PARALLEL_TREES]] {
532    let out = cast_slice_mut(output);
533    let forest = tree_idx / PARALLEL_TREES;
534    debug_assert_eq!(tree_idx % PARALLEL_TREES, 0);
535    let start = forest * domain;
536    &mut out[start..start + domain]
537}
538
539fn get_active_path<I>(choice_bits: I) -> usize
540where
541    I: Iterator<Item = u8> + ExactSizeIterator,
542{
543    choice_bits
544        .enumerate()
545        .fold(0, |point, (i, cb)| point | ((cb as usize) << i))
546}
547
548fn get_inactive_active_child(
549    tree: usize,
550    lvl: &mut [[Block; PARALLEL_TREES]],
551    inactive_child_idx: usize,
552    active_child_idx: usize,
553) -> [&mut Block; 2] {
554    let children = match active_child_idx.cmp(&inactive_child_idx) {
555        Ordering::Less => {
556            let (left, right) = lvl.split_at_mut(inactive_child_idx);
557            [&mut right[0], &mut left[active_child_idx]]
558        }
559        Ordering::Greater => {
560            let (left, right) = lvl.split_at_mut(active_child_idx);
561            [&mut left[inactive_child_idx], &mut right[0]]
562        }
563        Ordering::Equal => {
564            unreachable!("Impossible, active and inactive indices are always different")
565        }
566    };
567    children.map(|arr| &mut arr[tree])
568}
569
570fn interleave_point(point: usize, tree_idx: usize, domain: usize) -> usize {
571    let sub_tree = tree_idx % PARALLEL_TREES;
572    let forest = tree_idx / PARALLEL_TREES;
573    (forest * domain + point) * PARALLEL_TREES + sub_tree
574}
575
576fn copy_out(
577    last_lvl: &[[Block; PARALLEL_TREES]],
578    output: &mut [Block],
579    tree_idx: usize,
580    out_fmt: OutFormat,
581    conf: PprfConfig,
582) {
583    let total_trees = conf.pnt_count();
584    let curr_size = PARALLEL_TREES.min(total_trees - tree_idx);
585    let last_lvl: &[Block] = cast_slice(last_lvl);
586    // assert_eq!(conf.domain(), last_lvl.len() / PARALLEL_TREES);
587    let domain = conf.domain();
588    match out_fmt {
589        OutFormat::ByLeafIndex => {
590            for leaf_idx in 0..domain {
591                let o_idx = total_trees * leaf_idx + tree_idx;
592                let i_idx = leaf_idx * PARALLEL_TREES;
593                // todo copy from slice
594                output[o_idx..curr_size + o_idx]
595                    .copy_from_slice(&last_lvl[i_idx..curr_size + i_idx]);
596            }
597        }
598        OutFormat::ByTreeIndex => todo!(),
599        OutFormat::Interleaved => panic!("Do not copy_out for OutFormat::Interleaved"),
600    }
601}
602
603// Create a pair of fixed key aes128 ciphers
604fn create_fixed_aes() -> [Aes128; 2] {
605    [
606        Aes128::new(
607            &91389970179024809574621370423327856399_u128
608                .to_le_bytes()
609                .into(),
610        ),
611        Aes128::new(
612            &297966570818470707816499469807199042980_u128
613                .to_le_bytes()
614                .into(),
615        ),
616    ]
617}
618
619/// Intended for testing. Generates suitable OTs and choice bits for a pprf
620/// evaluation.
621#[doc(hidden)]
622pub fn fake_base<R: RngCore + CryptoRng>(
623    conf: PprfConfig,
624    rng: &mut R,
625) -> (Vec<[Block; 2]>, Vec<Block>, Vec<u8>) {
626    let base_ot_count = conf.base_ot_count();
627    let msg2: Vec<[Block; 2]> = (0..base_ot_count).map(|_| rng.random()).collect();
628    let choices = conf.sample_choice_bits(rng);
629    let msg = msg2
630        .iter()
631        .zip(choices.iter())
632        .map(|(m, c)| m[*c as usize])
633        .collect();
634    (msg2, msg, choices)
635}
636
637#[cfg(test)]
638mod tests {
639    use cryprot_core::{Block, alloc::HugePageMemory, buf::Buf, utils::xor_inplace};
640    use cryprot_net::testing::local_conn;
641    use rand::{Rng, SeedableRng, rngs::StdRng};
642
643    use crate::{
644        OutFormat, PARALLEL_TREES, PprfConfig, RegularPprfReceiver, RegularPprfSender, fake_base,
645    };
646
647    #[tokio::test]
648    async fn test_pprf_by_leaf() {
649        let conf = PprfConfig::new(334, 5 * PARALLEL_TREES);
650        let out_fmt = OutFormat::ByLeafIndex;
651        let mut rng = StdRng::seed_from_u64(42);
652
653        let (c1, c2) = local_conn().await.unwrap();
654        let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
655        let points = conf.get_points(out_fmt, &base_choices);
656
657        let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
658        let receiver =
659            RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
660        eprintln!("{points:?}");
661        let mut s_out = HugePageMemory::zeroed(conf.size());
662        let mut r_out = HugePageMemory::zeroed(conf.size());
663        let seed = rng.random();
664        tokio::try_join!(
665            sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
666            receiver.expand(out_fmt, &mut r_out)
667        )
668        .unwrap();
669
670        xor_inplace(&mut s_out, &r_out);
671
672        for j in 0..points.len() {
673            for i in 0..conf.domain() {
674                let idx = i * points.len() + j;
675
676                let exp = if points[j] == i {
677                    Block::ONES
678                } else {
679                    Block::ZERO
680                };
681                assert_eq!(exp, s_out[idx]);
682            }
683        }
684    }
685
686    #[tokio::test]
687    async fn test_pprf_interleaved_simple() {
688        // Reduce size to minimum to debug
689        let conf = PprfConfig::new(2, PARALLEL_TREES);
690        let out_fmt = OutFormat::Interleaved;
691        let mut rng = StdRng::seed_from_u64(42);
692
693        let (c1, c2) = local_conn().await.unwrap();
694        let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
695        let points = conf.get_points(out_fmt, &base_choices);
696
697        // Print the base OTs to see correlation
698        // println!("Sender base OTs: {:?}", sender_base_ots);
699        // println!("Receiver base OTs: {:?}", receiver_base_ots);
700        println!("Base choices: {:?}", base_choices);
701
702        let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
703        let receiver =
704            RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
705        println!("Points: {:?}", points);
706        let mut s_out = Vec::zeroed(conf.size());
707        let mut r_out = Vec::zeroed(conf.size());
708        let seed = rng.random();
709        tokio::try_join!(
710            sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
711            receiver.expand(out_fmt, &mut r_out)
712        )
713        .unwrap();
714
715        xor_inplace(&mut s_out, &r_out);
716        println!("XORed output: {:?}", s_out);
717        for (i, blk) in s_out.iter().enumerate() {
718            let f = points.contains(&i);
719            let exp = if f { Block::ONES } else { Block::ZERO };
720            assert_eq!(exp, *blk, "block {i} not as expected");
721        }
722    }
723
724    #[tokio::test]
725    async fn test_pprf_interleaved() {
726        let conf = PprfConfig::new(334, 5 * PARALLEL_TREES);
727        let out_fmt = OutFormat::Interleaved;
728        let mut rng = StdRng::seed_from_u64(42);
729
730        let (c1, c2) = local_conn().await.unwrap();
731        let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
732        let points = conf.get_points(out_fmt, &base_choices);
733
734        let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
735        let receiver =
736            RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
737        let mut s_out = HugePageMemory::zeroed(conf.size());
738        let mut r_out = HugePageMemory::zeroed(conf.size());
739        let seed = rng.random();
740        tokio::try_join!(
741            sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
742            receiver.expand(out_fmt, &mut r_out)
743        )
744        .unwrap();
745
746        xor_inplace(&mut s_out, &r_out);
747        for (i, blk) in s_out.iter().enumerate() {
748            let f = points.contains(&i);
749            let exp = if f { Block::ONES } else { Block::ZERO };
750            assert_eq!(exp, *blk, "block {i} not as expected");
751        }
752    }
753}