Skip to main content

cryprot_ot/
mlkem_ot.rs

1//! Post-quantum base OT using ML-KEM.
2//!
3//! Implements the MR19 protocol (Masny-Rindal, ePrint 2019/706, Figure 8)
4//! instantiated with ML-KEM as per Section D.3.
5//! See `docs/mlkem-ot-protocol.md` for the full protocol description.
6
7use std::{io, mem::size_of};
8
9use cryprot_core::{Block, buf::Buf, rand_compat::RngCompat, random_oracle::RandomOracle};
10use cryprot_net::{Connection, ConnectionError};
11use futures::{SinkExt, StreamExt};
12use hybrid_array::typenum::Unsigned;
13use ml_kem::{
14    Ciphertext as MlKemCiphertext, EncodedSizeUser, KemCore, ParameterSet, SharedKey,
15    kem::{Decapsulate, DecapsulationKey, Encapsulate, EncapsulationKey as MlKemEncapsulationKey},
16};
17// ML-KEM parameter set selection. If multiple features are enabled, the highest
18// wins.
19cfg_select! {
20    feature = "ml-kem-base-ot-1024" => {
21        use ml_kem::{MlKem1024 as MlKem, MlKem1024Params as MlKemParams};
22    }
23    feature = "ml-kem-base-ot-768" => {
24        use ml_kem::{MlKem768 as MlKem, MlKem768Params as MlKemParams};
25    }
26    feature = "ml-kem-base-ot-512" => {
27        use ml_kem::{MlKem512 as MlKem, MlKem512Params as MlKemParams};
28    }
29}
30
31use module_lattice::{Encode, Field, NttPolynomial};
32use rand::{RngExt, rngs::StdRng};
33use serde::{Deserialize, Serialize};
34use sha3::{
35    Digest, Shake128,
36    digest::{ExtendableOutput, Update, XofReader},
37};
38use subtle::{Choice, ConditionallySelectable};
39use tracing::Level;
40
41use crate::{Connected, RotReceiver, RotSender, SemiHonest, phase};
42
43// Define the ML-KEM base field (q = 3329).
44module_lattice::define_field!(MlKemField, u16, u32, u64, 3329);
45
46// Module dimension derived from the chosen ML-KEM parameter set.
47type K = <MlKemParams as ParameterSet>::K;
48
49type NttVector = module_lattice::NttVector<MlKemField, K>;
50
51type U12 = hybrid_array::typenum::U12;
52
53const ENCAPSULATION_KEY_LEN: usize =
54    <MlKemEncapsulationKey<MlKemParams> as EncodedSizeUser>::EncodedSize::USIZE;
55const CIPHERTEXT_LEN: usize = <MlKem as KemCore>::CiphertextSize::USIZE;
56const HASH_DOMAIN_SEPARATOR: &[u8] = b"MlKemOt";
57
58// Number of coefficients per polynomial (FIPS 203, Section 2: n = 256).
59const NUM_COEFFICIENTS: usize = 256;
60
61type Seed = [u8; 32];
62
63type Rho = [u8; 32];
64
65// Serialized t_hat is the encapsulation key minus the rho suffix.
66const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Rho>();
67
68// Parsed encapsulation key: ek = (t_hat, rho).
69struct EncapsulationKey {
70    t_hat: NttVector,
71    rho: Rho,
72}
73
74impl EncapsulationKey {
75    fn from_bytes(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> Self {
76        let enc = bytes[..T_HAT_BYTES_LEN]
77            .try_into()
78            .expect("t_hat length mismatch");
79        let t_hat = <NttVector as Encode<U12>>::decode(enc);
80        let rho = bytes[T_HAT_BYTES_LEN..]
81            .try_into()
82            .expect("rho length mismatch");
83        Self { t_hat, rho }
84    }
85
86    fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_LEN] {
87        let encoded = <NttVector as Encode<U12>>::encode(&self.t_hat);
88        let mut out = [0u8; ENCAPSULATION_KEY_LEN];
89        out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
90        out[T_HAT_BYTES_LEN..].copy_from_slice(&self.rho);
91        out
92    }
93}
94
95impl std::ops::Sub<&NttVector> for &EncapsulationKey {
96    type Output = EncapsulationKey;
97
98    fn sub(self, rhs: &NttVector) -> EncapsulationKey {
99        EncapsulationKey {
100            t_hat: &self.t_hat - rhs,
101            rho: self.rho,
102        }
103    }
104}
105
106impl std::ops::Add<&NttVector> for &EncapsulationKey {
107    type Output = EncapsulationKey;
108
109    fn add(self, rhs: &NttVector) -> EncapsulationKey {
110        EncapsulationKey {
111            t_hat: &self.t_hat + rhs,
112            rho: self.rho,
113        }
114    }
115}
116
117// XOF: SHAKE-128(seed || i || j), see FIPS 203 Section 4.1.
118fn xof(seed: &Seed, i: u8, j: u8) -> impl XofReader {
119    let mut h = Shake128::default();
120    h.update(seed);
121    h.update(&[i, j]);
122    h.finalize_xof()
123}
124
125// FIPS 203 Algorithm 7: SampleNTT.
126// Rejection sampling from a byte stream to produce a pseudorandom NTT
127// polynomial.
128//
129// Adapted from the ml-kem crate's `sample_ntt`.
130fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
131    const Q: u16 = MlKemField::Q;
132    // Read 32 triples (3 bytes each) at a time from the XOF.
133    // BUF_LEN must be divisible by 3 so pos always lands exactly on BUF_LEN.
134    const BUF_LEN: usize = 32 * 3;
135    let mut poly = NttPolynomial::<MlKemField>::default();
136    let mut buf = [0u8; BUF_LEN];
137    xof.read(&mut buf);
138    let mut pos = 0;
139    let mut i = 0;
140
141    while i < NUM_COEFFICIENTS {
142        // Refill the buffer from the XOF stream when exhausted.
143        if pos >= BUF_LEN {
144            xof.read(&mut buf);
145            pos = 0;
146        }
147
148        let d1 = u16::from(buf[pos]) | ((u16::from(buf[pos + 1]) & 0x0F) << 8);
149        let d2 = (u16::from(buf[pos + 1]) >> 4) | (u16::from(buf[pos + 2]) << 4);
150        pos += 3;
151
152        if d1 < Q {
153            poly.0[i] = module_lattice::Elem::new(d1);
154            i += 1;
155        }
156        if i < NUM_COEFFICIENTS && d2 < Q {
157            poly.0[i] = module_lattice::Elem::new(d2);
158            i += 1;
159        }
160    }
161
162    poly
163}
164
165// Produces a pseudorandom NttVector from a seed by calling sample_ntt_poly k
166// times, each with a different XOF stream: xof(seed, 0, j).
167fn sample_ntt_vector(seed: &Seed) -> NttVector {
168    NttVector::new(
169        (0..K::USIZE)
170            .map(|j| {
171                let mut reader = xof(seed, 0, j as u8);
172                sample_ntt_poly(&mut reader)
173            })
174            .collect(),
175    )
176}
177
178// Maps an encapsulation key to an NttVector via SHA3-256.
179// Only the t_hat component is used; rho is ignored.
180// Corresponds to libOTe's `pkHash`.
181fn hash_ek(ek: &EncapsulationKey) -> NttVector {
182    let encoded = <NttVector as Encode<U12>>::encode(&ek.t_hat);
183    let seed: Seed = sha3::Sha3_256::digest(encoded.as_slice()).into();
184    sample_ntt_vector(&seed)
185}
186
187// Generate a random encapsulation key using the given randomness and rho.
188//
189// The result is indistinguishable from a real encapsulation key, since a real
190// one has `t_hat = A_hat * s + e` and that is computationally indistinguishable
191// from a pseudorandom vector in `T_q^k`.
192fn random_ek(rng: &mut StdRng, rho: Rho) -> EncapsulationKey {
193    let seed: Seed = rng.random();
194    EncapsulationKey {
195        t_hat: sample_ntt_vector(&seed),
196        rho,
197    }
198}
199
200#[derive(thiserror::Error, Debug)]
201pub enum Error {
202    #[error("quic connection error")]
203    Connection(#[from] ConnectionError),
204    #[error("io communication error")]
205    Io(#[from] io::Error),
206    #[error(
207        "invalid count of keys/ciphertexts received. expected: {expected}, actual_0: {actual_0}, actual_1: {actual_1}"
208    )]
209    InvalidDataCount {
210        expected: usize,
211        actual_0: usize,
212        actual_1: usize,
213    },
214    #[error("expected message but stream is closed")]
215    ClosedStream,
216    #[error("ML-KEM decapsulation failed")]
217    Decapsulation,
218}
219
220#[derive(Copy, Clone, Serialize, Deserialize)]
221struct EncapsulationKeyBytes(#[serde(with = "serde_bytes")] [u8; ENCAPSULATION_KEY_LEN]);
222
223impl From<&EncapsulationKey> for EncapsulationKeyBytes {
224    fn from(ek: &EncapsulationKey) -> Self {
225        Self(ek.to_bytes())
226    }
227}
228
229impl ConditionallySelectable for EncapsulationKeyBytes {
230    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
231        Self(<[u8; ENCAPSULATION_KEY_LEN]>::conditional_select(
232            &a.0, &b.0, choice,
233        ))
234    }
235}
236
237#[derive(Copy, Clone, Serialize, Deserialize)]
238struct CiphertextBytes(#[serde(with = "serde_bytes")] [u8; CIPHERTEXT_LEN]);
239
240impl ConditionallySelectable for CiphertextBytes {
241    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
242        Self(<[u8; CIPHERTEXT_LEN]>::conditional_select(
243            &a.0, &b.0, choice,
244        ))
245    }
246}
247
248// Message from receiver to sender: two values (r_0, r_1) per OT.
249#[derive(Serialize, Deserialize)]
250struct EncapsulationKeysMessage {
251    rs_0: Vec<EncapsulationKeyBytes>,
252    rs_1: Vec<EncapsulationKeyBytes>,
253}
254
255// Message from sender to receiver: two ciphertexts per OT.
256#[derive(Serialize, Deserialize)]
257struct CiphertextsMessage {
258    cts_0: Vec<CiphertextBytes>,
259    cts_1: Vec<CiphertextBytes>,
260}
261
262pub struct MlKemOt {
263    rng: StdRng,
264    conn: Connection,
265}
266
267impl SemiHonest for MlKemOt {}
268
269impl MlKemOt {
270    pub fn new(connection: Connection) -> Self {
271        Self::new_with_rng(connection, rand::make_rng())
272    }
273
274    pub fn new_with_rng(connection: Connection, rng: StdRng) -> MlKemOt {
275        Self {
276            conn: connection,
277            rng,
278        }
279    }
280}
281
282impl Connected for MlKemOt {
283    fn connection(&mut self) -> &mut Connection {
284        &mut self.conn
285    }
286}
287
288impl RotSender for MlKemOt {
289    type Error = Error;
290
291    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
292    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
293    async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
294        let count = ots.len();
295        let (mut send, mut recv) = self.conn.byte_stream().await?;
296
297        let receiver_msg: EncapsulationKeysMessage = {
298            let mut recv_stream = recv.as_stream();
299            recv_stream.next().await.ok_or(Error::ClosedStream)??
300        };
301
302        if receiver_msg.rs_0.len() != count || receiver_msg.rs_1.len() != count {
303            return Err(Error::InvalidDataCount {
304                expected: count,
305                actual_0: receiver_msg.rs_0.len(),
306                actual_1: receiver_msg.rs_1.len(),
307            });
308        }
309
310        let mut cts_0 = Vec::with_capacity(count);
311        let mut cts_1 = Vec::with_capacity(count);
312        for (i, (r_0_bytes, r_1_bytes)) in receiver_msg
313            .rs_0
314            .iter()
315            .zip(receiver_msg.rs_1.iter())
316            .enumerate()
317        {
318            // Step 5: Receive (r_0, r_1) from the receiver (done above).
319            let r_0 = EncapsulationKey::from_bytes(&r_0_bytes.0);
320            let r_1 = EncapsulationKey::from_bytes(&r_1_bytes.0);
321
322            // Step 6: Reconstruct encapsulation keys: ek_j = r_j + hash_ek(r_{1-j}).
323            let ek_0 = &r_0 + &hash_ek(&r_1);
324            let ek_1 = &r_1 + &hash_ek(&r_0);
325
326            // Step 7: Encapsulate to both reconstructed keys.
327            let (ct_0, ss_0) = encapsulate(&(&ek_0).into(), &mut self.rng);
328            let (ct_1, ss_1) = encapsulate(&(&ek_1).into(), &mut self.rng);
329
330            // Step 8: Derive OT output keys.
331            let key_0 = derive_ot_key(&ss_0, i);
332            let key_1 = derive_ot_key(&ss_1, i);
333
334            cts_0.push(ct_0);
335            cts_1.push(ct_1);
336            ots[i] = [key_0, key_1];
337        }
338
339        let sender_msg = CiphertextsMessage { cts_0, cts_1 };
340        {
341            let mut send_stream = send.as_stream();
342            send_stream.send(sender_msg).await?;
343        }
344
345        Ok(())
346    }
347}
348
349impl RotReceiver for MlKemOt {
350    type Error = Error;
351
352    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
353    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
354    async fn receive_into(
355        &mut self,
356        ots: &mut impl Buf<Block>,
357        choices: &[Choice],
358    ) -> Result<(), Self::Error> {
359        let count = ots.len();
360        assert_eq!(choices.len(), count);
361
362        let (mut send, mut recv) = self.conn.byte_stream().await?;
363
364        let mut decap_keys: Vec<DecapsulationKey<MlKemParams>> = Vec::with_capacity(count);
365        let mut rs_0 = Vec::with_capacity(count);
366        let mut rs_1 = Vec::with_capacity(count);
367
368        for choice in choices.iter() {
369            // Step 1: Generate real keypair.
370            let (dk, ek) = MlKem::generate(&mut RngCompat(&mut self.rng));
371            let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] = ek
372                .as_bytes()
373                .as_slice()
374                .try_into()
375                .expect("incorrect encapsulation key size");
376            let ek = EncapsulationKey::from_bytes(&ek_bytes);
377
378            // Step 2: Sample random key for position 1-b.
379            let r_1_b = random_ek(&mut self.rng, ek.rho);
380
381            // Step 3: Compute real key: r_b = ek - hash_ek(r_{1-b}).
382            let r_b = &ek - &hash_ek(&r_1_b);
383            let r_b_bytes: EncapsulationKeyBytes = (&r_b).into();
384            let r_1_b_bytes: EncapsulationKeyBytes = (&r_1_b).into();
385
386            // Step 4: Select (r_0, r_1) based on choice bit (constant-time).
387            // If b=0: r_0 = real, r_1 = random.
388            // If b=1: r_0 = random, r_1 = real.
389            let r_0 = EncapsulationKeyBytes::conditional_select(&r_b_bytes, &r_1_b_bytes, *choice);
390            let r_1 = EncapsulationKeyBytes::conditional_select(&r_1_b_bytes, &r_b_bytes, *choice);
391
392            decap_keys.push(dk);
393            rs_0.push(r_0);
394            rs_1.push(r_1);
395        }
396
397        let receiver_msg = EncapsulationKeysMessage { rs_0, rs_1 };
398        {
399            let mut send_stream = send.as_stream();
400            send_stream.send(receiver_msg).await?;
401        }
402
403        let sender_msg: CiphertextsMessage = {
404            let mut recv_stream = recv.as_stream();
405            recv_stream.next().await.ok_or(Error::ClosedStream)??
406        };
407
408        if sender_msg.cts_0.len() != count || sender_msg.cts_1.len() != count {
409            return Err(Error::InvalidDataCount {
410                expected: count,
411                actual_0: sender_msg.cts_0.len(),
412                actual_1: sender_msg.cts_1.len(),
413            });
414        }
415
416        // Step 10-11: Decapsulate the chosen ciphertext and derive OT key.
417        for (i, ((dk, choice), (ct_0, ct_1))) in decap_keys
418            .iter()
419            .zip(choices.iter())
420            .zip(sender_msg.cts_0.iter().zip(sender_msg.cts_1.iter()))
421            .enumerate()
422        {
423            let ct_b_bytes = CiphertextBytes::conditional_select(ct_0, ct_1, *choice).0;
424            let ct_b: MlKemCiphertext<MlKem> = ct_b_bytes
425                .as_slice()
426                .try_into()
427                .expect("incorrect ciphertext size");
428            let shared_secret = dk.decapsulate(&ct_b).map_err(|_| Error::Decapsulation)?;
429            let key_b = derive_ot_key(&shared_secret, i);
430            ots[i] = key_b;
431        }
432
433        Ok(())
434    }
435}
436
437// Encapsulates to the given key, returning the ciphertext and the shared key.
438// Note: ML-KEM encapsulation is infallible - the Result in the ml-kem crate is
439// for API generality.
440fn encapsulate(
441    ek: &EncapsulationKeyBytes,
442    rng: &mut StdRng,
443) -> (CiphertextBytes, SharedKey<MlKem>) {
444    let parsed_ek = MlKemEncapsulationKey::<MlKemParams>::from_bytes((&ek.0).into());
445    let (ct, ss): (MlKemCiphertext<MlKem>, SharedKey<MlKem>) = parsed_ek
446        .encapsulate(&mut RngCompat(rng))
447        .expect("encapsulation failed");
448    (
449        CiphertextBytes(ct.as_slice().try_into().expect("incorrect ciphertext size")),
450        ss,
451    )
452}
453
454// Derive an OT key from the ML-KEM shared key using a random oracle XOF,
455// returning a Block-sized (128-bit) output.
456fn derive_ot_key(key: &SharedKey<MlKem>, tweak: usize) -> Block {
457    let mut ro = RandomOracle::new();
458    ro.update(HASH_DOMAIN_SEPARATOR);
459    ro.update(key.as_slice());
460    ro.update(&tweak.to_le_bytes());
461    let mut out = ro.finalize_xof();
462    let mut block = Block::ZERO;
463    out.fill(block.as_mut_bytes());
464    block
465}
466
467#[cfg(test)]
468mod tests {
469    use anyhow::Result;
470    use cryprot_net::testing::{init_tracing, local_conn};
471    use rand::{SeedableRng, rngs::StdRng};
472
473    use super::MlKemOt;
474    use crate::{RotReceiver, RotSender, random_choices};
475
476    #[tokio::test]
477    async fn mlkem_base_rot_random_choices() -> Result<()> {
478        let _g = init_tracing();
479        let (con1, con2) = local_conn().await?;
480        let mut rng1 = StdRng::seed_from_u64(42);
481        let rng2 = StdRng::seed_from_u64(42 * 42);
482        let count = 128;
483        let choices = random_choices(count, &mut rng1);
484
485        let mut sender = MlKemOt::new_with_rng(con1, rng1);
486        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
487        let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
488
489        for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
490            assert_eq!(r, s[c.unwrap_u8() as usize])
491        }
492        Ok(())
493    }
494
495    #[tokio::test]
496    async fn mlkem_base_rot_zero_choices() -> Result<()> {
497        let _g = init_tracing();
498        let (con1, con2) = local_conn().await?;
499        let rng1 = StdRng::seed_from_u64(123);
500        let rng2 = StdRng::seed_from_u64(456);
501        let count = 128;
502        let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(0)).collect();
503
504        let mut sender = MlKemOt::new_with_rng(con1, rng1);
505        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
506        let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
507
508        for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
509            assert_eq!(r, s[c.unwrap_u8() as usize])
510        }
511        Ok(())
512    }
513
514    #[tokio::test]
515    async fn mlkem_base_rot_one_choices() -> Result<()> {
516        let _g = init_tracing();
517        let (con1, con2) = local_conn().await?;
518        let rng1 = StdRng::seed_from_u64(789);
519        let rng2 = StdRng::seed_from_u64(101112);
520        let count = 128;
521        let choices: Vec<_> = (0..count).map(|_| subtle::Choice::from(1)).collect();
522
523        let mut sender = MlKemOt::new_with_rng(con1, rng1);
524        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
525        let (s_ot, r_ot) = tokio::try_join!(sender.send(count), receiver.receive(&choices))?;
526
527        for ((r, s), c) in r_ot.into_iter().zip(s_ot).zip(choices) {
528            assert_eq!(r, s[c.unwrap_u8() as usize])
529        }
530        Ok(())
531    }
532
533    #[tokio::test]
534    async fn mlkem_base_rot_single_ot() -> Result<()> {
535        let _g = init_tracing();
536        let (con1, con2) = local_conn().await?;
537        let rng1 = StdRng::seed_from_u64(42);
538        let rng2 = StdRng::seed_from_u64(43);
539        let choices = vec![subtle::Choice::from(1)];
540
541        let mut sender = MlKemOt::new_with_rng(con1, rng1);
542        let mut receiver = MlKemOt::new_with_rng(con2, rng2);
543        let (s_ot, r_ot) = tokio::try_join!(sender.send(1), receiver.receive(&choices))?;
544
545        assert_eq!(r_ot[0], s_ot[0][1]);
546        Ok(())
547    }
548}