1use std::io;
5
6use bitvec::{order::Lsb0, vec::BitVec};
7use bytemuck::{cast_slice, cast_slice_mut};
8use cryprot_core::{Block, aes_rng::AesRng, buf::Buf, tokio_rayon::spawn_compute};
9use cryprot_net::{Connection, ConnectionError};
10use rand::{Rng, SeedableRng};
11use subtle::{Choice, ConditionallySelectable};
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tracing::Level;
14
15use crate::phase;
16
17#[derive(thiserror::Error, Debug)]
18pub enum Error {
19 #[error("unable to establish sub connection")]
20 Connection(#[from] ConnectionError),
21 #[error("error in sending/receiving noisy vole data")]
22 Io(#[from] io::Error),
23}
24
25pub struct NoisyVoleSender {
26 conn: Connection,
27}
28
29impl NoisyVoleSender {
30 pub fn new(conn: Connection) -> Self {
31 Self { conn }
32 }
33
34 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::NOISY_VOLE))]
39 pub async fn send(
40 &mut self,
41 size: usize,
42 delta: Block,
43 ots: Vec<Block>,
44 ) -> Result<Vec<Block>, Error> {
45 assert_eq!(Block::BITS, ots.len());
46 let mut msg: Vec<Block> = Vec::zeroed(Block::BITS * size);
47 let (_, mut rx) = self.conn.byte_stream().await?;
48 rx.read_exact(cast_slice_mut(&mut msg)).await?;
49
50 let jh = spawn_compute(move || {
51 let mut b = vec![Block::ZERO; size];
52 let delta_arr = <[u64; 2]>::from(delta);
53 let xb: BitVec<u64, Lsb0> = BitVec::from_slice(&delta_arr);
54 let mut k = 0;
55 for (i, ot) in ots.iter().enumerate() {
56 let mut rng = AesRng::from_seed(*ot);
57
58 for bj in &mut b {
59 let mut tmp: Block = rng.random();
60
61 tmp ^=
62 Block::conditional_select(&Block::ZERO, &msg[k], Choice::from(xb[i] as u8));
63 *bj ^= tmp;
64 k += 1;
65 }
66 }
67 b
68 });
69
70 Ok(jh.await.expect("worker panic"))
71 }
72}
73
74pub struct NoisyVoleReceiver {
75 conn: Connection,
76}
77
78impl NoisyVoleReceiver {
79 pub fn new(conn: Connection) -> Self {
80 Self { conn }
81 }
82
83 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::NOISY_VOLE))]
89 pub async fn receive(
90 &mut self,
91 c: Vec<Block>,
92 ots: Vec<[Block; 2]>,
93 ) -> Result<Vec<Block>, Error> {
94 let jh = spawn_compute(move || {
95 let mut a = Vec::zeroed(c.len());
96 let mut msg: Vec<Block> = Vec::zeroed(ots.len() * a.len());
97
98 let mut k = 0;
99 for (i, [ot0, ot1]) in ots.into_iter().enumerate() {
100 let mut rng = AesRng::from_seed(ot0);
101 let t1 = Block::ONE << i;
102
103 for (aj, cj) in a.iter_mut().zip(c.iter()) {
104 msg[k] = rng.random();
105 *aj ^= msg[k];
106 let t0 = t1.gf_mul(cj);
107 msg[k] ^= t0;
108 k += 1;
109 }
110
111 let mut rng = AesRng::from_seed(ot1);
112 for m in &mut msg[k - c.len()..k] {
113 let t: Block = rng.random();
114 *m ^= t;
115 }
116 }
117 (msg, a)
118 });
119 let (mut tx, _) = self.conn.byte_stream().await?;
120 let (msg, a) = jh.await.expect("worker panic");
121 tx.write_all(cast_slice(&msg)).await?;
122 Ok(a)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use bitvec::{order::Lsb0, slice::BitSlice};
129 use cryprot_core::{Block, utils::xor_inplace};
130 use cryprot_net::testing::{init_tracing, local_conn};
131 use rand::{Rng, SeedableRng, rngs::StdRng};
132
133 use crate::noisy_vole::{NoisyVoleReceiver, NoisyVoleSender};
134
135 #[tokio::test]
136 async fn test_noisy_vole() {
137 let _g = init_tracing();
138 let (c1, c2) = local_conn().await.unwrap();
139 let mut sender = NoisyVoleSender::new(c1);
140 let mut receiver = NoisyVoleReceiver::new(c2);
141 let mut rng = StdRng::seed_from_u64(423423);
142 let r_ots: Vec<[Block; 2]> = (0..128).map(|_| rng.random()).collect();
143 let delta: Block = rng.random();
144 let choice = BitSlice::<_, Lsb0>::from_slice(delta.as_bytes());
145 let s_ots: Vec<_> = r_ots
146 .iter()
147 .zip(choice)
148 .map(|(ots, c)| ots[*c as usize].clone())
149 .collect();
150
151 let size = 200;
152 let mut c: Vec<_> = (0..size).map(|_| rng.random()).collect();
153
154 let (mut b, a) = tokio::try_join!(
155 sender.send(size, delta, s_ots),
156 receiver.receive(c.clone(), r_ots)
157 )
158 .unwrap();
159
160 for ci in &mut c {
161 *ci = ci.gf_mul(&delta);
162 }
163
164 xor_inplace(&mut b, &c);
165
166 assert_eq!(a, b);
167 }
168}