1use std::{io, mem::size_of};
8
9use cryprot_core::{Block, buf::Buf, rand_compat::RngCompat, random_oracle::RandomOracle};
10use cryprot_net::{Connection, ConnectionError};
11use futures::{SinkExt, StreamExt};
12use hybrid_array::typenum::Unsigned;
13use ml_kem::{
14 Ciphertext as MlKemCiphertext, EncodedSizeUser, KemCore, ParameterSet, SharedKey,
15 kem::{Decapsulate, DecapsulationKey, Encapsulate, EncapsulationKey as MlKemEncapsulationKey},
16};
17cfg_select! {
20 feature = "ml-kem-base-ot-1024" => {
21 use ml_kem::{MlKem1024 as MlKem, MlKem1024Params as MlKemParams};
22 }
23 feature = "ml-kem-base-ot-768" => {
24 use ml_kem::{MlKem768 as MlKem, MlKem768Params as MlKemParams};
25 }
26 feature = "ml-kem-base-ot-512" => {
27 use ml_kem::{MlKem512 as MlKem, MlKem512Params as MlKemParams};
28 }
29}
30
31use module_lattice::{Encode, Field, NttPolynomial};
32use rand::{RngExt, rngs::StdRng};
33use serde::{Deserialize, Serialize};
34use sha3::{
35 Digest, Shake128,
36 digest::{ExtendableOutput, Update, XofReader},
37};
38use subtle::{Choice, ConditionallySelectable};
39use tracing::Level;
40
41use crate::{Connected, RotReceiver, RotSender, SemiHonest, phase};
42
43module_lattice::define_field!(MlKemField, u16, u32, u64, 3329);
45
46type K = <MlKemParams as ParameterSet>::K;
48
49type NttVector = module_lattice::NttVector<MlKemField, K>;
50
51type U12 = hybrid_array::typenum::U12;
52
53const ENCAPSULATION_KEY_LEN: usize =
54 <MlKemEncapsulationKey<MlKemParams> as EncodedSizeUser>::EncodedSize::USIZE;
55const CIPHERTEXT_LEN: usize = <MlKem as KemCore>::CiphertextSize::USIZE;
56const HASH_DOMAIN_SEPARATOR: &[u8] = b"MlKemOt";
57
58const NUM_COEFFICIENTS: usize = 256;
60
61type Seed = [u8; 32];
62
63type Rho = [u8; 32];
64
65const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Rho>();
67
68struct EncapsulationKey {
70 t_hat: NttVector,
71 rho: Rho,
72}
73
74impl EncapsulationKey {
75 fn from_bytes(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> Self {
76 let enc = bytes[..T_HAT_BYTES_LEN]
77 .try_into()
78 .expect("t_hat length mismatch");
79 let t_hat = <NttVector as Encode<U12>>::decode(enc);
80 let rho = bytes[T_HAT_BYTES_LEN..]
81 .try_into()
82 .expect("rho length mismatch");
83 Self { t_hat, rho }
84 }
85
86 fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_LEN] {
87 let encoded = <NttVector as Encode<U12>>::encode(&self.t_hat);
88 let mut out = [0u8; ENCAPSULATION_KEY_LEN];
89 out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
90 out[T_HAT_BYTES_LEN..].copy_from_slice(&self.rho);
91 out
92 }
93}
94
95impl std::ops::Sub<&NttVector> for &EncapsulationKey {
96 type Output = EncapsulationKey;
97
98 fn sub(self, rhs: &NttVector) -> EncapsulationKey {
99 EncapsulationKey {
100 t_hat: &self.t_hat - rhs,
101 rho: self.rho,
102 }
103 }
104}
105
106impl std::ops::Add<&NttVector> for &EncapsulationKey {
107 type Output = EncapsulationKey;
108
109 fn add(self, rhs: &NttVector) -> EncapsulationKey {
110 EncapsulationKey {
111 t_hat: &self.t_hat + rhs,
112 rho: self.rho,
113 }
114 }
115}
116
117fn xof(seed: &Seed, i: u8, j: u8) -> impl XofReader {
119 let mut h = Shake128::default();
120 h.update(seed);
121 h.update(&[i, j]);
122 h.finalize_xof()
123}
124
125fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
131 const Q: u16 = MlKemField::Q;
132 const BUF_LEN: usize = 32 * 3;
135 let mut poly = NttPolynomial::<MlKemField>::default();
136 let mut buf = [0u8; BUF_LEN];
137 xof.read(&mut buf);
138 let mut pos = 0;
139 let mut i = 0;
140
141 while i < NUM_COEFFICIENTS {
142 if pos >= BUF_LEN {
144 xof.read(&mut buf);
145 pos = 0;
146 }
147
148 let d1 = u16::from(buf[pos]) | ((u16::from(buf[pos + 1]) & 0x0F) << 8);
149 let d2 = (u16::from(buf[pos + 1]) >> 4) | (u16::from(buf[pos + 2]) << 4);
150 pos += 3;
151
152 if d1 < Q {
153 poly.0[i] = module_lattice::Elem::new(d1);
154 i += 1;
155 }
156 if i < NUM_COEFFICIENTS && d2 < Q {
157 poly.0[i] = module_lattice::Elem::new(d2);
158 i += 1;
159 }
160 }
161
162 poly
163}
164
165fn sample_ntt_vector(seed: &Seed) -> NttVector {
168 NttVector::new(
169 (0..K::USIZE)
170 .map(|j| {
171 let mut reader = xof(seed, 0, j as u8);
172 sample_ntt_poly(&mut reader)
173 })
174 .collect(),
175 )
176}
177
178fn hash_ek(ek: &EncapsulationKey) -> NttVector {
182 let encoded = <NttVector as Encode<U12>>::encode(&ek.t_hat);
183 let seed: Seed = sha3::Sha3_256::digest(encoded.as_slice()).into();
184 sample_ntt_vector(&seed)
185}
186
187fn random_ek(rng: &mut StdRng, rho: Rho) -> EncapsulationKey {
193 let seed: Seed = rng.random();
194 EncapsulationKey {
195 t_hat: sample_ntt_vector(&seed),
196 rho,
197 }
198}
199
200#[derive(thiserror::Error, Debug)]
201pub enum Error {
202 #[error("quic connection error")]
203 Connection(#[from] ConnectionError),
204 #[error("io communication error")]
205 Io(#[from] io::Error),
206 #[error(
207 "invalid count of keys/ciphertexts received. expected: {expected}, actual_0: {actual_0}, actual_1: {actual_1}"
208 )]
209 InvalidDataCount {
210 expected: usize,
211 actual_0: usize,
212 actual_1: usize,
213 },
214 #[error("expected message but stream is closed")]
215 ClosedStream,
216 #[error("ML-KEM decapsulation failed")]
217 Decapsulation,
218}
219
220#[derive(Copy, Clone, Serialize, Deserialize)]
221struct EncapsulationKeyBytes(#[serde(with = "serde_bytes")] [u8; ENCAPSULATION_KEY_LEN]);
222
223impl From<&EncapsulationKey> for EncapsulationKeyBytes {
224 fn from(ek: &EncapsulationKey) -> Self {
225 Self(ek.to_bytes())
226 }
227}
228
229impl ConditionallySelectable for EncapsulationKeyBytes {
230 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
231 Self(<[u8; ENCAPSULATION_KEY_LEN]>::conditional_select(
232 &a.0, &b.0, choice,
233 ))
234 }
235}
236
237#[derive(Copy, Clone, Serialize, Deserialize)]
238struct CiphertextBytes(#[serde(with = "serde_bytes")] [u8; CIPHERTEXT_LEN]);
239
240impl ConditionallySelectable for CiphertextBytes {
241 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
242 Self(<[u8; CIPHERTEXT_LEN]>::conditional_select(
243 &a.0, &b.0, choice,
244 ))
245 }
246}
247
248#[derive(Serialize, Deserialize)]
250struct EncapsulationKeysMessage {
251 rs_0: Vec<EncapsulationKeyBytes>,
252 rs_1: Vec<EncapsulationKeyBytes>,
253}
254
255#[derive(Serialize, Deserialize)]
257struct CiphertextsMessage {
258 cts_0: Vec<CiphertextBytes>,
259 cts_1: Vec<CiphertextBytes>,
260}
261
262pub struct MlKemOt {
263 rng: StdRng,
264 conn: Connection,
265}
266
267impl SemiHonest for MlKemOt {}
268
269impl MlKemOt {
270 pub fn new(connection: Connection) -> Self {
271 Self::new_with_rng(connection, rand::make_rng())
272 }
273
274 pub fn new_with_rng(connection: Connection, rng: StdRng) -> MlKemOt {
275 Self {
276 conn: connection,
277 rng,
278 }
279 }
280}
281
282impl Connected for MlKemOt {
283 fn connection(&mut self) -> &mut Connection {
284 &mut self.conn
285 }
286}
287
288impl RotSender for MlKemOt {
289 type Error = Error;
290
291 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
292 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
293 async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
294 let count = ots.len();
295 let (mut send, mut recv) = self.conn.byte_stream().await?;
296
297 let receiver_msg: EncapsulationKeysMessage = {
298 let mut recv_stream = recv.as_stream();
299 recv_stream.next().await.ok_or(Error::ClosedStream)??
300 };
301
302 if receiver_msg.rs_0.len() != count || receiver_msg.rs_1.len() != count {
303 return Err(Error::InvalidDataCount {
304 expected: count,
305 actual_0: receiver_msg.rs_0.len(),
306 actual_1: receiver_msg.rs_1.len(),
307 });
308 }
309
310 let mut cts_0 = Vec::with_capacity(count);
311 let mut cts_1 = Vec::with_capacity(count);
312 for (i, (r_0_bytes, r_1_bytes)) in receiver_msg
313 .rs_0
314 .iter()
315 .zip(receiver_msg.rs_1.iter())
316 .enumerate()
317 {
318 let r_0 = EncapsulationKey::from_bytes(&r_0_bytes.0);
320 let r_1 = EncapsulationKey::from_bytes(&r_1_bytes.0);
321
322 let ek_0 = &r_0 + &hash_ek(&r_1);
324 let ek_1 = &r_1 + &hash_ek(&r_0);
325
326 let (ct_0, ss_0) = encapsulate(&(&ek_0).into(), &mut self.rng);
328 let (ct_1, ss_1) = encapsulate(&(&ek_1).into(), &mut self.rng);
329
330 let key_0 = derive_ot_key(&ss_0, i);
332 let key_1 = derive_ot_key(&ss_1, i);
333
334 cts_0.push(ct_0);
335 cts_1.push(ct_1);
336 ots[i] = [key_0, key_1];
337 }
338
339 let sender_msg = CiphertextsMessage { cts_0, cts_1 };
340 {
341 let mut send_stream = send.as_stream();
342 send_stream.send(sender_msg).await?;
343 }
344
345 Ok(())
346 }
347}
348
349impl RotReceiver for MlKemOt {
350 type Error = Error;
351
352 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
353 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
354 async fn receive_into(
355 &mut self,
356 ots: &mut impl Buf<Block>,
357 choices: &[Choice],
358 ) -> Result<(), Self::Error> {
359 let count = ots.len();
360 assert_eq!(choices.len(), count);
361
362 let (mut send, mut recv) = self.conn.byte_stream().await?;
363
364 let mut decap_keys: Vec<DecapsulationKey<MlKemParams>> = Vec::with_capacity(count);
365 let mut rs_0 = Vec::with_capacity(count);
366 let mut rs_1 = Vec::with_capacity(count);
367
368 for choice in choices.iter() {
369 let (dk, ek) = MlKem::generate(&mut RngCompat(&mut self.rng));
371 let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] = ek
372 .as_bytes()
373 .as_slice()
374 .try_into()
375 .expect("incorrect encapsulation key size");
376 let ek = EncapsulationKey::from_bytes(&ek_bytes);
377
378 let r_1_b = random_ek(&mut self.rng, ek.rho);
380
381 let r_b = &ek - &hash_ek(&r_1_b);
383 let r_b_bytes: EncapsulationKeyBytes = (&r_b).into();
384 let r_1_b_bytes: EncapsulationKeyBytes = (&r_1_b).into();
385
386 let r_0 = EncapsulationKeyBytes::conditional_select(&r_b_bytes, &r_1_b_bytes, *choice);
390 let r_1 = EncapsulationKeyBytes::conditional_select(&r_1_b_bytes, &r_b_bytes, *choice);
391
392 decap_keys.push(dk);
393 rs_0.push(r_0);
394 rs_1.push(r_1);
395 }
396
397 let receiver_msg = EncapsulationKeysMessage { rs_0, rs_1 };
398 {
399 let mut send_stream = send.as_stream();
400 send_stream.send(receiver_msg).await?;
401 }
402
403 let sender_msg: CiphertextsMessage = {
404 let mut recv_stream = recv.as_stream();
405 recv_stream.next().await.ok_or(Error::ClosedStream)??
406 };
407
408 if sender_msg.cts_0.len() != count || sender_msg.cts_1.len() != count {
409 return Err(Error::InvalidDataCount {
410 expected: count,
411 actual_0: sender_msg.cts_0.len(),
412 actual_1: sender_msg.cts_1.len(),
413 });
414 }
415
416 for (i, ((dk, choice), (ct_0, ct_1))) in decap_keys
418 .iter()
419 .zip(choices.iter())
420 .zip(sender_msg.cts_0.iter().zip(sender_msg.cts_1.iter()))
421 .enumerate()
422 {
423 let ct_b_bytes = CiphertextBytes::conditional_select(ct_0, ct_1, *choice).0;
424 let ct_b: MlKemCiphertext<MlKem> = ct_b_bytes
425 .as_slice()
426 .try_into()
427 .expect("incorrect ciphertext size");
428 let shared_secret = dk.decapsulate(&ct_b).map_err(|_| Error::Decapsulation)?;
429 let key_b = derive_ot_key(&shared_secret, i);
430 ots[i] = key_b;
431 }
432
433 Ok(())
434 }
435}
436
437fn encapsulate(
441 ek: &EncapsulationKeyBytes,
442 rng: &mut StdRng,
443) -> (CiphertextBytes, SharedKey<MlKem>) {
444 let parsed_ek = MlKemEncapsulationKey::<MlKemParams>::from_bytes((&ek.0).into());
445 let (ct, ss): (MlKemCiphertext<MlKem>, SharedKey<MlKem>) = parsed_ek
446 .encapsulate(&mut RngCompat(rng))
447 .expect("encapsulation failed");
448 (
449 CiphertextBytes(ct.as_slice().try_into().expect("incorrect ciphertext size")),
450 ss,
451 )
452}
453
454fn derive_ot_key(key: &SharedKey<MlKem>, tweak: usize) -> Block {
457 let mut ro = RandomOracle::new();
458 ro.update(HASH_DOMAIN_SEPARATOR);
459 ro.update(key.as_slice());
460 ro.update(&tweak.to_le_bytes());
461 let mut out = ro.finalize_xof();
462 let mut block = Block::ZERO;
463 out.fill(block.as_mut_bytes());
464 block
465}
466
467#[cfg(test)]
468mod tests {
469 use anyhow::Result;
470 use cryprot_net::testing::{init_tracing, local_conn};
471 use rand::{SeedableRng, rngs::StdRng};
472
473 use super::MlKemOt;
474 use crate::{RotReceiver, RotSender, random_choices};
475
476 #[tokio::test]
477 async fn mlkem_base_rot_random_choices() -> Result<()> {
478 let _g = init_tracing();
479 let (con1, con2) = local_conn().await?;
480 let mut rng1 = StdRng::seed_from_u64(42);
481 let rng2 = StdRng::seed_from_u64(42 * 42);
482 let count = 128;
483 let choices = random_choices(count, &mut rng1);
484
485 let mut sender = MlKemOt::new_with_rng(con1, rng1);
486 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
487 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
488
489 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
490 assert_eq!(r, s[c.unwrap_u8() as usize])
491 }
492 Ok(())
493 }
494
495 #[tokio::test]
496 async fn mlkem_base_rot_zero_choices() -> Result<()> {
497 let _g = init_tracing();
498 let (con1, con2) = local_conn().await?;
499 let rng1 = StdRng::seed_from_u64(123);
500 let rng2 = StdRng::seed_from_u64(456);
501 let count = 128;
502 let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(0)).collect();
503
504 let mut sender = MlKemOt::new_with_rng(con1, rng1);
505 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
506 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
507
508 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
509 assert_eq!(r, s[c.unwrap_u8() as usize])
510 }
511 Ok(())
512 }
513
514 #[tokio::test]
515 async fn mlkem_base_rot_one_choices() -> Result<()> {
516 let _g = init_tracing();
517 let (con1, con2) = local_conn().await?;
518 let rng1 = StdRng::seed_from_u64(789);
519 let rng2 = StdRng::seed_from_u64(101112);
520 let count = 128;
521 let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(1)).collect();
522
523 let mut sender = MlKemOt::new_with_rng(con1, rng1);
524 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
525 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
526
527 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
528 assert_eq!(r, s[c.unwrap_u8() as usize])
529 }
530 Ok(())
531 }
532
533 #[tokio::test]
534 async fn mlkem_base_rot_single_ot() -> Result<()> {
535 let _g = init_tracing();
536 let (con1, con2) = local_conn().await?;
537 let rng1 = StdRng::seed_from_u64(42);
538 let rng2 = StdRng::seed_from_u64(43);
539 let choices = vec![subtle::Choice::from(1)];
540
541 let mut sender = MlKemOt::new_with_rng(con1, rng1);
542 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
543 let (s_ot, r_ot) = tokio::try_join!(sender.send(1), receiver.receive(&choices))?;
544
545 assert_eq!(r_ot[0], s_ot[0][1]);
546 Ok(())
547 }
548}