1use std::mem;
13
14use aes::{
15 Aes128,
16 cipher::{BlockCipherEncrypt, KeyInit},
17};
18use rand::{CryptoRng, Rng, RngCore, SeedableRng};
19use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng};
20
21use crate::{AES_PAR_BLOCKS, Block};
22
23#[derive(Clone, Debug)]
29pub struct AesRng(BlockRng<AesRngCore>);
30
31impl RngCore for AesRng {
32 #[inline]
33 fn next_u32(&mut self) -> u32 {
34 self.0.next_u32()
35 }
36
37 #[inline]
38 fn next_u64(&mut self) -> u64 {
39 self.0.next_u64()
40 }
41
42 #[inline]
43 fn fill_bytes(&mut self, dest: &mut [u8]) {
44 let block_size = mem::size_of::<aes::Block>();
45 let block_len = dest.len() / block_size * block_size;
46 let (block_bytes, rest_bytes) = dest.split_at_mut(block_len);
47 let blocks = bytemuck::cast_slice_mut::<_, aes::Block>(block_bytes);
50 for chunk in blocks.chunks_mut(AES_PAR_BLOCKS) {
51 for block in chunk.iter_mut() {
52 *block = aes::cipher::Array(self.0.core.state.to_le_bytes());
53 self.0.core.state += 1;
54 }
55 self.0.core.aes.encrypt_blocks(chunk);
56 }
57 self.0.fill_bytes(rest_bytes)
59 }
60}
61
62impl SeedableRng for AesRng {
63 type Seed = Block;
64
65 #[inline]
66 fn from_seed(seed: Self::Seed) -> Self {
67 AesRng(BlockRng::<AesRngCore>::from_seed(seed))
68 }
69}
70
71impl CryptoRng for AesRng {}
72
73impl AesRng {
74 #[inline]
77 pub fn new() -> Self {
78 let seed = rand::random::<Block>();
79 AesRng::from_seed(seed)
80 }
81
82 #[inline]
84 pub fn fork(&mut self) -> Self {
85 let seed = self.random::<Block>();
86 AesRng::from_seed(seed)
87 }
88}
89
90impl Default for AesRng {
91 #[inline]
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97#[derive(Clone)]
99pub struct AesRngCore {
100 aes: Aes128,
101 state: u128,
102}
103
104impl std::fmt::Debug for AesRngCore {
105 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
106 write!(f, "AesRngCore {{}}")
107 }
108}
109
110impl BlockRngCore for AesRngCore {
111 type Item = u32;
112 type Results = hidden::ParBlockWrapper;
114
115 #[inline]
117 fn generate(&mut self, results: &mut Self::Results) {
118 let blocks = bytemuck::cast_slice_mut::<_, aes::Block>(results.as_mut());
119 blocks.iter_mut().for_each(|blk| {
120 *blk = aes::cipher::Array(self.state.to_le_bytes());
123 self.state += 1;
124 });
125 self.aes.encrypt_blocks(blocks);
126 }
127}
128
129mod hidden {
130 #[derive(Copy, Clone)]
134 pub struct ParBlockWrapper([u32; 36]);
135
136 impl Default for ParBlockWrapper {
137 fn default() -> Self {
138 Self([0; 36])
139 }
140 }
141
142 impl AsMut<[u32]> for ParBlockWrapper {
143 fn as_mut(&mut self) -> &mut [u32] {
144 &mut self.0
145 }
146 }
147
148 impl AsRef<[u32]> for ParBlockWrapper {
149 fn as_ref(&self) -> &[u32] {
150 &self.0
151 }
152 }
153}
154
155impl SeedableRng for AesRngCore {
156 type Seed = Block;
157
158 #[inline]
159 fn from_seed(seed: Self::Seed) -> Self {
160 let aes = Aes128::new(&seed.into());
161 AesRngCore {
162 aes,
163 state: Default::default(),
164 }
165 }
166}
167
168impl CryptoBlockRng for AesRngCore {}
169
170impl From<AesRngCore> for AesRng {
171 #[inline]
172 fn from(core: AesRngCore) -> Self {
173 AesRng(BlockRng::new(core))
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_generate() {
183 let mut rng = AesRng::new();
184 let a = rng.random::<[Block; 8]>();
185 let b = rng.random::<[Block; 8]>();
186 assert_ne!(a, b);
187 }
188}