1use std::io;
4
5use bitvec::{order::Lsb0, vec::BitVec};
6use cryprot_core::{Block, buf::Buf};
7use cryprot_net::ConnectionError;
8use futures::{SinkExt, StreamExt};
9use subtle::ConditionallySelectable;
10use thiserror::Error;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12
13use crate::{
14 Connected, CotReceiver, CotSender, Malicious, RandChoiceRotReceiver, RandChoiceRotSender,
15 RotReceiver, RotSender, SemiHonest,
16};
17
18#[derive(Debug)]
25pub struct ChosenChoice<P>(P);
26
27impl<P> ChosenChoice<P> {
28 pub fn new(ot_protocol: P) -> Self {
29 Self(ot_protocol)
30 }
31}
32
33impl<P: Connected> Connected for ChosenChoice<P> {
34 fn connection(&mut self) -> &mut cryprot_net::Connection {
35 self.0.connection()
36 }
37}
38
39impl<P: SemiHonest> SemiHonest for ChosenChoice<P> {}
40
41impl<P: Malicious> Malicious for ChosenChoice<P> {}
43
44#[derive(Error, Debug)]
45pub enum Error<E> {
46 #[error("unable to perform R-OTs")]
47 Rot(E),
48 #[error("error in sending correction values for C-OT")]
49 Correction(io::Error),
50 #[error("expected correction values but receiver is closed")]
51 MissingCorrection,
52 #[error("connection error to peer")]
53 Connecion(#[from] ConnectionError),
54}
55
56impl<R: RandChoiceRotReceiver> RotReceiver for ChosenChoice<R> {
57 type Error = Error<R::Error>;
58
59 async fn receive_into(
60 &mut self,
61 ots: &mut impl cryprot_core::buf::Buf<cryprot_core::Block>,
62 choices: &[subtle::Choice],
63 ) -> Result<(), Self::Error> {
64 let mut rand_choices = self
65 .0
66 .rand_choice_receive_into(ots)
67 .await
68 .map_err(Error::Rot)?;
69 for (c1, c2) in rand_choices.iter_mut().zip(choices) {
70 *c1 ^= *c2;
71 }
72 let mut bv: BitVec<u8, Lsb0> = BitVec::with_capacity(choices.len());
73 bv.extend(rand_choices.iter().map(|c| c.unwrap_u8() != 0));
74
75 let (mut tx, _) = self.connection().stream().await?;
76 tx.send(bv).await.map_err(Error::Correction)?;
77 Ok(())
78 }
79}
80
81impl<S: RotSender + RandChoiceRotSender + Send> RotSender for ChosenChoice<S> {
82 type Error = Error<S::Error>;
83
84 async fn send_into(
85 &mut self,
86 ots: &mut impl cryprot_core::buf::Buf<[cryprot_core::Block; 2]>,
87 ) -> Result<(), Self::Error> {
88 self.0.send_into(ots).await.map_err(Error::Rot)?;
89 let (_, mut rx) = self.connection().stream().await?;
90 let correction: BitVec<u8, Lsb0> = rx
91 .next()
92 .await
93 .ok_or(Error::MissingCorrection)?
94 .map_err(Error::Correction)?;
95
96 for (ots, c_bit) in ots.iter_mut().zip(correction) {
97 let tmp = *ots;
98 ots[0] = tmp[c_bit as usize];
99 ots[1] = tmp[!c_bit as usize];
100 }
101 Ok(())
102 }
103}
104
105#[derive(Debug)]
168pub struct CorrelatedFromRandom<P>(P);
169
170impl<P> CorrelatedFromRandom<P> {
171 pub fn new(protocol: P) -> Self {
172 Self(protocol)
173 }
174}
175
176impl<P: Connected> Connected for CorrelatedFromRandom<P> {
177 fn connection(&mut self) -> &mut cryprot_net::Connection {
178 self.0.connection()
179 }
180}
181
182impl<P: SemiHonest> SemiHonest for CorrelatedFromRandom<P> {}
183
184impl<P: Malicious> Malicious for CorrelatedFromRandom<P> {}
186
187const COR_CHUNK_SIZE: usize = 8500 / Block::BYTES;
189
190impl<S: RotSender> CotSender for CorrelatedFromRandom<S>
191where
192 S::Error: From<ConnectionError> + From<std::io::Error>,
193{
194 type Error = S::Error;
195
196 async fn correlated_send_into<B, F>(
197 &mut self,
198 ots: &mut B,
199 mut correlation: F,
200 ) -> Result<(), Self::Error>
201 where
202 B: Buf<Block>,
203 F: FnMut(usize) -> Block + Send,
204 {
205 let mut r_ots = B::zeroed_arr2(ots.len());
206 self.0.send_into(&mut r_ots).await?;
207 let mut send_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
208 let (mut tx, _) = self.connection().byte_stream().await?;
209 for (chunk_idx, (ot_chunk, rot_chunk)) in ots
214 .chunks_mut(send_buf.len())
215 .zip(r_ots.chunks(send_buf.len()))
216 .enumerate()
217 {
218 for (idx, ((ot, r_ot), correction)) in ot_chunk
219 .iter_mut()
220 .zip(rot_chunk)
221 .zip(&mut send_buf)
222 .enumerate()
223 {
224 *ot = r_ot[0];
225 *correction = r_ot[1] ^ r_ot[0] ^ correlation(chunk_idx * COR_CHUNK_SIZE + idx);
226 }
227 tx.write_all(bytemuck::must_cast_slice_mut(
228 &mut send_buf[..ot_chunk.len()],
229 ))
230 .await?;
231 }
232 Ok(())
233 }
234}
235
236impl<R: RotReceiver> CotReceiver for CorrelatedFromRandom<R>
237where
238 R::Error: From<ConnectionError> + From<std::io::Error>,
239{
240 type Error = R::Error;
241
242 async fn correlated_receive_into<B>(
243 &mut self,
244 ots: &mut B,
245 choices: &[subtle::Choice],
246 ) -> Result<(), Self::Error>
247 where
248 B: Buf<Block>,
249 {
250 self.0.receive_into(ots, choices).await?;
251 let mut recv_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
252 let (_, mut rx) = self.connection().byte_stream().await?;
253 for (ot_chunk, choice_chunk) in ots
254 .chunks_mut(COR_CHUNK_SIZE)
255 .zip(choices.chunks(COR_CHUNK_SIZE))
256 {
257 rx.read_exact(bytemuck::must_cast_slice_mut(
258 &mut recv_buf[..ot_chunk.len()],
259 ))
260 .await?;
261 for ((ot, correction), choice) in ot_chunk.iter_mut().zip(&recv_buf).zip(choice_chunk) {
262 let use_correction = Block::conditional_select(&Block::ZERO, &Block::ONES, *choice);
263 *ot ^= use_correction & *correction;
264 }
265 }
266 Ok(())
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use cryprot_net::testing::{init_tracing, local_conn};
273 use rand::{SeedableRng, rngs::StdRng};
274
275 use crate::{
276 RotReceiver, RotSender,
277 adapter::ChosenChoice,
278 random_choices,
279 silent_ot::{SemiHonestSilentOtReceiver, SemiHonestSilentOtSender},
280 };
281
282 #[tokio::test]
283 async fn test_chosen_choice_adapter() {
284 let _g = init_tracing();
285 let (c1, c2) = local_conn().await.unwrap();
286 let mut sender = ChosenChoice::new(SemiHonestSilentOtSender::new(c1));
287 let mut receiver = ChosenChoice::new(SemiHonestSilentOtReceiver::new(c2));
288
289 let count = 2_usize.pow(10);
290 let choices = random_choices(count, &mut StdRng::seed_from_u64(234));
291
292 let (s_ots, r_ots) =
293 tokio::try_join!(sender.send(count), receiver.receive(&choices)).unwrap();
294
295 for (i, c) in choices.iter().enumerate() {
296 assert_eq!(
297 s_ots[i][c.unwrap_u8() as usize],
298 r_ots[i],
299 "ot {i}, choice: {}, s_ots: {:?}, r_ot: {:?}",
300 c.unwrap_u8(),
301 s_ots[i],
302 r_ots[i]
303 );
304 }
305 }
306}