cryprot_core/
block.rs

1//! A 128-bit [`Block`] type.
2//!
3//! Operations on [`Block`]s will use SIMD instructions where possible.
4use std::{
5    fmt,
6    ops::{Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use aes::cipher::{self, array::sizes};
10use bytemuck::{Pod, Zeroable};
11use rand::{Rng, distr::StandardUniform, prelude::Distribution};
12use serde::{Deserialize, Serialize};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
14use thiserror::Error;
15use wide::u8x16;
16
17use crate::random_oracle::{self, RandomOracle};
18
19pub mod gf128;
20
21/// A 128-bit block. Uses SIMD operations where available.
22#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, Pod, Zeroable)]
23#[repr(transparent)]
24pub struct Block(u8x16);
25
26impl Block {
27    /// All bits set to 0.
28    pub const ZERO: Self = Self(u8x16::ZERO);
29    /// All bits set to 1.
30    pub const ONES: Self = Self(u8x16::MAX);
31    /// Lsb set to 1, all others zero.
32    pub const ONE: Self = Self::new(1_u128.to_ne_bytes());
33    /// Mask to mask off the LSB of a Block.
34    /// ```rust
35    /// # use cryprot_core::Block;
36    /// let b = Block::ONES;
37    /// let masked = b & Block::MASK_LSB;
38    /// assert_eq!(masked, Block::ONES << 1)
39    /// ```
40    pub const MASK_LSB: Self = Self::pack(u64::MAX << 1, u64::MAX);
41
42    /// 16 bytes in a Block.
43    pub const BYTES: usize = 16;
44    /// 128 bits in a block.
45    pub const BITS: usize = 128;
46
47    /// Create a new block from bytes.
48    #[inline]
49    pub const fn new(bytes: [u8; 16]) -> Self {
50        Self(u8x16::new(bytes))
51    }
52
53    /// Create a block with all bytes set to `byte`.
54    #[inline]
55    pub const fn splat(byte: u8) -> Self {
56        Self::new([byte; 16])
57    }
58
59    /// Pack two `u64` into a Block. Usable in const context.
60    ///
61    /// In non-const contexts, using `Block::from([low, high])` is likely
62    /// faster.
63    #[inline]
64    pub const fn pack(low: u64, high: u64) -> Self {
65        let mut bytes = [0; 16];
66        let low = low.to_ne_bytes();
67        let mut i = 0;
68        while i < low.len() {
69            bytes[i] = low[i];
70            i += 1;
71        }
72
73        let high = high.to_ne_bytes();
74        let mut i = 0;
75        while i < high.len() {
76            bytes[i + 8] = high[i];
77            i += 1;
78        }
79
80        Self::new(bytes)
81    }
82
83    /// Bytes of the block.
84    #[inline]
85    pub fn as_bytes(&self) -> &[u8; 16] {
86        self.0.as_array_ref()
87    }
88
89    /// Mutable bytes of the block.
90    #[inline]
91    pub fn as_mut_bytes(&mut self) -> &mut [u8; 16] {
92        self.0.as_array_mut()
93    }
94
95    /// Hash the block with a [`random_oracle`].
96    #[inline]
97    pub fn ro_hash(&self) -> random_oracle::Hash {
98        let mut ro = RandomOracle::new();
99        ro.update(self.as_bytes());
100        ro.finalize()
101    }
102
103    ///  Create a block from 128 [`Choice`]s.
104    ///
105    /// # Panics
106    /// If choices.len() != 128
107    #[inline]
108    pub fn from_choices(choices: &[Choice]) -> Self {
109        assert_eq!(128, choices.len(), "choices.len() must be 128");
110        let mut bytes = [0_u8; 16];
111        for (chunk, byte) in choices.chunks_exact(8).zip(&mut bytes) {
112            for (i, choice) in chunk.iter().enumerate() {
113                *byte ^= choice.unwrap_u8() << i;
114            }
115        }
116        Self::new(bytes)
117    }
118
119    /// Low 64 bits of the block.
120    #[inline]
121    pub fn low(&self) -> u64 {
122        u64::from_ne_bytes(self.as_bytes()[..8].try_into().expect("correct len"))
123    }
124
125    /// High 64 bits of the block.
126    #[inline]
127    pub fn high(&self) -> u64 {
128        u64::from_ne_bytes(self.as_bytes()[8..].try_into().expect("correct len"))
129    }
130
131    /// Least significant bit of the block
132    #[inline]
133    pub fn lsb(&self) -> bool {
134        *self & Block::ONE == Block::ONE
135    }
136
137    /// Iterator over bits of the Block.
138    #[inline]
139    pub fn bits(&self) -> impl Iterator<Item = bool> {
140        struct BitIter {
141            blk: Block,
142            idx: usize,
143        }
144        impl Iterator for BitIter {
145            type Item = bool;
146
147            #[inline]
148            fn next(&mut self) -> Option<Self::Item> {
149                if self.idx < Block::BITS {
150                    self.idx += 1;
151                    let bit = (self.blk >> (self.idx - 1)) & Block::ONE != Block::ZERO;
152                    Some(bit)
153                } else {
154                    None
155                }
156            }
157        }
158        BitIter { blk: *self, idx: 0 }
159    }
160}
161
162// Implement standard operators for more ergonomic usage
163impl BitAnd for Block {
164    type Output = Self;
165
166    #[inline]
167    fn bitand(self, rhs: Self) -> Self {
168        Self(self.0 & rhs.0)
169    }
170}
171
172impl BitAndAssign for Block {
173    #[inline]
174    fn bitand_assign(&mut self, rhs: Self) {
175        *self = *self & rhs;
176    }
177}
178
179impl BitOr for Block {
180    type Output = Self;
181
182    #[inline]
183    fn bitor(self, rhs: Self) -> Self {
184        Self(self.0 | rhs.0)
185    }
186}
187
188impl BitOrAssign for Block {
189    #[inline]
190    fn bitor_assign(&mut self, rhs: Self) {
191        *self = *self | rhs;
192    }
193}
194
195impl BitXor for Block {
196    type Output = Self;
197
198    #[inline]
199    fn bitxor(self, rhs: Self) -> Self {
200        Self(self.0 ^ rhs.0)
201    }
202}
203
204impl BitXorAssign for Block {
205    #[inline]
206    fn bitxor_assign(&mut self, rhs: Self) {
207        *self = *self ^ rhs;
208    }
209}
210
211impl<Rhs> Shl<Rhs> for Block
212where
213    u128: Shl<Rhs, Output = u128>,
214{
215    type Output = Block;
216
217    #[inline]
218    fn shl(self, rhs: Rhs) -> Self::Output {
219        Self::from(u128::from(self) << rhs)
220    }
221}
222
223impl<Rhs> Shr<Rhs> for Block
224where
225    u128: Shr<Rhs, Output = u128>,
226{
227    type Output = Block;
228
229    #[inline]
230    fn shr(self, rhs: Rhs) -> Self::Output {
231        Self::from(u128::from(self) >> rhs)
232    }
233}
234
235impl Not for Block {
236    type Output = Self;
237
238    #[inline]
239    fn not(self) -> Self {
240        Self(!self.0)
241    }
242}
243
244impl PartialEq for Block {
245    fn eq(&self, other: &Self) -> bool {
246        let a: u128 = (*self).into();
247        let b: u128 = (*other).into();
248        a.ct_eq(&b).into()
249    }
250}
251
252impl Eq for Block {}
253
254impl Distribution<Block> for StandardUniform {
255    #[inline]
256    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Block {
257        let mut bytes = [0; 16];
258        rng.fill_bytes(&mut bytes);
259        Block::new(bytes)
260    }
261}
262
263impl AsRef<[u8]> for Block {
264    fn as_ref(&self) -> &[u8] {
265        self.as_bytes()
266    }
267}
268
269impl AsMut<[u8]> for Block {
270    #[inline]
271    fn as_mut(&mut self) -> &mut [u8] {
272        self.as_mut_bytes()
273    }
274}
275
276impl From<Block> for cipher::Array<u8, sizes::U16> {
277    #[inline]
278    fn from(value: Block) -> Self {
279        Self(*value.as_bytes())
280    }
281}
282
283impl From<cipher::Array<u8, sizes::U16>> for Block {
284    #[inline]
285    fn from(value: cipher::Array<u8, sizes::U16>) -> Self {
286        Self::new(value.0)
287    }
288}
289
290impl From<[u64; 2]> for Block {
291    #[inline]
292    fn from(value: [u64; 2]) -> Self {
293        bytemuck::cast(value)
294    }
295}
296
297impl From<Block> for [u64; 2] {
298    #[inline]
299    fn from(value: Block) -> Self {
300        bytemuck::cast(value)
301    }
302}
303
304impl From<Block> for u128 {
305    #[inline]
306    fn from(value: Block) -> Self {
307        // todo correct endianness?
308        u128::from_ne_bytes(*value.as_bytes())
309    }
310}
311
312impl From<&Block> for u128 {
313    #[inline]
314    fn from(value: &Block) -> Self {
315        // todo correct endianness?
316        u128::from_ne_bytes(*value.as_bytes())
317    }
318}
319
320impl From<usize> for Block {
321    fn from(value: usize) -> Self {
322        (value as u128).into()
323    }
324}
325
326impl From<u128> for Block {
327    #[inline]
328    fn from(value: u128) -> Self {
329        Self::new(value.to_ne_bytes())
330    }
331}
332
333impl From<&u128> for Block {
334    #[inline]
335    fn from(value: &u128) -> Self {
336        Self::new(value.to_ne_bytes())
337    }
338}
339
340#[derive(Debug, Error)]
341#[error("slice must have length of 16")]
342pub struct WrongLength;
343
344impl TryFrom<&[u8]> for Block {
345    type Error = WrongLength;
346
347    #[inline]
348    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
349        let arr = value.try_into().map_err(|_| WrongLength)?;
350        Ok(Self::new(arr))
351    }
352}
353
354#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
355mod from_arch_impls {
356    #[cfg(target_arch = "x86")]
357    use std::arch::x86::*;
358    #[cfg(target_arch = "x86_64")]
359    use std::arch::x86_64::*;
360
361    use super::Block;
362
363    impl From<__m128i> for Block {
364        #[inline]
365        fn from(value: __m128i) -> Self {
366            bytemuck::must_cast(value)
367        }
368    }
369
370    impl From<&__m128i> for Block {
371        #[inline]
372        fn from(value: &__m128i) -> Self {
373            bytemuck::must_cast(*value)
374        }
375    }
376
377    impl From<Block> for __m128i {
378        #[inline]
379        fn from(value: Block) -> Self {
380            bytemuck::must_cast(value)
381        }
382    }
383
384    impl From<&Block> for __m128i {
385        #[inline]
386        fn from(value: &Block) -> Self {
387            bytemuck::must_cast(*value)
388        }
389    }
390}
391
392impl ConditionallySelectable for Block {
393    #[inline]
394    // adapted from https://github.com/dalek-cryptography/subtle/blob/369e7463e85921377a5f2df80aabcbbc6d57a930/src/lib.rs#L510-L517
395    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
396        // if choice = 0, mask = (-0) = 0000...0000
397        // if choice = 1, mask = (-1) = 1111...1111
398        let mask = Block::new((-(choice.unwrap_u8() as i128)).to_le_bytes());
399        *a ^ (mask & (*a ^ *b))
400    }
401}
402
403impl Add for Block {
404    type Output = Block;
405
406    #[inline]
407    fn add(self, rhs: Self) -> Self::Output {
408        // todo is this a sensible implementation?
409        let a: u128 = self.into();
410        let b: u128 = rhs.into();
411        Self::from(a.wrapping_add(b))
412    }
413}
414
415impl fmt::Binary for Block {
416    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417        fmt::Binary::fmt(&u128::from(*self), f)
418    }
419}
420
421#[cfg(feature = "num-traits")]
422impl num_traits::Zero for Block {
423    fn zero() -> Self {
424        Self::ZERO
425    }
426
427    fn is_zero(&self) -> bool {
428        *self == Self::ZERO
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use subtle::{Choice, ConditionallySelectable};
435
436    use crate::Block;
437
438    #[test]
439    fn test_block_cond_select() {
440        let choice = Choice::from(0);
441        assert_eq!(
442            Block::ZERO,
443            Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
444        );
445        let choice = Choice::from(1);
446        assert_eq!(
447            Block::ONES,
448            Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
449        );
450    }
451
452    #[test]
453    fn test_block_low_high() {
454        let b = Block::from(1_u128);
455        assert_eq!(1, b.low());
456        assert_eq!(0, b.high());
457    }
458
459    #[test]
460    fn test_from_into_u64_arr() {
461        let b = Block::from([42, 65]);
462        assert_eq!(42, b.low());
463        assert_eq!(65, b.high());
464        assert_eq!([42, 65], <[u64; 2]>::from(b));
465    }
466
467    #[test]
468    fn test_pack() {
469        let b = Block::pack(42, 123);
470        assert_eq!(42, b.low());
471        assert_eq!(123, b.high());
472    }
473
474    #[test]
475    fn test_mask_lsb() {
476        assert_eq!(Block::ONES ^ Block::ONE, Block::MASK_LSB);
477    }
478
479    #[test]
480    fn test_bits() {
481        let b: Block = 0b101_u128.into();
482        let mut iter = b.bits();
483        assert_eq!(Some(true), iter.next());
484        assert_eq!(Some(false), iter.next());
485        assert_eq!(Some(true), iter.next());
486        for rest in iter {
487            assert_eq!(false, rest);
488        }
489    }
490
491    #[test]
492    fn test_from_choices() {
493        let mut choices = vec![Choice::from(0); 128];
494        choices[2] = Choice::from(1);
495        choices[16] = Choice::from(1);
496        let blk = Block::from_choices(&choices);
497        assert_eq!(Block::from(1_u128 << 2 | 1_u128 << 16), blk);
498    }
499}