cryprot_ot/
noisy_vole.rs

1//! Noisy-Vole computes for chosen c and delta, a and b s.t. a = b + c * delta
2//! in GF(2^128).
3
4use std::io;
5
6use bitvec::{order::Lsb0, vec::BitVec};
7use bytemuck::{cast_slice, cast_slice_mut};
8use cryprot_core::{Block, aes_rng::AesRng, buf::Buf, tokio_rayon::spawn_compute};
9use cryprot_net::{Connection, ConnectionError};
10use rand::{Rng, SeedableRng};
11use subtle::{Choice, ConditionallySelectable};
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tracing::Level;
14
15use crate::phase;
16
17#[derive(thiserror::Error, Debug)]
18pub enum Error {
19    #[error("unable to establish sub connection")]
20    Connection(#[from] ConnectionError),
21    #[error("error in sending/receiving noisy vole data")]
22    Io(#[from] io::Error),
23}
24
25pub struct NoisyVoleSender {
26    conn: Connection,
27}
28
29impl NoisyVoleSender {
30    pub fn new(conn: Connection) -> Self {
31        Self { conn }
32    }
33
34    /// For chosen delta compute b s.t. a = b + c * delta.
35    ///
36    /// Operations are performed in GF(2^128). Note that the bits of `delta`
37    /// must be equal to the choice bits for the passed base OTs.
38    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::NOISY_VOLE))]
39    pub async fn send(
40        &mut self,
41        size: usize,
42        delta: Block,
43        ots: Vec<Block>,
44    ) -> Result<Vec<Block>, Error> {
45        assert_eq!(Block::BITS, ots.len());
46        let mut msg: Vec<Block> = Vec::zeroed(Block::BITS * size);
47        let (_, mut rx) = self.conn.byte_stream().await?;
48        rx.read_exact(cast_slice_mut(&mut msg)).await?;
49
50        let jh = spawn_compute(move || {
51            let mut b = vec![Block::ZERO; size];
52            let delta_arr = <[u64; 2]>::from(delta);
53            let xb: BitVec<u64, Lsb0> = BitVec::from_slice(&delta_arr);
54            let mut k = 0;
55            for (i, ot) in ots.iter().enumerate() {
56                let mut rng = AesRng::from_seed(*ot);
57
58                for bj in &mut b {
59                    let mut tmp: Block = rng.random();
60
61                    tmp ^=
62                        Block::conditional_select(&Block::ZERO, &msg[k], Choice::from(xb[i] as u8));
63                    *bj ^= tmp;
64                    k += 1;
65                }
66            }
67            b
68        });
69
70        Ok(jh.await.expect("worker panic"))
71    }
72}
73
74pub struct NoisyVoleReceiver {
75    conn: Connection,
76}
77
78impl NoisyVoleReceiver {
79    pub fn new(conn: Connection) -> Self {
80        Self { conn }
81    }
82
83    /// For chosen c compute a s.t. a = b + c * delta.
84    ///
85    /// Operations are performed in GF(2^128). Note that the bits of `delta` for
86    /// the [`NoisyVoleSender`] must be equal to the choice bits for the
87    /// passed base OTs.
88    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::NOISY_VOLE))]
89    pub async fn receive(
90        &mut self,
91        c: Vec<Block>,
92        ots: Vec<[Block; 2]>,
93    ) -> Result<Vec<Block>, Error> {
94        let jh = spawn_compute(move || {
95            let mut a = Vec::zeroed(c.len());
96            let mut msg: Vec<Block> = Vec::zeroed(ots.len() * a.len());
97
98            let mut k = 0;
99            for (i, [ot0, ot1]) in ots.into_iter().enumerate() {
100                let mut rng = AesRng::from_seed(ot0);
101                let t1 = Block::ONE << i;
102
103                for (aj, cj) in a.iter_mut().zip(c.iter()) {
104                    msg[k] = rng.random();
105                    *aj ^= msg[k];
106                    let t0 = t1.gf_mul(cj);
107                    msg[k] ^= t0;
108                    k += 1;
109                }
110
111                let mut rng = AesRng::from_seed(ot1);
112                for m in &mut msg[k - c.len()..k] {
113                    let t: Block = rng.random();
114                    *m ^= t;
115                }
116            }
117            (msg, a)
118        });
119        let (mut tx, _) = self.conn.byte_stream().await?;
120        let (msg, a) = jh.await.expect("worker panic");
121        tx.write_all(cast_slice(&msg)).await?;
122        Ok(a)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use bitvec::{order::Lsb0, slice::BitSlice};
129    use cryprot_core::{Block, utils::xor_inplace};
130    use cryprot_net::testing::{init_tracing, local_conn};
131    use rand::{Rng, SeedableRng, rngs::StdRng};
132
133    use crate::noisy_vole::{NoisyVoleReceiver, NoisyVoleSender};
134
135    #[tokio::test]
136    async fn test_noisy_vole() {
137        let _g = init_tracing();
138        let (c1, c2) = local_conn().await.unwrap();
139        let mut sender = NoisyVoleSender::new(c1);
140        let mut receiver = NoisyVoleReceiver::new(c2);
141        let mut rng = StdRng::seed_from_u64(423423);
142        let r_ots: Vec<[Block; 2]> = (0..128).map(|_| rng.random()).collect();
143        let delta: Block = rng.random();
144        let choice = BitSlice::<_, Lsb0>::from_slice(delta.as_bytes());
145        let s_ots: Vec<_> = r_ots
146            .iter()
147            .zip(choice)
148            .map(|(ots, c)| ots[*c as usize].clone())
149            .collect();
150
151        let size = 200;
152        let mut c: Vec<_> = (0..size).map(|_| rng.random()).collect();
153
154        let (mut b, a) = tokio::try_join!(
155            sender.send(size, delta, s_ots),
156            receiver.receive(c.clone(), r_ots)
157        )
158        .unwrap();
159
160        for ci in &mut c {
161            *ci = ci.gf_mul(&delta);
162        }
163
164        xor_inplace(&mut b, &c);
165
166        assert_eq!(a, b);
167    }
168}