1use std::{io, mem::size_of};
8
9use cryprot_core::{Block, buf::Buf, random_oracle::RandomOracle};
10use cryprot_net::{Connection, ConnectionError};
11use futures::{SinkExt, StreamExt};
12use hybrid_array::typenum::Unsigned;
13use ml_kem::{
14 Ciphertext as MlKemCiphertext, InvalidKey, Kem, KeyExport, KeySizeUser, ParameterSet,
15 SharedKey,
16 kem::{Decapsulate, DecapsulationKey, Encapsulate, EncapsulationKey as MlKemEncapsulationKey},
17};
18cfg_select! {
21 feature = "ml-kem-base-ot-1024" => {
22 use ml_kem::{MlKem1024 as MlKem};
23 }
24 feature = "ml-kem-base-ot-768" => {
25 use ml_kem::{MlKem768 as MlKem};
26 }
27 feature = "ml-kem-base-ot-512" => {
28 use ml_kem::{MlKem512 as MlKem};
29 }
30}
31
32use module_lattice::{Encode, Field, NttPolynomial};
33use rand::{RngExt, rngs::StdRng};
34use serde::{Deserialize, Serialize};
35use sha3::Digest;
36use shake::{ExtendableOutput, Shake128, Update, XofReader};
37use subtle::{Choice, ConditionallySelectable};
38use tracing::Level;
39
40use crate::{Connected, RotReceiver, RotSender, SemiHonest, phase};
41
42module_lattice::define_field!(MlKemField, u16, u32, u64, 3329);
44
45type K = <MlKem as ParameterSet>::K;
47
48type NttVector = module_lattice::NttVector<MlKemField, K>;
49
50type U12 = hybrid_array::typenum::U12;
51
52const ENCAPSULATION_KEY_LEN: usize = <MlKemEncapsulationKey<MlKem> as KeySizeUser>::KeySize::USIZE;
53const CIPHERTEXT_LEN: usize = <MlKem as Kem>::CiphertextSize::USIZE;
54const HASH_DOMAIN_SEPARATOR: &[u8] = b"MlKemOt";
55
56const NUM_COEFFICIENTS: usize = 256;
58
59type Seed = [u8; 32];
60
61type Rho = [u8; 32];
62
63const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Rho>();
65
66struct EncapsulationKey {
68 t_hat: NttVector,
69 rho: Rho,
70}
71
72impl EncapsulationKey {
73 fn from_bytes(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> Self {
74 let enc = bytes[..T_HAT_BYTES_LEN]
75 .try_into()
76 .expect("t_hat length mismatch");
77 let t_hat = <NttVector as Encode<U12>>::decode(enc);
78 let rho = bytes[T_HAT_BYTES_LEN..]
79 .try_into()
80 .expect("rho length mismatch");
81 Self { t_hat, rho }
82 }
83
84 fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_LEN] {
85 let encoded = <NttVector as Encode<U12>>::encode(&self.t_hat);
86 let mut out = [0u8; ENCAPSULATION_KEY_LEN];
87 out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
88 out[T_HAT_BYTES_LEN..].copy_from_slice(&self.rho);
89 out
90 }
91}
92
93impl std::ops::Sub<&NttVector> for &EncapsulationKey {
94 type Output = EncapsulationKey;
95
96 fn sub(self, rhs: &NttVector) -> EncapsulationKey {
97 EncapsulationKey {
98 t_hat: &self.t_hat - rhs,
99 rho: self.rho,
100 }
101 }
102}
103
104impl std::ops::Add<&NttVector> for &EncapsulationKey {
105 type Output = EncapsulationKey;
106
107 fn add(self, rhs: &NttVector) -> EncapsulationKey {
108 EncapsulationKey {
109 t_hat: &self.t_hat + rhs,
110 rho: self.rho,
111 }
112 }
113}
114
115fn xof(seed: &Seed, i: u8, j: u8) -> impl XofReader {
117 let mut h = Shake128::default();
118 h.update(seed);
119 h.update(&[i, j]);
120 h.finalize_xof()
121}
122
123fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
129 const Q: u16 = MlKemField::Q;
130 const BUF_LEN: usize = 32 * 3;
133 let mut poly = NttPolynomial::<MlKemField>::default();
134 let mut buf = [0u8; BUF_LEN];
135 xof.read(&mut buf);
136 let mut pos = 0;
137 let mut i = 0;
138
139 while i < NUM_COEFFICIENTS {
140 if pos >= BUF_LEN {
142 xof.read(&mut buf);
143 pos = 0;
144 }
145
146 let d1 = u16::from(buf[pos]) | ((u16::from(buf[pos + 1]) & 0x0F) << 8);
147 let d2 = (u16::from(buf[pos + 1]) >> 4) | (u16::from(buf[pos + 2]) << 4);
148 pos += 3;
149
150 if d1 < Q {
151 poly.0[i] = module_lattice::Elem::new(d1);
152 i += 1;
153 }
154 if i < NUM_COEFFICIENTS && d2 < Q {
155 poly.0[i] = module_lattice::Elem::new(d2);
156 i += 1;
157 }
158 }
159
160 poly
161}
162
163fn sample_ntt_vector(seed: &Seed) -> NttVector {
166 NttVector::new(
167 (0..K::USIZE)
168 .map(|j| {
169 let mut reader = xof(seed, 0, j as u8);
170 sample_ntt_poly(&mut reader)
171 })
172 .collect(),
173 )
174}
175
176fn hash_ek(ek: &EncapsulationKey) -> NttVector {
180 let encoded = <NttVector as Encode<U12>>::encode(&ek.t_hat);
181 let seed: Seed = sha3::Sha3_256::digest(encoded.as_slice()).into();
182 sample_ntt_vector(&seed)
183}
184
185fn random_ek(rng: &mut StdRng, rho: Rho) -> EncapsulationKey {
191 let seed: Seed = rng.random();
192 EncapsulationKey {
193 t_hat: sample_ntt_vector(&seed),
194 rho,
195 }
196}
197
198#[derive(thiserror::Error, Debug)]
199pub enum Error {
200 #[error("quic connection error")]
201 Connection(#[from] ConnectionError),
202 #[error("io communication error")]
203 Io(#[from] io::Error),
204 #[error(
205 "invalid count of keys/ciphertexts received. expected: {expected}, actual_0: {actual_0}, actual_1: {actual_1}"
206 )]
207 InvalidDataCount {
208 expected: usize,
209 actual_0: usize,
210 actual_1: usize,
211 },
212 #[error("expected message but stream is closed")]
213 ClosedStream,
214 #[error("ML-KEM decapsulation failed")]
215 Decapsulation,
216 #[error("EncapsulationKey Validation failed")]
217 EncapsKeyValidation(#[source] InvalidKey),
218}
219
220#[derive(Copy, Clone, Serialize, Deserialize)]
221struct EncapsulationKeyBytes(#[serde(with = "serde_bytes")] [u8; ENCAPSULATION_KEY_LEN]);
222
223impl From<EncapsulationKey> for EncapsulationKeyBytes {
224 fn from(ek: EncapsulationKey) -> Self {
225 Self(ek.to_bytes())
226 }
227}
228
229impl ConditionallySelectable for EncapsulationKeyBytes {
230 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
231 Self(<[u8; ENCAPSULATION_KEY_LEN]>::conditional_select(
232 &a.0, &b.0, choice,
233 ))
234 }
235}
236
237#[derive(Copy, Clone, Serialize, Deserialize)]
238struct CiphertextBytes(#[serde(with = "serde_bytes")] [u8; CIPHERTEXT_LEN]);
239
240impl ConditionallySelectable for CiphertextBytes {
241 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
242 Self(<[u8; CIPHERTEXT_LEN]>::conditional_select(
243 &a.0, &b.0, choice,
244 ))
245 }
246}
247
248#[derive(Serialize, Deserialize)]
250struct EncapsulationKeysMessage {
251 rs_0: Vec<EncapsulationKeyBytes>,
252 rs_1: Vec<EncapsulationKeyBytes>,
253}
254
255#[derive(Serialize, Deserialize)]
257struct CiphertextsMessage {
258 cts_0: Vec<CiphertextBytes>,
259 cts_1: Vec<CiphertextBytes>,
260}
261
262pub struct MlKemOt {
263 rng: StdRng,
264 conn: Connection,
265}
266
267impl SemiHonest for MlKemOt {}
268
269impl MlKemOt {
270 pub fn new(connection: Connection) -> Self {
271 Self::new_with_rng(connection, rand::make_rng())
272 }
273
274 pub fn new_with_rng(connection: Connection, rng: StdRng) -> MlKemOt {
275 Self {
276 conn: connection,
277 rng,
278 }
279 }
280}
281
282impl Connected for MlKemOt {
283 fn connection(&mut self) -> &mut Connection {
284 &mut self.conn
285 }
286}
287
288impl RotSender for MlKemOt {
289 type Error = Error;
290
291 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
292 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
293 async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
294 let count = ots.len();
295 let (mut send, mut recv) = self.conn.byte_stream().await?;
296
297 let receiver_msg: EncapsulationKeysMessage = {
298 let mut recv_stream = recv.as_stream();
299 recv_stream.next().await.ok_or(Error::ClosedStream)??
300 };
301
302 if receiver_msg.rs_0.len() != count || receiver_msg.rs_1.len() != count {
303 return Err(Error::InvalidDataCount {
304 expected: count,
305 actual_0: receiver_msg.rs_0.len(),
306 actual_1: receiver_msg.rs_1.len(),
307 });
308 }
309
310 let mut cts_0 = Vec::with_capacity(count);
311 let mut cts_1 = Vec::with_capacity(count);
312 for (i, (r_0_bytes, r_1_bytes)) in receiver_msg
313 .rs_0
314 .iter()
315 .zip(receiver_msg.rs_1.iter())
316 .enumerate()
317 {
318 let r_0 = EncapsulationKey::from_bytes(&r_0_bytes.0);
320 let r_1 = EncapsulationKey::from_bytes(&r_1_bytes.0);
321
322 let ek_0 = &r_0 + &hash_ek(&r_1);
324 let ek_1 = &r_1 + &hash_ek(&r_0);
325
326 let (ct_0, ss_0) = encapsulate(ek_0.into(), &mut self.rng)?;
328 let (ct_1, ss_1) = encapsulate(ek_1.into(), &mut self.rng)?;
329
330 let key_0 = derive_ot_key(&ss_0, i);
332 let key_1 = derive_ot_key(&ss_1, i);
333
334 cts_0.push(ct_0);
335 cts_1.push(ct_1);
336 ots[i] = [key_0, key_1];
337 }
338
339 let sender_msg = CiphertextsMessage { cts_0, cts_1 };
340 {
341 let mut send_stream = send.as_stream();
342 send_stream.send(sender_msg).await?;
343 }
344
345 Ok(())
346 }
347}
348
349impl RotReceiver for MlKemOt {
350 type Error = Error;
351
352 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
353 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
354 async fn receive_into(
355 &mut self,
356 ots: &mut impl Buf<Block>,
357 choices: &[Choice],
358 ) -> Result<(), Self::Error> {
359 let count = ots.len();
360 assert_eq!(choices.len(), count);
361
362 let (mut send, mut recv) = self.conn.byte_stream().await?;
363
364 let mut decap_keys: Vec<DecapsulationKey<MlKem>> = Vec::with_capacity(count);
365 let mut rs_0 = Vec::with_capacity(count);
366 let mut rs_1 = Vec::with_capacity(count);
367
368 for choice in choices.iter() {
369 let (dk, ek) = MlKem::generate_keypair_from_rng(&mut self.rng);
371 let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] = ek
372 .to_bytes()
373 .as_slice()
374 .try_into()
375 .expect("incorrect encapsulation key size");
376 let ek = EncapsulationKey::from_bytes(&ek_bytes);
377
378 let r_1_b = random_ek(&mut self.rng, ek.rho);
380
381 let r_b = &ek - &hash_ek(&r_1_b);
383 let r_b_bytes: EncapsulationKeyBytes = r_b.into();
384 let r_1_b_bytes: EncapsulationKeyBytes = r_1_b.into();
385
386 let r_0 = EncapsulationKeyBytes::conditional_select(&r_b_bytes, &r_1_b_bytes, *choice);
390 let r_1 = EncapsulationKeyBytes::conditional_select(&r_1_b_bytes, &r_b_bytes, *choice);
391
392 decap_keys.push(dk);
393 rs_0.push(r_0);
394 rs_1.push(r_1);
395 }
396
397 let receiver_msg = EncapsulationKeysMessage { rs_0, rs_1 };
398 {
399 let mut send_stream = send.as_stream();
400 send_stream.send(receiver_msg).await?;
401 }
402
403 let sender_msg: CiphertextsMessage = {
404 let mut recv_stream = recv.as_stream();
405 recv_stream.next().await.ok_or(Error::ClosedStream)??
406 };
407
408 if sender_msg.cts_0.len() != count || sender_msg.cts_1.len() != count {
409 return Err(Error::InvalidDataCount {
410 expected: count,
411 actual_0: sender_msg.cts_0.len(),
412 actual_1: sender_msg.cts_1.len(),
413 });
414 }
415
416 for (i, ((dk, choice), (ct_0, ct_1))) in decap_keys
418 .iter()
419 .zip(choices.iter())
420 .zip(sender_msg.cts_0.iter().zip(sender_msg.cts_1.iter()))
421 .enumerate()
422 {
423 let ct_b_bytes = CiphertextBytes::conditional_select(ct_0, ct_1, *choice).0;
424 let ct_b: MlKemCiphertext<MlKem> = ct_b_bytes
425 .as_slice()
426 .try_into()
427 .expect("incorrect ciphertext size");
428 let shared_secret = dk.decapsulate(&ct_b);
429 let key_b = derive_ot_key(&shared_secret, i);
430 ots[i] = key_b;
431 }
432
433 Ok(())
434 }
435}
436
437fn encapsulate(
439 ek: EncapsulationKeyBytes,
440 rng: &mut StdRng,
441) -> Result<(CiphertextBytes, SharedKey), Error> {
442 let parsed_ek =
443 MlKemEncapsulationKey::<MlKem>::new(&ek.0.into()).map_err(Error::EncapsKeyValidation)?;
444 let (ct, ss): (MlKemCiphertext<MlKem>, SharedKey) = parsed_ek.encapsulate_with_rng(rng);
445 Ok((
446 CiphertextBytes(ct.as_slice().try_into().expect("incorrect ciphertext size")),
447 ss,
448 ))
449}
450
451fn derive_ot_key(key: &SharedKey, tweak: usize) -> Block {
454 let mut ro = RandomOracle::new();
455 ro.update(HASH_DOMAIN_SEPARATOR);
456 ro.update(key.as_slice());
457 ro.update(&tweak.to_le_bytes());
458 let mut out = ro.finalize_xof();
459 let mut block = Block::ZERO;
460 out.fill(block.as_mut_bytes());
461 block
462}
463
464#[cfg(test)]
465mod tests {
466 use anyhow::Result;
467 use cryprot_net::testing::{init_tracing, local_conn};
468 use rand::{SeedableRng, rngs::StdRng};
469
470 use super::MlKemOt;
471 use crate::{RotReceiver, RotSender, random_choices};
472
473 #[tokio::test]
474 async fn mlkem_base_rot_random_choices() -> Result<()> {
475 let _g = init_tracing();
476 let (con1, con2) = local_conn().await?;
477 let mut rng1 = StdRng::seed_from_u64(42);
478 let rng2 = StdRng::seed_from_u64(42 * 42);
479 let count = 128;
480 let choices = random_choices(count, &mut rng1);
481
482 let mut sender = MlKemOt::new_with_rng(con1, rng1);
483 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
484 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
485
486 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
487 assert_eq!(r, s[c.unwrap_u8() as usize])
488 }
489 Ok(())
490 }
491
492 #[tokio::test]
493 async fn mlkem_base_rot_zero_choices() -> Result<()> {
494 let _g = init_tracing();
495 let (con1, con2) = local_conn().await?;
496 let rng1 = StdRng::seed_from_u64(123);
497 let rng2 = StdRng::seed_from_u64(456);
498 let count = 128;
499 let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(0)).collect();
500
501 let mut sender = MlKemOt::new_with_rng(con1, rng1);
502 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
503 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
504
505 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
506 assert_eq!(r, s[c.unwrap_u8() as usize])
507 }
508 Ok(())
509 }
510
511 #[tokio::test]
512 async fn mlkem_base_rot_one_choices() -> Result<()> {
513 let _g = init_tracing();
514 let (con1, con2) = local_conn().await?;
515 let rng1 = StdRng::seed_from_u64(789);
516 let rng2 = StdRng::seed_from_u64(101112);
517 let count = 128;
518 let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(1)).collect();
519
520 let mut sender = MlKemOt::new_with_rng(con1, rng1);
521 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
522 let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
523
524 for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
525 assert_eq!(r, s[c.unwrap_u8() as usize])
526 }
527 Ok(())
528 }
529
530 #[tokio::test]
531 async fn mlkem_base_rot_single_ot() -> Result<()> {
532 let _g = init_tracing();
533 let (con1, con2) = local_conn().await?;
534 let rng1 = StdRng::seed_from_u64(42);
535 let rng2 = StdRng::seed_from_u64(43);
536 let choices = vec![subtle::Choice::from(1)];
537
538 let mut sender = MlKemOt::new_with_rng(con1, rng1);
539 let mut receiver = MlKemOt::new_with_rng(con2, rng2);
540 let (s_ot, r_ot) = tokio::try_join!(sender.send(1), receiver.receive(&choices))?;
541
542 assert_eq!(r_ot[0], s_ot[0][1]);
543 Ok(())
544 }
545}