cryprot_pprf/
lib.rs

1//! Distributed Puncturable Pseudorandom Function (PPRF) Implementation.
2use std::{array, cmp::Ordering, io, mem};
3
4use aes::{
5    Aes128,
6    cipher::{BlockCipherEncrypt, KeyInit},
7};
8use bytemuck::{cast_slice, cast_slice_mut};
9use cryprot_core::{
10    AES_PAR_BLOCKS, Block,
11    aes_hash::FIXED_KEY_HASH,
12    aes_rng::AesRng,
13    alloc::allocate_zeroed_vec,
14    buf::Buf,
15    tokio_rayon::spawn_compute,
16    utils::{log2_ceil, xor_inplace},
17};
18use cryprot_net::{Connection, ConnectionError};
19use futures::{SinkExt, StreamExt};
20use ndarray::{Array2, ArrayView2};
21use rand::{CryptoRng, Rng, RngCore, SeedableRng, distr::Uniform, prelude::Distribution};
22use serde::{Deserialize, Serialize};
23use tokio::sync::mpsc::unbounded_channel;
24use tracing::Level;
25
26/// Sender for PPRF expansion.
27pub struct RegularPprfSender {
28    conn: Connection,
29    conf: PprfConfig,
30    base_ots: Array2<[Block; 2]>,
31}
32
33/// Receiver for the PPRF expansion.
34pub struct RegularPprfReceiver {
35    conn: Connection,
36    conf: PprfConfig,
37    base_ots: Array2<Block>,
38    base_choices: Array2<u8>,
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub enum OutFormat {
43    ByLeafIndex,
44    /// TODO: currently unimplemented
45    ByTreeIndex,
46    Interleaved,
47}
48
49/// Errors returned by the PPRF expansion.
50#[derive(thiserror::Error, Debug)]
51pub enum Error {
52    #[error("unable to establish sub-stream to pprf peer")]
53    Connection(#[from] ConnectionError),
54    #[error("error in sending data to pprf peer")]
55    Send(#[source] io::Error),
56    #[error("error in receiving data from pprf peer")]
57    Receive(#[source] io::Error),
58}
59
60/// Number of trees that are expanded in parallel using AES ILP.
61///
62/// The concrete value depends on the target architecture.
63pub const PARALLEL_TREES: usize = AES_PAR_BLOCKS;
64
65/// Communication phase for [`cryprot_net::metrics`].
66pub const COMMUNICATION_PHASE: &str = "pprf-expansion";
67
68/// Config for a PPRF expansion.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub struct PprfConfig {
71    pnt_count: usize,
72    domain: usize,
73    depth: usize,
74}
75
76impl PprfConfig {
77    /// Create a PprfConfig
78    ///
79    /// # Panics
80    /// - if `domain < 2`
81    /// - if `domain % 2 != 0`
82    /// - if `pnt_count % `[`PARALLEL_TREES`]` != 0`
83    pub fn new(domain: usize, pnt_count: usize) -> Self {
84        assert!(domain >= 2, "domain must be at least 2");
85        assert_eq!(0, domain % 2, "domain must be even");
86        assert_eq!(
87            0,
88            pnt_count % PARALLEL_TREES,
89            "pnt_count must be divisable by {PARALLEL_TREES}"
90        );
91        let depth = log2_ceil(domain);
92        Self {
93            pnt_count,
94            domain,
95            depth,
96        }
97    }
98
99    /// Number of base OTs needed for the configured PPRF expansion.
100    pub fn base_ot_count(&self) -> usize {
101        self.depth * self.pnt_count
102    }
103
104    pub fn pnt_count(&self) -> usize {
105        self.pnt_count
106    }
107
108    pub fn domain(&self) -> usize {
109        self.domain
110    }
111
112    pub fn depth(&self) -> usize {
113        self.depth
114    }
115
116    pub fn size(&self) -> usize {
117        self.domain() * self.pnt_count()
118    }
119
120    /// Base OT choice bits needed for this PPRF expansion.
121    ///
122    /// Every `u8` element is either `0` or `1`.
123    pub fn sample_choice_bits<R: RngCore + CryptoRng>(&self, rng: &mut R) -> Vec<u8> {
124        let mut choices = vec![0_u8; self.pnt_count() * self.depth()];
125        let dist = Uniform::new(0, self.domain()).expect("correct range");
126        for choice in choices.chunks_exact_mut(self.depth()) {
127            let mut idx = dist.sample(rng);
128            for choice_bit in choice {
129                *choice_bit = (idx & 1) as u8;
130                idx >>= 1;
131            }
132        }
133        choices
134    }
135
136    pub fn get_points(&self, out_fmt: OutFormat, base_choices: &[u8]) -> Vec<usize> {
137        match out_fmt {
138            OutFormat::Interleaved => {
139                let mut points = self.get_points(OutFormat::ByLeafIndex, base_choices);
140                for (i, point) in points.iter_mut().enumerate() {
141                    *point = interleave_point(*point, i, self.domain())
142                }
143                points
144            }
145            OutFormat::ByLeafIndex => {
146                let base_choices =
147                    ArrayView2::from_shape([self.pnt_count(), self.depth()], base_choices)
148                        .expect("base_choices has wrong size for this conf");
149
150                base_choices
151                    .rows()
152                    .into_iter()
153                    .map(|choice_bits| {
154                        debug_assert_eq!(self.depth(), choice_bits.len());
155                        let point = get_active_path(choice_bits.iter().copied());
156                        debug_assert!(point < self.domain());
157                        point
158                    })
159                    .collect()
160            }
161            _ => todo!(),
162        }
163    }
164}
165
166#[derive(Serialize, Deserialize, Default, Clone, Debug)]
167struct TreeGrp {
168    g: usize,
169    sums: [Vec<[Block; PARALLEL_TREES]>; 2],
170    last_ots: Vec<[Block; 4]>,
171}
172
173impl RegularPprfSender {
174    /// Create a new `RegularPprfSender`.
175    ///
176    /// # Panics
177    /// If `base_ots.len() != conf.base_ot_count()`.
178    pub fn new_with_conf(conn: Connection, conf: PprfConfig, base_ots: Vec<[Block; 2]>) -> Self {
179        assert_eq!(
180            conf.base_ot_count(),
181            base_ots.len(),
182            "wrong number of base OTs"
183        );
184        let base_ots = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_ots)
185            .expect("base_ots.len() is checked before");
186        Self {
187            conn,
188            conf,
189            base_ots,
190        }
191    }
192
193    /// Expand the PPRF into out.
194    ///
195    /// Note that this method temporarily moves out the buffer pointed to by
196    /// `out`. If the future returned by `expand` is dropped or panics, out
197    /// might point to a different buffer.
198    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = COMMUNICATION_PHASE))]
199    pub async fn expand(
200        mut self,
201        value: Block,
202        seed: Block,
203        out_fmt: OutFormat,
204        out: &mut impl Buf<Block>,
205    ) -> Result<(), Error> {
206        let size = self.conf.size();
207        let mut output = mem::take(out);
208        let (mut tx, _) = self.conn.stream().await?;
209        let (send, mut recv) = unbounded_channel();
210        let jh = spawn_compute(move || {
211            if output.len() < size {
212                output.grow_zeroed(size);
213            }
214            let aes = create_fixed_aes();
215            let depth = self.conf.depth();
216            let pnt_count = self.conf.pnt_count();
217            let domain = self.conf.domain();
218
219            let mut rng = AesRng::from_seed(seed);
220            let dd = match out_fmt {
221                OutFormat::Interleaved => depth,
222                _ => depth + 1,
223            };
224
225            let mut tree: Vec<[Block; PARALLEL_TREES]> = Vec::zeroed(2_usize.pow(dd as u32));
226
227            for g in (0..pnt_count).step_by(PARALLEL_TREES) {
228                let mut tree_grp = TreeGrp {
229                    g,
230                    ..Default::default()
231                };
232                let min = PARALLEL_TREES.min(pnt_count - g);
233                let level: &mut [u8] = cast_slice_mut(get_level(&mut tree, 0));
234                rng.fill_bytes(level);
235                tree_grp.sums[0].resize(depth, Default::default());
236                tree_grp.sums[1].resize(depth, Default::default());
237
238                for d in 0..depth {
239                    let (lvl0, lvl1) = if out_fmt == OutFormat::Interleaved && d + 1 == depth {
240                        (
241                            get_level(&mut tree, d),
242                            get_level_output(&mut output, g, domain),
243                        )
244                    } else {
245                        get_cons_levels(&mut tree, d)
246                    };
247
248                    let width = lvl1.len();
249                    let mut child_idx = 0;
250                    while child_idx < width {
251                        let parent_idx = child_idx >> 1;
252                        let parent = &lvl0[parent_idx];
253                        for (aes, sums) in aes.iter().zip(&mut tree_grp.sums) {
254                            let child = &mut lvl1[child_idx];
255                            let sum = &mut sums[d];
256                            aes.encrypt_blocks_b2b(cast_slice(parent), cast_slice_mut(child))
257                                .expect("parent and child have same len");
258                            xor_inplace(child, parent);
259                            xor_inplace(sum, child);
260                            child_idx += 1;
261                        }
262                    }
263                }
264
265                let mut mask_sums = |idx: usize| {
266                    for (d, sums) in tree_grp.sums[idx].iter_mut().take(depth - 1).enumerate() {
267                        for (j, sum) in sums.iter_mut().enumerate().take(min) {
268                            *sum ^= self.base_ots[(g + j, depth - 1 - d)][idx ^ 1];
269                        }
270                    }
271                };
272                mask_sums(0);
273                mask_sums(1);
274
275                let d = depth - 1;
276                tree_grp.last_ots.resize(min, Default::default());
277                for j in 0..min {
278                    tree_grp.last_ots[j][0] = tree_grp.sums[0][d][j];
279                    tree_grp.last_ots[j][1] = tree_grp.sums[1][d][j] ^ value;
280                    tree_grp.last_ots[j][2] = tree_grp.sums[1][d][j];
281                    tree_grp.last_ots[j][3] = tree_grp.sums[0][d][j] ^ value;
282
283                    let mask_in = [
284                        self.base_ots[(g + j, 0)][1],
285                        self.base_ots[(g + j, 0)][1] ^ Block::ONES,
286                        self.base_ots[(g + j, 0)][0],
287                        self.base_ots[(g + j, 0)][0] ^ Block::ONES,
288                    ];
289                    let masks = FIXED_KEY_HASH.cr_hash_blocks(&mask_in);
290                    xor_inplace(&mut tree_grp.last_ots[j], &masks);
291                }
292                tree_grp.sums[0].truncate(depth - 1);
293                tree_grp.sums[1].truncate(depth - 1);
294
295                if send.send(tree_grp).is_err() {
296                    // receiver in async task is dropped, so we stop compute task by returning.
297                    // output will be dropped and not put back into initial &mut out parameter
298                    return output;
299                }
300                if out_fmt != OutFormat::Interleaved {
301                    let last_lvl = get_level(&mut tree, depth);
302                    copy_out(last_lvl, &mut output, g, out_fmt, self.conf);
303                }
304            }
305            output
306        });
307
308        while let Some(tree_group) = recv.recv().await {
309            tx.send(tree_group).await.map_err(Error::Send)?;
310        }
311
312        *out = jh.await.expect("panic in worker thread");
313        Ok(())
314    }
315}
316
317impl RegularPprfReceiver {
318    /// Create a new `RegularPprfReceiver`.
319    ///
320    /// # Panics
321    /// If:
322    /// - `base_ots.len() != conf.base_ot_count()` or
323    /// - `base_choices.len() != conf.base_ot_count()` or
324    pub fn new_with_conf(
325        conn: Connection,
326        conf: PprfConfig,
327        base_ots: Vec<Block>,
328        base_choices: Vec<u8>,
329    ) -> Self {
330        assert_eq!(
331            conf.base_ot_count(),
332            base_ots.len(),
333            "wrong number of base OTs"
334        );
335        assert_eq!(
336            conf.base_ot_count(),
337            base_choices.len(),
338            "wrong number of base choices"
339        );
340        let base_ots = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_ots)
341            .expect("base_ots.len() is checked before");
342        let base_choices = Array2::from_shape_vec([conf.pnt_count(), conf.depth()], base_choices)
343            .expect("base_ots.len() is checked before");
344        Self {
345            conn,
346            conf,
347            base_ots,
348            base_choices,
349        }
350    }
351
352    /// Expand the PPRF into out.
353    ///
354    /// Note that this method temporarily moves out the buffer pointed to by
355    /// `out`. If the future returned by `expand` is dropped or panics, out
356    /// might point to a different buffer.
357    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = COMMUNICATION_PHASE))]
358    pub async fn expand(
359        mut self,
360        out_fmt: OutFormat,
361        out: &mut impl Buf<Block>,
362    ) -> Result<(), Error> {
363        let size = self.conf.size();
364        let mut output = mem::take(out);
365        let (_, mut rx) = self.conn.stream().await?;
366        let (send, recv) = std::sync::mpsc::channel();
367        let jh = spawn_compute(move || {
368            if output.len() < size {
369                output.grow_zeroed(size);
370            }
371            let aes = create_fixed_aes();
372            let points = self.conf.get_points(
373                OutFormat::ByLeafIndex,
374                self.base_choices
375                    .as_slice()
376                    .expect("array order is unchanged"),
377            );
378            let depth = self.conf.depth();
379            let pnt_count = self.conf.pnt_count();
380            let domain = self.conf.domain();
381            let dd = match out_fmt {
382                OutFormat::Interleaved => depth,
383                _ => depth + 1,
384            };
385            let mut tree: Vec<[Block; PARALLEL_TREES]> =
386                allocate_zeroed_vec(2_usize.pow(dd as u32));
387
388            for g in (0..pnt_count).step_by(PARALLEL_TREES) {
389                let Ok(tree_grp): Result<TreeGrp, _> = recv.recv() else {
390                    // Async task is dropped, so we simply return the output which will be dropped
391                    return output;
392                };
393                assert_eq!(g, tree_grp.g);
394
395                if depth > 1 {
396                    let lvl1 = get_level(&mut tree, 1);
397                    for i in 0..PARALLEL_TREES {
398                        let active = self.base_choices[(i + g, depth - 1)] as usize;
399                        lvl1[active ^ 1][i] =
400                            self.base_ots[(i + g, depth - 1)] ^ tree_grp.sums[active ^ 1][0][i];
401                        lvl1[active][i] = Block::ZERO;
402                    }
403                }
404
405                let mut my_sums = [[Block::ZERO; PARALLEL_TREES]; 2];
406
407                for d in 1..depth {
408                    let (lvl0, lvl1) = if out_fmt == OutFormat::Interleaved && d + 1 == depth {
409                        (
410                            get_level(&mut tree, d),
411                            get_level_output(&mut output, g, domain),
412                        )
413                    } else {
414                        get_cons_levels(&mut tree, d)
415                    };
416
417                    my_sums = [[Block::ZERO; PARALLEL_TREES]; 2];
418
419                    let width = lvl1.len();
420                    let mut child_idx = 0;
421                    while child_idx < width {
422                        let parent_idx = child_idx >> 1;
423                        let parent = &lvl0[parent_idx];
424                        for (aes, sum) in aes.iter().zip(&mut my_sums) {
425                            let child = &mut lvl1[child_idx];
426                            aes.encrypt_blocks_b2b(cast_slice(parent), cast_slice_mut(child))
427                                .expect("parent and child have same len");
428                            xor_inplace(child, parent);
429                            xor_inplace(sum, child);
430                            child_idx += 1;
431                        }
432                    }
433
434                    if d != depth - 1 {
435                        for i in 0..PARALLEL_TREES {
436                            let leaf_idx = points[i + g];
437                            let active_child_idx = leaf_idx >> (depth - 1 - d);
438                            let inactive_child_idx = active_child_idx ^ 1;
439                            let not_ai = inactive_child_idx & 1;
440                            let inactive_child = &mut lvl1[inactive_child_idx][i];
441                            let correct_sum = *inactive_child ^ tree_grp.sums[not_ai][d][i];
442                            *inactive_child = correct_sum
443                                ^ my_sums[not_ai][i]
444                                ^ self.base_ots[(i + g, depth - 1 - d)];
445                        }
446                    }
447                }
448                let lvl = if out_fmt == OutFormat::Interleaved {
449                    get_level_output(&mut output, g, domain)
450                } else {
451                    get_level(&mut tree, depth)
452                };
453
454                for j in 0..PARALLEL_TREES {
455                    let active_child_idx = points[j + g];
456                    let inactive_child_idx = active_child_idx ^ 1;
457                    let not_ai = inactive_child_idx & 1;
458
459                    let mask_in = [
460                        self.base_ots[(g + j, 0)],
461                        self.base_ots[(g + j, 0)] ^ Block::ONES,
462                    ];
463                    let masks = FIXED_KEY_HASH.cr_hash_blocks(&mask_in);
464
465                    let ots: [Block; 2] =
466                        array::from_fn(|i| tree_grp.last_ots[j][2 * not_ai + i] ^ masks[i]);
467
468                    let [inactive_child, active_child] =
469                        get_inactive_active_child(j, lvl, inactive_child_idx, active_child_idx);
470
471                    // Fix the sums we computed previously to not include the
472                    // incorrect child values.
473                    let inactive_sum = my_sums[not_ai][j] ^ *inactive_child;
474                    let active_sum = my_sums[not_ai ^ 1][j] ^ *active_child;
475                    *inactive_child = ots[0] ^ inactive_sum;
476                    *active_child = ots[1] ^ active_sum;
477                }
478                if out_fmt != OutFormat::Interleaved {
479                    let last_lvl = get_level(&mut tree, depth);
480                    copy_out(last_lvl, &mut output, g, out_fmt, self.conf);
481                }
482            }
483            output
484        });
485
486        while let Some(tree_grp) = rx.next().await {
487            let tree_grp = tree_grp.map_err(Error::Receive)?;
488            if send.send(tree_grp).is_err() {
489                // panic in the worker thread, so we break from receiving more data
490                break;
491            }
492        }
493
494        *out = jh.await.expect("panic in worker thread");
495        Ok(())
496    }
497}
498
499// Returns the i'th level of the current PARALLEL_TREES trees. The
500// children of node j on level i are located at 2*j and
501// 2*j+1  on level i+1.
502fn get_level(tree: &mut [[Block; PARALLEL_TREES]], i: usize) -> &mut [[Block; PARALLEL_TREES]] {
503    let size = 1 << i;
504    let offset = size - 1;
505    &mut tree[offset..offset + size]
506}
507
508// Return the i'th and (i+1)'th level
509fn get_cons_levels(
510    tree: &mut [[Block; PARALLEL_TREES]],
511    i: usize,
512) -> (
513    &mut [[Block; PARALLEL_TREES]],
514    &mut [[Block; PARALLEL_TREES]],
515) {
516    let size0 = 1 << i;
517    let offset0 = size0 - 1;
518    let tree = &mut tree[offset0..];
519    let (level0, rest) = tree.split_at_mut(size0);
520    let size1 = 1 << (i + 1);
521    debug_assert_eq!(size0 + offset0, size1 - 1);
522    let level1 = &mut rest[..size1];
523    (level0, level1)
524}
525
526fn get_level_output(
527    output: &mut [Block],
528    tree_idx: usize,
529    domain: usize,
530) -> &mut [[Block; PARALLEL_TREES]] {
531    let out = cast_slice_mut(output);
532    let forest = tree_idx / PARALLEL_TREES;
533    debug_assert_eq!(tree_idx % PARALLEL_TREES, 0);
534    let start = forest * domain;
535    &mut out[start..start + domain]
536}
537
538fn get_active_path<I>(choice_bits: I) -> usize
539where
540    I: Iterator<Item = u8> + ExactSizeIterator,
541{
542    choice_bits
543        .enumerate()
544        .fold(0, |point, (i, cb)| point | ((cb as usize) << i))
545}
546
547fn get_inactive_active_child(
548    tree: usize,
549    lvl: &mut [[Block; PARALLEL_TREES]],
550    inactive_child_idx: usize,
551    active_child_idx: usize,
552) -> [&mut Block; 2] {
553    let children = match active_child_idx.cmp(&inactive_child_idx) {
554        Ordering::Less => {
555            let (left, right) = lvl.split_at_mut(inactive_child_idx);
556            [&mut right[0], &mut left[active_child_idx]]
557        }
558        Ordering::Greater => {
559            let (left, right) = lvl.split_at_mut(active_child_idx);
560            [&mut left[inactive_child_idx], &mut right[0]]
561        }
562        Ordering::Equal => {
563            unreachable!("Impossible, active and inactive indices are always different")
564        }
565    };
566    children.map(|arr| &mut arr[tree])
567}
568
569fn interleave_point(point: usize, tree_idx: usize, domain: usize) -> usize {
570    let sub_tree = tree_idx % PARALLEL_TREES;
571    let forest = tree_idx / PARALLEL_TREES;
572    (forest * domain + point) * PARALLEL_TREES + sub_tree
573}
574
575fn copy_out(
576    last_lvl: &[[Block; PARALLEL_TREES]],
577    output: &mut [Block],
578    tree_idx: usize,
579    out_fmt: OutFormat,
580    conf: PprfConfig,
581) {
582    let total_trees = conf.pnt_count();
583    let curr_size = PARALLEL_TREES.min(total_trees - tree_idx);
584    let last_lvl: &[Block] = cast_slice(last_lvl);
585    // assert_eq!(conf.domain(), last_lvl.len() / PARALLEL_TREES);
586    let domain = conf.domain();
587    match out_fmt {
588        OutFormat::ByLeafIndex => {
589            for leaf_idx in 0..domain {
590                let o_idx = total_trees * leaf_idx + tree_idx;
591                let i_idx = leaf_idx * PARALLEL_TREES;
592                // todo copy from slice
593                output[o_idx..curr_size + o_idx]
594                    .copy_from_slice(&last_lvl[i_idx..curr_size + i_idx]);
595            }
596        }
597        OutFormat::ByTreeIndex => todo!(),
598        OutFormat::Interleaved => panic!("Do not copy_out for OutFormat::Interleaved"),
599    }
600}
601
602// Create a pair of fixed key aes128 ciphers
603fn create_fixed_aes() -> [Aes128; 2] {
604    [
605        Aes128::new(
606            &91389970179024809574621370423327856399_u128
607                .to_le_bytes()
608                .into(),
609        ),
610        Aes128::new(
611            &297966570818470707816499469807199042980_u128
612                .to_le_bytes()
613                .into(),
614        ),
615    ]
616}
617
618/// Intended for testing. Generates suitable OTs and choice bits for a pprf
619/// evaluation.
620#[doc(hidden)]
621pub fn fake_base<R: RngCore + CryptoRng>(
622    conf: PprfConfig,
623    rng: &mut R,
624) -> (Vec<[Block; 2]>, Vec<Block>, Vec<u8>) {
625    let base_ot_count = conf.base_ot_count();
626    let msg2: Vec<[Block; 2]> = (0..base_ot_count).map(|_| rng.random()).collect();
627    let choices = conf.sample_choice_bits(rng);
628    let msg = msg2
629        .iter()
630        .zip(choices.iter())
631        .map(|(m, c)| m[*c as usize])
632        .collect();
633    (msg2, msg, choices)
634}
635
636#[cfg(test)]
637mod tests {
638    use cryprot_core::{Block, alloc::HugePageMemory, buf::Buf, utils::xor_inplace};
639    use cryprot_net::testing::local_conn;
640    use rand::{Rng, SeedableRng, rngs::StdRng};
641
642    use crate::{
643        OutFormat, PARALLEL_TREES, PprfConfig, RegularPprfReceiver, RegularPprfSender, fake_base,
644    };
645
646    #[tokio::test]
647    async fn test_pprf_by_leaf() {
648        let conf = PprfConfig::new(334, 5 * PARALLEL_TREES);
649        let out_fmt = OutFormat::ByLeafIndex;
650        let mut rng = StdRng::seed_from_u64(42);
651
652        let (c1, c2) = local_conn().await.unwrap();
653        let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
654        let points = conf.get_points(out_fmt, &base_choices);
655
656        let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
657        let receiver =
658            RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
659        eprintln!("{points:?}");
660        let mut s_out = HugePageMemory::zeroed(conf.size());
661        let mut r_out = HugePageMemory::zeroed(conf.size());
662        let seed = rng.random();
663        tokio::try_join!(
664            sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
665            receiver.expand(out_fmt, &mut r_out)
666        )
667        .unwrap();
668
669        xor_inplace(&mut s_out, &r_out);
670
671        for j in 0..points.len() {
672            for i in 0..conf.domain() {
673                let idx = i * points.len() + j;
674
675                let exp = if points[j] == i {
676                    Block::ONES
677                } else {
678                    Block::ZERO
679                };
680                assert_eq!(exp, s_out[idx]);
681            }
682        }
683    }
684
685    #[tokio::test]
686    async fn test_pprf_interleaved_simple() {
687        // Reduce size to minimum to debug
688        let conf = PprfConfig::new(2, PARALLEL_TREES);
689        let out_fmt = OutFormat::Interleaved;
690        let mut rng = StdRng::seed_from_u64(42);
691
692        let (c1, c2) = local_conn().await.unwrap();
693        let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
694        let points = conf.get_points(out_fmt, &base_choices);
695
696        // Print the base OTs to see correlation
697        // println!("Sender base OTs: {:?}", sender_base_ots);
698        // println!("Receiver base OTs: {:?}", receiver_base_ots);
699        println!("Base choices: {:?}", base_choices);
700
701        let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
702        let receiver =
703            RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
704        println!("Points: {:?}", points);
705        let mut s_out = Vec::zeroed(conf.size());
706        let mut r_out = Vec::zeroed(conf.size());
707        let seed = rng.random();
708        tokio::try_join!(
709            sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
710            receiver.expand(out_fmt, &mut r_out)
711        )
712        .unwrap();
713
714        xor_inplace(&mut s_out, &r_out);
715        println!("XORed output: {:?}", s_out);
716        for (i, blk) in s_out.iter().enumerate() {
717            let f = points.contains(&i);
718            let exp = if f { Block::ONES } else { Block::ZERO };
719            assert_eq!(exp, *blk, "block {i} not as expected");
720        }
721    }
722
723    #[tokio::test]
724    async fn test_pprf_interleaved() {
725        let conf = PprfConfig::new(334, 5 * PARALLEL_TREES);
726        let out_fmt = OutFormat::Interleaved;
727        let mut rng = StdRng::seed_from_u64(42);
728
729        let (c1, c2) = local_conn().await.unwrap();
730        let (sender_base_ots, receiver_base_ots, base_choices) = fake_base(conf, &mut rng);
731        let points = conf.get_points(out_fmt, &base_choices);
732
733        let sender = RegularPprfSender::new_with_conf(c1, conf, sender_base_ots);
734        let receiver =
735            RegularPprfReceiver::new_with_conf(c2, conf, receiver_base_ots, base_choices);
736        let mut s_out = HugePageMemory::zeroed(conf.size());
737        let mut r_out = HugePageMemory::zeroed(conf.size());
738        let seed = rng.random();
739        tokio::try_join!(
740            sender.expand(Block::ONES, seed, out_fmt, &mut s_out),
741            receiver.expand(out_fmt, &mut r_out)
742        )
743        .unwrap();
744
745        xor_inplace(&mut s_out, &r_out);
746        for (i, blk) in s_out.iter().enumerate() {
747            let f = points.contains(&i);
748            let exp = if f { Block::ONES } else { Block::ZERO };
749            assert_eq!(exp, *blk, "block {i} not as expected");
750        }
751    }
752}