1use std::io;
13
14use cryprot_core::{
15 Block,
16 buf::Buf,
17 rand_compat::RngCompat,
18 random_oracle::{Hash, RandomOracle},
19};
20use cryprot_net::{Connection, ConnectionError};
21use curve25519_dalek::{RistrettoPoint, Scalar, constants::RISTRETTO_BASEPOINT_TABLE};
22use futures::{SinkExt, StreamExt};
23use rand::{Rng, SeedableRng, rngs::StdRng};
24use subtle::{Choice, ConditionallySelectable};
25use tracing::Level;
26
27use crate::{Connected, Malicious, RotReceiver, RotSender, SemiHonest, phase};
28
29pub struct SimplestOt {
31 rng: StdRng,
32 conn: Connection,
33}
34
35impl SimplestOt {
36 pub fn new(connection: Connection) -> Self {
37 Self::new_with_rng(connection, StdRng::from_os_rng())
38 }
39
40 pub fn new_with_rng(connection: Connection, rng: StdRng) -> SimplestOt {
41 Self {
42 conn: connection,
43 rng,
44 }
45 }
46}
47
48impl Connected for SimplestOt {
49 fn connection(&mut self) -> &mut Connection {
50 &mut self.conn
51 }
52}
53
54#[derive(thiserror::Error, Debug)]
55pub enum Error {
56 #[error("quic connection error")]
57 Connection(#[from] ConnectionError),
58 #[error("io communicaiton error")]
59 Io(#[from] io::Error),
60 #[error("insufficient points received. expected: {expected}, actual: {actual}")]
61 InsufficientPoints { expected: usize, actual: usize },
62 #[error("expected message but stream is closed")]
63 ClosedStream,
64 #[error("seed commitment and seed hash not equal")]
65 CommitmentHashesNotEqual,
66}
67
68impl SemiHonest for SimplestOt {}
69
70impl Malicious for SimplestOt {}
71
72impl RotSender for SimplestOt {
73 type Error = Error;
74
75 #[allow(non_snake_case)]
76 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
77 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
78 async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
79 let count = ots.len();
80 let a = Scalar::random(&mut RngCompat(&mut self.rng));
81 let mut A = RISTRETTO_BASEPOINT_TABLE * &a;
82 let seed: Block = self.rng.random();
93 let seed_commitment = seed.ro_hash();
95 let (mut send, mut recv) = self.conn.byte_stream().await?;
96 {
97 let mut send_m1 = send.as_stream();
98 send_m1.send((A, *seed_commitment.as_bytes())).await?;
99 }
100
101 let B_points: Vec<RistrettoPoint> = {
102 let mut recv_m2 = recv.as_stream();
103 recv_m2.next().await.ok_or(Error::ClosedStream)??
104 };
105 if B_points.len() != count {
106 return Err(Error::InsufficientPoints {
107 expected: count,
108 actual: B_points.len(),
109 });
110 }
111 {
113 let mut send_m3 = send.as_stream();
114 send_m3.send(seed).await?;
115 }
116
117 A *= a;
118 for (i, (mut B, ots)) in B_points.into_iter().zip(ots.iter_mut()).enumerate() {
119 B *= a;
120 let k0 = ro_hash_point(&B, i, seed);
121 B -= A;
122 let k1 = ro_hash_point(&B, i, seed);
123 *ots = [k0, k1];
124 }
125 Ok(())
126 }
127}
128
129impl RotReceiver for SimplestOt {
130 type Error = Error;
131
132 #[allow(non_snake_case)]
133 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
134 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
135 async fn receive_into(
136 &mut self,
137 ots: &mut impl Buf<Block>,
138 choices: &[Choice],
139 ) -> Result<(), Self::Error> {
140 assert_eq!(choices.len(), ots.len());
141 let (mut send, mut recv) = self.conn.byte_stream().await?;
142 let (A, commitment): (RistrettoPoint, [u8; 32]) = {
143 let mut recv_m1 = recv.as_stream();
144 recv_m1.next().await.ok_or(Error::ClosedStream)??
145 };
146
147 let (b_points, B_points): (Vec<_>, Vec<_>) = choices
148 .iter()
149 .map(|choice| {
150 let b = Scalar::random(&mut RngCompat(&mut self.rng));
151 let B_0 = RISTRETTO_BASEPOINT_TABLE * &b;
152 let B_1 = B_0 + A;
153 let B_choice = RistrettoPoint::conditional_select(&B_0, &B_1, *choice);
154 (b, B_choice)
155 })
156 .unzip();
157 {
158 let mut send_m2 = send.as_stream();
159 send_m2.send(B_points).await?;
160 }
161
162 let seed: Block = {
163 let mut recv_3 = recv.as_stream();
164 recv_3.next().await.ok_or(Error::ClosedStream)??
165 };
166 if Hash::from_bytes(commitment) != seed.ro_hash() {
167 return Err(Error::CommitmentHashesNotEqual);
168 }
169 for (i, (b, ot)) in b_points.into_iter().zip(ots.iter_mut()).enumerate() {
170 let B = A * b;
171 *ot = ro_hash_point(&B, i, seed);
172 }
173 Ok(())
174 }
175}
176
177fn ro_hash_point(point: &RistrettoPoint, tweak: usize, seed: Block) -> Block {
178 let mut ro = RandomOracle::new();
179 ro.update(point.compress().as_bytes());
180 ro.update(&tweak.to_le_bytes());
181 ro.update(seed.as_bytes());
183 let mut out_reader = ro.finalize_xof();
184 let mut ret = Block::ZERO;
185 out_reader.fill(ret.as_mut_bytes());
186 ret
187}
188
189#[cfg(test)]
190mod tests {
191 use anyhow::Result;
192 use cryprot_net::testing::{init_tracing, local_conn};
193 use rand::{SeedableRng, rngs::StdRng};
194
195 use super::SimplestOt;
196 use crate::{RotReceiver, RotSender, random_choices};
197
198 #[tokio::test]
199 async fn base_rot() -> Result<()> {
200 let _g = init_tracing();
201 let (c1, c2) = local_conn().await?;
202 let mut rng1 = StdRng::seed_from_u64(42);
203 let rng2 = StdRng::seed_from_u64(42 * 42);
204 let count = 128;
205 let choices = random_choices(count, &mut rng1);
206
207 let mut sender = SimplestOt::new_with_rng(c1, rng1);
208 let mut receiver = SimplestOt::new_with_rng(c2, rng2);
209 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
210
211 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
212 assert_eq!(r, s[c.unwrap_u8() as usize])
213 }
214 Ok(())
215 }
216}