cryprot_ot/
adapter.rs

1//! Adapters for OT types.
2
3use std::io;
4
5use bitvec::{order::Lsb0, vec::BitVec};
6use cryprot_core::{Block, buf::Buf};
7use cryprot_net::ConnectionError;
8use futures::{SinkExt, StreamExt};
9use subtle::ConditionallySelectable;
10use thiserror::Error;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12
13use crate::{
14    Connected, CotReceiver, CotSender, Malicious, RandChoiceRotReceiver, RandChoiceRotSender,
15    RotReceiver, RotSender, SemiHonest,
16};
17
18/// Adapts a [`RandChoiceRotReceiver`] into a [`RotReceiver`] and
19/// [`RandChoiceRotSender`] into [`RotSender`].
20///
21/// This adapter can be used to adapt the [silent OT](`crate::silent_ot`)
22/// protocol into a protocol with chosen choice bits at the cost number of OTs
23/// bits of communication.
24#[derive(Debug)]
25pub struct ChosenChoice<P>(P);
26
27impl<P> ChosenChoice<P> {
28    pub fn new(ot_protocol: P) -> Self {
29        Self(ot_protocol)
30    }
31}
32
33impl<P: Connected> Connected for ChosenChoice<P> {
34    fn connection(&mut self) -> &mut cryprot_net::Connection {
35        self.0.connection()
36    }
37}
38
39impl<P: SemiHonest> SemiHonest for ChosenChoice<P> {}
40
41// TODO is there something I can cite that this holds?
42impl<P: Malicious> Malicious for ChosenChoice<P> {}
43
44#[derive(Error, Debug)]
45pub enum Error<E> {
46    #[error("unable to perform R-OTs")]
47    Rot(E),
48    #[error("error in sending correction values for C-OT")]
49    Correction(io::Error),
50    #[error("expected correction values but receiver is closed")]
51    MissingCorrection,
52    #[error("connection error to peer")]
53    Connecion(#[from] ConnectionError),
54}
55
56impl<R: RandChoiceRotReceiver> RotReceiver for ChosenChoice<R> {
57    type Error = Error<R::Error>;
58
59    async fn receive_into(
60        &mut self,
61        ots: &mut impl cryprot_core::buf::Buf<cryprot_core::Block>,
62        choices: &[subtle::Choice],
63    ) -> Result<(), Self::Error> {
64        let mut rand_choices = self
65            .0
66            .rand_choice_receive_into(ots)
67            .await
68            .map_err(Error::Rot)?;
69        for (c1, c2) in rand_choices.iter_mut().zip(choices) {
70            *c1 ^= *c2;
71        }
72        let mut bv: BitVec<u8, Lsb0> = BitVec::with_capacity(choices.len());
73        bv.extend(rand_choices.iter().map(|c| c.unwrap_u8() != 0));
74
75        let (mut tx, _) = self.connection().stream().await?;
76        tx.send(bv).await.map_err(Error::Correction)?;
77        Ok(())
78    }
79}
80
81impl<S: RotSender + RandChoiceRotSender + Send> RotSender for ChosenChoice<S> {
82    type Error = Error<S::Error>;
83
84    async fn send_into(
85        &mut self,
86        ots: &mut impl cryprot_core::buf::Buf<[cryprot_core::Block; 2]>,
87    ) -> Result<(), Self::Error> {
88        self.0.send_into(ots).await.map_err(Error::Rot)?;
89        let (_, mut rx) = self.connection().stream().await?;
90        let correction: BitVec<u8, Lsb0> = rx
91            .next()
92            .await
93            .ok_or(Error::MissingCorrection)?
94            .map_err(Error::Correction)?;
95
96        for (ots, c_bit) in ots.iter_mut().zip(correction) {
97            let tmp = *ots;
98            ots[0] = tmp[c_bit as usize];
99            ots[1] = tmp[!c_bit as usize];
100        }
101        Ok(())
102    }
103}
104
105/// Adapts any [`RotSender`]/[`RotReceiver`] into a
106/// [`CotSender`]/[`CotReceiver`].
107///
108/// This adapter can also be used to easily implement the correlated OT traits
109/// on the protocol types directly. Because `&mut S: RotSender` when `S:
110/// RotSender` you can create a temporary [`CorrelatedFromRandom`] from a `&mut
111/// self` inside an implementation of the correlated traits.
112///
113/// ```
114/// use cryprot_core::{Block, buf::Buf};
115///
116/// use cryprot_ot::adapter::CorrelatedFromRandom;
117/// use cryprot_ot::{Connected, CotSender, RotSender};
118///
119/// struct MyRotSender;
120///
121/// # impl Connected for MyRotSender {
122/// #     fn connection(&mut self) -> &mut cryprot_net::Connection {
123/// #         todo!()
124/// #     }
125/// # }
126///
127/// // Error type must implement `From<ConnectionError>` and `From<io::Error>` for
128/// // adapter
129/// #[derive(thiserror::Error, Debug)]
130/// enum Error {
131///     #[error("connection")]
132///     Connection(#[from] cryprot_net::ConnectionError),
133///     #[error("io")]
134///     Io(#[from] std::io::Error),
135/// }
136///
137/// impl RotSender for MyRotSender {
138///     type Error = Error;
139///
140///     async fn send_into(
141///         &mut self,
142///         ots: &mut impl cryprot_core::buf::Buf<[cryprot_core::Block; 2]>,
143///     ) -> Result<(), Self::Error> {
144///         todo!()
145///     }
146/// }
147///
148/// impl CotSender for MyRotSender {
149///     type Error = <MyRotSender as RotSender>::Error;
150///
151///     async fn correlated_send_into<B, F>(
152///         &mut self,
153///         ots: &mut B,
154///         correlation: F,
155///     ) -> Result<(), Self::Error>
156///     where
157///         B: Buf<Block>,
158///         F: FnMut(usize) -> Block + Send,
159///     {
160///         // because &mut self also implements RotSender, we can use it for the adapter
161///         CorrelatedFromRandom::new(self)
162///             .correlated_send_into(ots, correlation)
163///             .await
164///     }
165/// }
166/// ```
167#[derive(Debug)]
168pub struct CorrelatedFromRandom<P>(P);
169
170impl<P> CorrelatedFromRandom<P> {
171    pub fn new(protocol: P) -> Self {
172        Self(protocol)
173    }
174}
175
176impl<P: Connected> Connected for CorrelatedFromRandom<P> {
177    fn connection(&mut self) -> &mut cryprot_net::Connection {
178        self.0.connection()
179    }
180}
181
182impl<P: SemiHonest> SemiHonest for CorrelatedFromRandom<P> {}
183
184// For a discussion of the security of this see https://github.com/osu-crypto/libOTe/issues/167
185impl<P: Malicious> Malicious for CorrelatedFromRandom<P> {}
186
187// should fit in one jumbo frame
188const COR_CHUNK_SIZE: usize = 8500 / Block::BYTES;
189
190impl<S: RotSender> CotSender for CorrelatedFromRandom<S>
191where
192    S::Error: From<ConnectionError> + From<std::io::Error>,
193{
194    type Error = S::Error;
195
196    async fn correlated_send_into<B, F>(
197        &mut self,
198        ots: &mut B,
199        mut correlation: F,
200    ) -> Result<(), Self::Error>
201    where
202        B: Buf<Block>,
203        F: FnMut(usize) -> Block + Send,
204    {
205        let mut r_ots = B::zeroed_arr2(ots.len());
206        self.0.send_into(&mut r_ots).await?;
207        let mut send_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
208        let (mut tx, _) = self.connection().byte_stream().await?;
209        // Using spawn_compute here results in slightly lower performance.
210        // I think there is just not enough work done per byte transmitted here.
211        // This implementation is also simpler and less prone to errors than the
212        // spawn_compute one.
213        for (chunk_idx, (ot_chunk, rot_chunk)) in ots
214            .chunks_mut(send_buf.len())
215            .zip(r_ots.chunks(send_buf.len()))
216            .enumerate()
217        {
218            for (idx, ((ot, r_ot), correction)) in ot_chunk
219                .iter_mut()
220                .zip(rot_chunk)
221                .zip(&mut send_buf)
222                .enumerate()
223            {
224                *ot = r_ot[0];
225                *correction = r_ot[1] ^ r_ot[0] ^ correlation(chunk_idx * COR_CHUNK_SIZE + idx);
226            }
227            tx.write_all(bytemuck::must_cast_slice_mut(
228                &mut send_buf[..ot_chunk.len()],
229            ))
230            .await?;
231        }
232        Ok(())
233    }
234}
235
236impl<R: RotReceiver> CotReceiver for CorrelatedFromRandom<R>
237where
238    R::Error: From<ConnectionError> + From<std::io::Error>,
239{
240    type Error = R::Error;
241
242    async fn correlated_receive_into<B>(
243        &mut self,
244        ots: &mut B,
245        choices: &[subtle::Choice],
246    ) -> Result<(), Self::Error>
247    where
248        B: Buf<Block>,
249    {
250        self.0.receive_into(ots, choices).await?;
251        let mut recv_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
252        let (_, mut rx) = self.connection().byte_stream().await?;
253        for (ot_chunk, choice_chunk) in ots
254            .chunks_mut(COR_CHUNK_SIZE)
255            .zip(choices.chunks(COR_CHUNK_SIZE))
256        {
257            rx.read_exact(bytemuck::must_cast_slice_mut(
258                &mut recv_buf[..ot_chunk.len()],
259            ))
260            .await?;
261            for ((ot, correction), choice) in ot_chunk.iter_mut().zip(&recv_buf).zip(choice_chunk) {
262                let use_correction = Block::conditional_select(&Block::ZERO, &Block::ONES, *choice);
263                *ot ^= use_correction & *correction;
264            }
265        }
266        Ok(())
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use cryprot_net::testing::{init_tracing, local_conn};
273    use rand::{SeedableRng, rngs::StdRng};
274
275    use crate::{
276        RotReceiver, RotSender,
277        adapter::ChosenChoice,
278        random_choices,
279        silent_ot::{SemiHonestSilentOtReceiver, SemiHonestSilentOtSender},
280    };
281
282    #[tokio::test]
283    async fn test_chosen_choice_adapter() {
284        let _g = init_tracing();
285        let (c1, c2) = local_conn().await.unwrap();
286        let mut sender = ChosenChoice::new(SemiHonestSilentOtSender::new(c1));
287        let mut receiver = ChosenChoice::new(SemiHonestSilentOtReceiver::new(c2));
288
289        let count = 2_usize.pow(10);
290        let choices = random_choices(count, &mut StdRng::seed_from_u64(234));
291
292        let (s_ots, r_ots) =
293            tokio::try_join!(sender.send(count), receiver.receive(&choices)).unwrap();
294
295        for (i, c) in choices.iter().enumerate() {
296            assert_eq!(
297                s_ots[i][c.unwrap_u8() as usize],
298                r_ots[i],
299                "ot {i}, choice: {}, s_ots: {:?}, r_ot: {:?}",
300                c.unwrap_u8(),
301                s_ots[i],
302                r_ots[i]
303            );
304        }
305    }
306}