Skip to main content

cryprot_ot/
mlkem_ot.rs

1//! Post-quantum base OT using ML-KEM.
2//!
3//! Implements the MR19 protocol (Masny-Rindal, ePrint 2019/706, Figure 8)
4//! instantiated with ML-KEM as per Section D.3.
5//! See `docs/mlkem-ot-protocol.md` for the full protocol description.
6
7use std::{io, mem::size_of};
8
9use cryprot_core::{Block, buf::Buf, random_oracle::RandomOracle};
10use cryprot_net::{Connection, ConnectionError};
11use futures::{SinkExt, StreamExt};
12use hybrid_array::typenum::Unsigned;
13use ml_kem::{
14    Ciphertext as MlKemCiphertext, InvalidKey, Kem, KeyExport, KeySizeUser, ParameterSet,
15    SharedKey,
16    kem::{Decapsulate, DecapsulationKey, Encapsulate, EncapsulationKey as MlKemEncapsulationKey},
17};
18// ML-KEM parameter set selection. If multiple features are enabled, the highest
19// wins.
20cfg_select! {
21    feature = "ml-kem-base-ot-1024" => {
22        use ml_kem::{MlKem1024 as MlKem};
23    }
24    feature = "ml-kem-base-ot-768" => {
25        use ml_kem::{MlKem768 as MlKem};
26    }
27    feature = "ml-kem-base-ot-512" => {
28        use ml_kem::{MlKem512 as MlKem};
29    }
30}
31
32use module_lattice::{Encode, Field, NttPolynomial};
33use rand::{RngExt, rngs::StdRng};
34use serde::{Deserialize, Serialize};
35use sha3::Digest;
36use shake::{ExtendableOutput, Shake128, Update, XofReader};
37use subtle::{Choice, ConditionallySelectable};
38use tracing::Level;
39
40use crate::{Connected, RotReceiver, RotSender, SemiHonest, phase};
41
42// Define the ML-KEM base field (q = 3329).
43module_lattice::define_field!(MlKemField, u16, u32, u64, 3329);
44
45// Module dimension derived from the chosen ML-KEM parameter set.
46type K = <MlKem as ParameterSet>::K;
47
48type NttVector = module_lattice::NttVector<MlKemField, K>;
49
50type U12 = hybrid_array::typenum::U12;
51
52const ENCAPSULATION_KEY_LEN: usize = <MlKemEncapsulationKey<MlKem> as KeySizeUser>::KeySize::USIZE;
53const CIPHERTEXT_LEN: usize = <MlKem as Kem>::CiphertextSize::USIZE;
54const HASH_DOMAIN_SEPARATOR: &[u8] = b"MlKemOt";
55
56// Number of coefficients per polynomial (FIPS 203, Section 2: n = 256).
57const NUM_COEFFICIENTS: usize = 256;
58
59type Seed = [u8; 32];
60
61type Rho = [u8; 32];
62
63// Serialized t_hat is the encapsulation key minus the rho suffix.
64const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Rho>();
65
66// Parsed encapsulation key: ek = (t_hat, rho).
67struct EncapsulationKey {
68    t_hat: NttVector,
69    rho: Rho,
70}
71
72impl EncapsulationKey {
73    fn from_bytes(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> Self {
74        let enc = bytes[..T_HAT_BYTES_LEN]
75            .try_into()
76            .expect("t_hat length mismatch");
77        let t_hat = <NttVector as Encode<U12>>::decode(enc);
78        let rho = bytes[T_HAT_BYTES_LEN..]
79            .try_into()
80            .expect("rho length mismatch");
81        Self { t_hat, rho }
82    }
83
84    fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_LEN] {
85        let encoded = <NttVector as Encode<U12>>::encode(&self.t_hat);
86        let mut out = [0u8; ENCAPSULATION_KEY_LEN];
87        out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
88        out[T_HAT_BYTES_LEN..].copy_from_slice(&self.rho);
89        out
90    }
91}
92
93impl std::ops::Sub<&NttVector> for &EncapsulationKey {
94    type Output = EncapsulationKey;
95
96    fn sub(self, rhs: &NttVector) -> EncapsulationKey {
97        EncapsulationKey {
98            t_hat: &self.t_hat - rhs,
99            rho: self.rho,
100        }
101    }
102}
103
104impl std::ops::Add<&NttVector> for &EncapsulationKey {
105    type Output = EncapsulationKey;
106
107    fn add(self, rhs: &NttVector) -> EncapsulationKey {
108        EncapsulationKey {
109            t_hat: &self.t_hat + rhs,
110            rho: self.rho,
111        }
112    }
113}
114
115// XOF: SHAKE-128(seed || i || j), see FIPS 203 Section 4.1.
116fn xof(seed: &Seed, i: u8, j: u8) -> impl XofReader {
117    let mut h = Shake128::default();
118    h.update(seed);
119    h.update(&[i, j]);
120    h.finalize_xof()
121}
122
123// FIPS 203 Algorithm 7: SampleNTT.
124// Rejection sampling from a byte stream to produce a pseudorandom NTT
125// polynomial.
126//
127// Adapted from the ml-kem crate's `sample_ntt`.
128fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
129    const Q: u16 = MlKemField::Q;
130    // Read 32 triples (3 bytes each) at a time from the XOF.
131    // BUF_LEN must be divisible by 3 so pos always lands exactly on BUF_LEN.
132    const BUF_LEN: usize = 32 * 3;
133    let mut poly = NttPolynomial::<MlKemField>::default();
134    let mut buf = [0u8; BUF_LEN];
135    xof.read(&mut buf);
136    let mut pos = 0;
137    let mut i = 0;
138
139    while i < NUM_COEFFICIENTS {
140        // Refill the buffer from the XOF stream when exhausted.
141        if pos >= BUF_LEN {
142            xof.read(&mut buf);
143            pos = 0;
144        }
145
146        let d1 = u16::from(buf[pos]) | ((u16::from(buf[pos + 1]) & 0x0F) << 8);
147        let d2 = (u16::from(buf[pos + 1]) >> 4) | (u16::from(buf[pos + 2]) << 4);
148        pos += 3;
149
150        if d1 < Q {
151            poly.0[i] = module_lattice::Elem::new(d1);
152            i += 1;
153        }
154        if i < NUM_COEFFICIENTS && d2 < Q {
155            poly.0[i] = module_lattice::Elem::new(d2);
156            i += 1;
157        }
158    }
159
160    poly
161}
162
163// Produces a pseudorandom NttVector from a seed by calling sample_ntt_poly k
164// times, each with a different XOF stream: xof(seed, 0, j).
165fn sample_ntt_vector(seed: &Seed) -> NttVector {
166    NttVector::new(
167        (0..K::USIZE)
168            .map(|j| {
169                let mut reader = xof(seed, 0, j as u8);
170                sample_ntt_poly(&mut reader)
171            })
172            .collect(),
173    )
174}
175
176// Maps an encapsulation key to an NttVector via SHA3-256.
177// Only the t_hat component is used; rho is ignored.
178// Corresponds to libOTe's `pkHash`.
179fn hash_ek(ek: &EncapsulationKey) -> NttVector {
180    let encoded = <NttVector as Encode<U12>>::encode(&ek.t_hat);
181    let seed: Seed = sha3::Sha3_256::digest(encoded.as_slice()).into();
182    sample_ntt_vector(&seed)
183}
184
185// Generate a random encapsulation key using the given randomness and rho.
186//
187// The result is indistinguishable from a real encapsulation key, since a real
188// one has `t_hat = A_hat * s + e` and that is computationally indistinguishable
189// from a pseudorandom vector in `T_q^k`.
190fn random_ek(rng: &mut StdRng, rho: Rho) -> EncapsulationKey {
191    let seed: Seed = rng.random();
192    EncapsulationKey {
193        t_hat: sample_ntt_vector(&seed),
194        rho,
195    }
196}
197
198#[derive(thiserror::Error, Debug)]
199pub enum Error {
200    #[error("quic connection error")]
201    Connection(#[from] ConnectionError),
202    #[error("io communication error")]
203    Io(#[from] io::Error),
204    #[error(
205        "invalid count of keys/ciphertexts received. expected: {expected}, actual_0: {actual_0}, actual_1: {actual_1}"
206    )]
207    InvalidDataCount {
208        expected: usize,
209        actual_0: usize,
210        actual_1: usize,
211    },
212    #[error("expected message but stream is closed")]
213    ClosedStream,
214    #[error("ML-KEM decapsulation failed")]
215    Decapsulation,
216    #[error("EncapsulationKey Validation failed")]
217    EncapsKeyValidation(#[source] InvalidKey),
218}
219
220#[derive(Copy, Clone, Serialize, Deserialize)]
221struct EncapsulationKeyBytes(#[serde(with = "serde_bytes")] [u8; ENCAPSULATION_KEY_LEN]);
222
223impl From<EncapsulationKey> for EncapsulationKeyBytes {
224    fn from(ek: EncapsulationKey) -> Self {
225        Self(ek.to_bytes())
226    }
227}
228
229impl ConditionallySelectable for EncapsulationKeyBytes {
230    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
231        Self(<[u8; ENCAPSULATION_KEY_LEN]>::conditional_select(
232            &a.0, &b.0, choice,
233        ))
234    }
235}
236
237#[derive(Copy, Clone, Serialize, Deserialize)]
238struct CiphertextBytes(#[serde(with = "serde_bytes")] [u8; CIPHERTEXT_LEN]);
239
240impl ConditionallySelectable for CiphertextBytes {
241    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
242        Self(<[u8; CIPHERTEXT_LEN]>::conditional_select(
243            &a.0, &b.0, choice,
244        ))
245    }
246}
247
248// Message from receiver to sender: two values (r_0, r_1) per OT.
249#[derive(Serialize, Deserialize)]
250struct EncapsulationKeysMessage {
251    rs_0: Vec<EncapsulationKeyBytes>,
252    rs_1: Vec<EncapsulationKeyBytes>,
253}
254
255// Message from sender to receiver: two ciphertexts per OT.
256#[derive(Serialize, Deserialize)]
257struct CiphertextsMessage {
258    cts_0: Vec<CiphertextBytes>,
259    cts_1: Vec<CiphertextBytes>,
260}
261
262pub struct MlKemOt {
263    rng: StdRng,
264    conn: Connection,
265}
266
267impl SemiHonest for MlKemOt {}
268
269impl MlKemOt {
270    pub fn new(connection: Connection) -> Self {
271        Self::new_with_rng(connection, rand::make_rng())
272    }
273
274    pub fn new_with_rng(connection: Connection, rng: StdRng) -> MlKemOt {
275        Self {
276            conn: connection,
277            rng,
278        }
279    }
280}
281
282impl Connected for MlKemOt {
283    fn connection(&mut self) -> &mut Connection {
284        &mut self.conn
285    }
286}
287
288impl RotSender for MlKemOt {
289    type Error = Error;
290
291    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
292    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
293    async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
294        let count = ots.len();
295        let (mut send, mut recv) = self.conn.byte_stream().await?;
296
297        let receiver_msg: EncapsulationKeysMessage = {
298            let mut recv_stream = recv.as_stream();
299            recv_stream.next().await.ok_or(Error::ClosedStream)??
300        };
301
302        if receiver_msg.rs_0.len() != count || receiver_msg.rs_1.len() != count {
303            return Err(Error::InvalidDataCount {
304                expected: count,
305                actual_0: receiver_msg.rs_0.len(),
306                actual_1: receiver_msg.rs_1.len(),
307            });
308        }
309
310        let mut cts_0 = Vec::with_capacity(count);
311        let mut cts_1 = Vec::with_capacity(count);
312        for (i, (r_0_bytes, r_1_bytes)) in receiver_msg
313            .rs_0
314            .iter()
315            .zip(receiver_msg.rs_1.iter())
316            .enumerate()
317        {
318            // Step 5: Receive (r_0, r_1) from the receiver (done above).
319            let r_0 = EncapsulationKey::from_bytes(&r_0_bytes.0);
320            let r_1 = EncapsulationKey::from_bytes(&r_1_bytes.0);
321
322            // Step 6: Reconstruct encapsulation keys: ek_j = r_j + hash_ek(r_{1-j}).
323            let ek_0 = &r_0 + &hash_ek(&r_1);
324            let ek_1 = &r_1 + &hash_ek(&r_0);
325
326            // Step 7: Encapsulate to both reconstructed keys.
327            let (ct_0, ss_0) = encapsulate(ek_0.into(), &mut self.rng)?;
328            let (ct_1, ss_1) = encapsulate(ek_1.into(), &mut self.rng)?;
329
330            // Step 8: Derive OT output keys.
331            let key_0 = derive_ot_key(&ss_0, i);
332            let key_1 = derive_ot_key(&ss_1, i);
333
334            cts_0.push(ct_0);
335            cts_1.push(ct_1);
336            ots[i] = [key_0, key_1];
337        }
338
339        let sender_msg = CiphertextsMessage { cts_0, cts_1 };
340        {
341            let mut send_stream = send.as_stream();
342            send_stream.send(sender_msg).await?;
343        }
344
345        Ok(())
346    }
347}
348
349impl RotReceiver for MlKemOt {
350    type Error = Error;
351
352    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
353    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
354    async fn receive_into(
355        &mut self,
356        ots: &mut impl Buf<Block>,
357        choices: &[Choice],
358    ) -> Result<(), Self::Error> {
359        let count = ots.len();
360        assert_eq!(choices.len(), count);
361
362        let (mut send, mut recv) = self.conn.byte_stream().await?;
363
364        let mut decap_keys: Vec<DecapsulationKey<MlKem>> = Vec::with_capacity(count);
365        let mut rs_0 = Vec::with_capacity(count);
366        let mut rs_1 = Vec::with_capacity(count);
367
368        for choice in choices.iter() {
369            // Step 1: Generate real keypair.
370            let (dk, ek) = MlKem::generate_keypair_from_rng(&mut self.rng);
371            let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] = ek
372                .to_bytes()
373                .as_slice()
374                .try_into()
375                .expect("incorrect encapsulation key size");
376            let ek = EncapsulationKey::from_bytes(&ek_bytes);
377
378            // Step 2: Sample random key for position 1-b.
379            let r_1_b = random_ek(&mut self.rng, ek.rho);
380
381            // Step 3: Compute real key: r_b = ek - hash_ek(r_{1-b}).
382            let r_b = &ek - &hash_ek(&r_1_b);
383            let r_b_bytes: EncapsulationKeyBytes = r_b.into();
384            let r_1_b_bytes: EncapsulationKeyBytes = r_1_b.into();
385
386            // Step 4: Select (r_0, r_1) based on choice bit (constant-time).
387            // If b=0: r_0 = real, r_1 = random.
388            // If b=1: r_0 = random, r_1 = real.
389            let r_0 = EncapsulationKeyBytes::conditional_select(&r_b_bytes, &r_1_b_bytes, *choice);
390            let r_1 = EncapsulationKeyBytes::conditional_select(&r_1_b_bytes, &r_b_bytes, *choice);
391
392            decap_keys.push(dk);
393            rs_0.push(r_0);
394            rs_1.push(r_1);
395        }
396
397        let receiver_msg = EncapsulationKeysMessage { rs_0, rs_1 };
398        {
399            let mut send_stream = send.as_stream();
400            send_stream.send(receiver_msg).await?;
401        }
402
403        let sender_msg: CiphertextsMessage = {
404            let mut recv_stream = recv.as_stream();
405            recv_stream.next().await.ok_or(Error::ClosedStream)??
406        };
407
408        if sender_msg.cts_0.len() != count || sender_msg.cts_1.len() != count {
409            return Err(Error::InvalidDataCount {
410                expected: count,
411                actual_0: sender_msg.cts_0.len(),
412                actual_1: sender_msg.cts_1.len(),
413            });
414        }
415
416        // Step 10-11: Decapsulate the chosen ciphertext and derive OT key.
417        for (i, ((dk, choice), (ct_0, ct_1))) in decap_keys
418            .iter()
419            .zip(choices.iter())
420            .zip(sender_msg.cts_0.iter().zip(sender_msg.cts_1.iter()))
421            .enumerate()
422        {
423            let ct_b_bytes = CiphertextBytes::conditional_select(ct_0, ct_1, *choice).0;
424            let ct_b: MlKemCiphertext<MlKem> = ct_b_bytes
425                .as_slice()
426                .try_into()
427                .expect("incorrect ciphertext size");
428            let shared_secret = dk.decapsulate(&ct_b);
429            let key_b = derive_ot_key(&shared_secret, i);
430            ots[i] = key_b;
431        }
432
433        Ok(())
434    }
435}
436
437// Encapsulates to the given key, returning the ciphertext and the shared key.
438fn encapsulate(
439    ek: EncapsulationKeyBytes,
440    rng: &mut StdRng,
441) -> Result<(CiphertextBytes, SharedKey), Error> {
442    let parsed_ek =
443        MlKemEncapsulationKey::<MlKem>::new(&ek.0.into()).map_err(Error::EncapsKeyValidation)?;
444    let (ct, ss): (MlKemCiphertext<MlKem>, SharedKey) = parsed_ek.encapsulate_with_rng(rng);
445    Ok((
446        CiphertextBytes(ct.as_slice().try_into().expect("incorrect ciphertext size")),
447        ss,
448    ))
449}
450
451// Derive an OT key from the ML-KEM shared key using a random oracle XOF,
452// returning a Block-sized (128-bit) output.
453fn derive_ot_key(key: &SharedKey, tweak: usize) -> Block {
454    let mut ro = RandomOracle::new();
455    ro.update(HASH_DOMAIN_SEPARATOR);
456    ro.update(key.as_slice());
457    ro.update(&tweak.to_le_bytes());
458    let mut out = ro.finalize_xof();
459    let mut block = Block::ZERO;
460    out.fill(block.as_mut_bytes());
461    block
462}
463
464#[cfg(test)]
465mod tests {
466    use anyhow::Result;
467    use cryprot_net::testing::{init_tracing, local_conn};
468    use rand::{SeedableRng, rngs::StdRng};
469
470    use super::MlKemOt;
471    use crate::{RotReceiver, RotSender, random_choices};
472
473    #[tokio::test]
474    async fn mlkem_base_rot_random_choices() -> Result<()> {
475        let _g = init_tracing();
476        let (con1, con2) = local_conn().await?;
477        let mut rng1 = StdRng::seed_from_u64(42);
478        let rng2 = StdRng::seed_from_u64(42 * 42);
479        let count = 128;
480        let choices = random_choices(count, &mut rng1);
481
482        let mut sender = MlKemOt::new_with_rng(con1, rng1);
483        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
484        let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
485
486        for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
487            assert_eq!(r, s[c.unwrap_u8() as usize])
488        }
489        Ok(())
490    }
491
492    #[tokio::test]
493    async fn mlkem_base_rot_zero_choices() -> Result<()> {
494        let _g = init_tracing();
495        let (con1, con2) = local_conn().await?;
496        let rng1 = StdRng::seed_from_u64(123);
497        let rng2 = StdRng::seed_from_u64(456);
498        let count = 128;
499        let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(0)).collect();
500
501        let mut sender = MlKemOt::new_with_rng(con1, rng1);
502        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
503        let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
504
505        for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
506            assert_eq!(r, s[c.unwrap_u8() as usize])
507        }
508        Ok(())
509    }
510
511    #[tokio::test]
512    async fn mlkem_base_rot_one_choices() -> Result<()> {
513        let _g = init_tracing();
514        let (con1, con2) = local_conn().await?;
515        let rng1 = StdRng::seed_from_u64(789);
516        let rng2 = StdRng::seed_from_u64(101112);
517        let count = 128;
518        let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(1)).collect();
519
520        let mut sender = MlKemOt::new_with_rng(con1, rng1);
521        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
522        let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
523
524        for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
525            assert_eq!(r, s[c.unwrap_u8() as usize])
526        }
527        Ok(())
528    }
529
530    #[tokio::test]
531    async fn mlkem_base_rot_single_ot() -> Result<()> {
532        let _g = init_tracing();
533        let (con1, con2) = local_conn().await?;
534        let rng1 = StdRng::seed_from_u64(42);
535        let rng2 = StdRng::seed_from_u64(43);
536        let choices = vec![subtle::Choice::from(1)];
537
538        let mut sender = MlKemOt::new_with_rng(con1, rng1);
539        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
540        let (s_ot, r_ot) = tokio::try_join!(sender.send(1), receiver.receive(&choices))?;
541
542        assert_eq!(r_ot[0], s_ot[0][1]);
543        Ok(())
544    }
545}