1use std::{
5 fmt,
6 ops::{Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use aes::cipher::{self, array::sizes};
10use bytemuck::{Pod, Zeroable};
11use rand::{Rng, distr::StandardUniform, prelude::Distribution};
12use serde::{Deserialize, Serialize};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
14use thiserror::Error;
15use wide::u8x16;
16
17use crate::random_oracle::{self, RandomOracle};
18
19pub mod gf128;
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, Pod, Zeroable)]
23#[repr(transparent)]
24pub struct Block(u8x16);
25
26impl Block {
27 pub const ZERO: Self = Self(u8x16::ZERO);
29 pub const ONES: Self = Self(u8x16::MAX);
31 pub const ONE: Self = Self::new(1_u128.to_ne_bytes());
33 pub const MASK_LSB: Self = Self::pack(u64::MAX << 1, u64::MAX);
41
42 pub const BYTES: usize = 16;
44 pub const BITS: usize = 128;
46
47 #[inline]
49 pub const fn new(bytes: [u8; 16]) -> Self {
50 Self(u8x16::new(bytes))
51 }
52
53 #[inline]
55 pub const fn splat(byte: u8) -> Self {
56 Self::new([byte; 16])
57 }
58
59 #[inline]
64 pub const fn pack(low: u64, high: u64) -> Self {
65 let mut bytes = [0; 16];
66 let low = low.to_ne_bytes();
67 let mut i = 0;
68 while i < low.len() {
69 bytes[i] = low[i];
70 i += 1;
71 }
72
73 let high = high.to_ne_bytes();
74 let mut i = 0;
75 while i < high.len() {
76 bytes[i + 8] = high[i];
77 i += 1;
78 }
79
80 Self::new(bytes)
81 }
82
83 #[inline]
85 pub fn as_bytes(&self) -> &[u8; 16] {
86 self.0.as_array_ref()
87 }
88
89 #[inline]
91 pub fn as_mut_bytes(&mut self) -> &mut [u8; 16] {
92 self.0.as_array_mut()
93 }
94
95 #[inline]
97 pub fn ro_hash(&self) -> random_oracle::Hash {
98 let mut ro = RandomOracle::new();
99 ro.update(self.as_bytes());
100 ro.finalize()
101 }
102
103 #[inline]
108 pub fn from_choices(choices: &[Choice]) -> Self {
109 assert_eq!(128, choices.len(), "choices.len() must be 128");
110 let mut bytes = [0_u8; 16];
111 for (chunk, byte) in choices.chunks_exact(8).zip(&mut bytes) {
112 for (i, choice) in chunk.iter().enumerate() {
113 *byte ^= choice.unwrap_u8() << i;
114 }
115 }
116 Self::new(bytes)
117 }
118
119 #[inline]
121 pub fn low(&self) -> u64 {
122 u64::from_ne_bytes(self.as_bytes()[..8].try_into().expect("correct len"))
123 }
124
125 #[inline]
127 pub fn high(&self) -> u64 {
128 u64::from_ne_bytes(self.as_bytes()[8..].try_into().expect("correct len"))
129 }
130
131 #[inline]
133 pub fn lsb(&self) -> bool {
134 *self & Block::ONE == Block::ONE
135 }
136
137 #[inline]
139 pub fn bits(&self) -> impl Iterator<Item = bool> {
140 struct BitIter {
141 blk: Block,
142 idx: usize,
143 }
144 impl Iterator for BitIter {
145 type Item = bool;
146
147 #[inline]
148 fn next(&mut self) -> Option<Self::Item> {
149 if self.idx < Block::BITS {
150 self.idx += 1;
151 let bit = (self.blk >> (self.idx - 1)) & Block::ONE != Block::ZERO;
152 Some(bit)
153 } else {
154 None
155 }
156 }
157 }
158 BitIter { blk: *self, idx: 0 }
159 }
160}
161
162impl BitAnd for Block {
164 type Output = Self;
165
166 #[inline]
167 fn bitand(self, rhs: Self) -> Self {
168 Self(self.0 & rhs.0)
169 }
170}
171
172impl BitAndAssign for Block {
173 #[inline]
174 fn bitand_assign(&mut self, rhs: Self) {
175 *self = *self & rhs;
176 }
177}
178
179impl BitOr for Block {
180 type Output = Self;
181
182 #[inline]
183 fn bitor(self, rhs: Self) -> Self {
184 Self(self.0 | rhs.0)
185 }
186}
187
188impl BitOrAssign for Block {
189 #[inline]
190 fn bitor_assign(&mut self, rhs: Self) {
191 *self = *self | rhs;
192 }
193}
194
195impl BitXor for Block {
196 type Output = Self;
197
198 #[inline]
199 fn bitxor(self, rhs: Self) -> Self {
200 Self(self.0 ^ rhs.0)
201 }
202}
203
204impl BitXorAssign for Block {
205 #[inline]
206 fn bitxor_assign(&mut self, rhs: Self) {
207 *self = *self ^ rhs;
208 }
209}
210
211impl<Rhs> Shl<Rhs> for Block
212where
213 u128: Shl<Rhs, Output = u128>,
214{
215 type Output = Block;
216
217 #[inline]
218 fn shl(self, rhs: Rhs) -> Self::Output {
219 Self::from(u128::from(self) << rhs)
220 }
221}
222
223impl<Rhs> Shr<Rhs> for Block
224where
225 u128: Shr<Rhs, Output = u128>,
226{
227 type Output = Block;
228
229 #[inline]
230 fn shr(self, rhs: Rhs) -> Self::Output {
231 Self::from(u128::from(self) >> rhs)
232 }
233}
234
235impl Not for Block {
236 type Output = Self;
237
238 #[inline]
239 fn not(self) -> Self {
240 Self(!self.0)
241 }
242}
243
244impl PartialEq for Block {
245 fn eq(&self, other: &Self) -> bool {
246 let a: u128 = (*self).into();
247 let b: u128 = (*other).into();
248 a.ct_eq(&b).into()
249 }
250}
251
252impl Eq for Block {}
253
254impl Distribution<Block> for StandardUniform {
255 #[inline]
256 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Block {
257 let mut bytes = [0; 16];
258 rng.fill_bytes(&mut bytes);
259 Block::new(bytes)
260 }
261}
262
263impl AsRef<[u8]> for Block {
264 fn as_ref(&self) -> &[u8] {
265 self.as_bytes()
266 }
267}
268
269impl AsMut<[u8]> for Block {
270 #[inline]
271 fn as_mut(&mut self) -> &mut [u8] {
272 self.as_mut_bytes()
273 }
274}
275
276impl From<Block> for cipher::Array<u8, sizes::U16> {
277 #[inline]
278 fn from(value: Block) -> Self {
279 Self(*value.as_bytes())
280 }
281}
282
283impl From<cipher::Array<u8, sizes::U16>> for Block {
284 #[inline]
285 fn from(value: cipher::Array<u8, sizes::U16>) -> Self {
286 Self::new(value.0)
287 }
288}
289
290impl From<[u64; 2]> for Block {
291 #[inline]
292 fn from(value: [u64; 2]) -> Self {
293 bytemuck::cast(value)
294 }
295}
296
297impl From<Block> for [u64; 2] {
298 #[inline]
299 fn from(value: Block) -> Self {
300 bytemuck::cast(value)
301 }
302}
303
304impl From<Block> for u128 {
305 #[inline]
306 fn from(value: Block) -> Self {
307 u128::from_ne_bytes(*value.as_bytes())
309 }
310}
311
312impl From<&Block> for u128 {
313 #[inline]
314 fn from(value: &Block) -> Self {
315 u128::from_ne_bytes(*value.as_bytes())
317 }
318}
319
320impl From<usize> for Block {
321 fn from(value: usize) -> Self {
322 (value as u128).into()
323 }
324}
325
326impl From<u128> for Block {
327 #[inline]
328 fn from(value: u128) -> Self {
329 Self::new(value.to_ne_bytes())
330 }
331}
332
333impl From<&u128> for Block {
334 #[inline]
335 fn from(value: &u128) -> Self {
336 Self::new(value.to_ne_bytes())
337 }
338}
339
340#[derive(Debug, Error)]
341#[error("slice must have length of 16")]
342pub struct WrongLength;
343
344impl TryFrom<&[u8]> for Block {
345 type Error = WrongLength;
346
347 #[inline]
348 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
349 let arr = value.try_into().map_err(|_| WrongLength)?;
350 Ok(Self::new(arr))
351 }
352}
353
354#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
355mod from_arch_impls {
356 #[cfg(target_arch = "x86")]
357 use std::arch::x86::*;
358 #[cfg(target_arch = "x86_64")]
359 use std::arch::x86_64::*;
360
361 use super::Block;
362
363 impl From<__m128i> for Block {
364 #[inline]
365 fn from(value: __m128i) -> Self {
366 bytemuck::must_cast(value)
367 }
368 }
369
370 impl From<&__m128i> for Block {
371 #[inline]
372 fn from(value: &__m128i) -> Self {
373 bytemuck::must_cast(*value)
374 }
375 }
376
377 impl From<Block> for __m128i {
378 #[inline]
379 fn from(value: Block) -> Self {
380 bytemuck::must_cast(value)
381 }
382 }
383
384 impl From<&Block> for __m128i {
385 #[inline]
386 fn from(value: &Block) -> Self {
387 bytemuck::must_cast(*value)
388 }
389 }
390}
391
392impl ConditionallySelectable for Block {
393 #[inline]
394 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
396 let mask = Block::new((-(choice.unwrap_u8() as i128)).to_le_bytes());
399 *a ^ (mask & (*a ^ *b))
400 }
401}
402
403impl Add for Block {
404 type Output = Block;
405
406 #[inline]
407 fn add(self, rhs: Self) -> Self::Output {
408 let a: u128 = self.into();
410 let b: u128 = rhs.into();
411 Self::from(a.wrapping_add(b))
412 }
413}
414
415impl fmt::Binary for Block {
416 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417 fmt::Binary::fmt(&u128::from(*self), f)
418 }
419}
420
421#[cfg(feature = "num-traits")]
422impl num_traits::Zero for Block {
423 fn zero() -> Self {
424 Self::ZERO
425 }
426
427 fn is_zero(&self) -> bool {
428 *self == Self::ZERO
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use subtle::{Choice, ConditionallySelectable};
435
436 use crate::Block;
437
438 #[test]
439 fn test_block_cond_select() {
440 let choice = Choice::from(0);
441 assert_eq!(
442 Block::ZERO,
443 Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
444 );
445 let choice = Choice::from(1);
446 assert_eq!(
447 Block::ONES,
448 Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
449 );
450 }
451
452 #[test]
453 fn test_block_low_high() {
454 let b = Block::from(1_u128);
455 assert_eq!(1, b.low());
456 assert_eq!(0, b.high());
457 }
458
459 #[test]
460 fn test_from_into_u64_arr() {
461 let b = Block::from([42, 65]);
462 assert_eq!(42, b.low());
463 assert_eq!(65, b.high());
464 assert_eq!([42, 65], <[u64; 2]>::from(b));
465 }
466
467 #[test]
468 fn test_pack() {
469 let b = Block::pack(42, 123);
470 assert_eq!(42, b.low());
471 assert_eq!(123, b.high());
472 }
473
474 #[test]
475 fn test_mask_lsb() {
476 assert_eq!(Block::ONES ^ Block::ONE, Block::MASK_LSB);
477 }
478
479 #[test]
480 fn test_bits() {
481 let b: Block = 0b101_u128.into();
482 let mut iter = b.bits();
483 assert_eq!(Some(true), iter.next());
484 assert_eq!(Some(false), iter.next());
485 assert_eq!(Some(true), iter.next());
486 for rest in iter {
487 assert_eq!(false, rest);
488 }
489 }
490
491 #[test]
492 fn test_from_choices() {
493 let mut choices = vec![Choice::from(0); 128];
494 choices[2] = Choice::from(1);
495 choices[16] = Choice::from(1);
496 let blk = Block::from_choices(&choices);
497 assert_eq!(Block::from(1_u128 << 2 | 1_u128 << 16), blk);
498 }
499}