cryprot_core/transpose/portable.rs
1use wide::{i8x16, i64x2};
2
3/// Transpose a bit matrix.
4///
5/// # Panics
6/// - If `rows < 16`
7/// - If `rows` is not divisible by 16
8/// - If `input.len()` is not divisible by `rows`
9/// - If the number of columns, computed as `input.len() * 8 / rows` is less
10/// than 16
11/// - If the number of columns is not divisible by 8
12pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
13 assert!(rows >= 16, "rows must be at least 16");
14 assert_eq!(0, rows % 16, "rows must be divisible by 16");
15 assert_eq!(
16 0,
17 input.len() % rows,
18 "input.len() must be divisible by rows"
19 );
20 let cols = input.len() * 8 / rows;
21 assert!(cols >= 16, "columns must be at least 16. Columns {cols}");
22 assert_eq!(
23 0,
24 cols % 8,
25 "Number of bitmatrix columns must be divisable by 8. columns: {cols}"
26 );
27
28 // Transpose a matrix by splitting it into 16x8 blocks (=16 bytes which can be
29 // one SSE2 128 bit vector depending on target support), write transposed
30 // block into output at the transposed position and continue with next
31 // block.
32
33 let mut row: usize = 0;
34 while row <= rows - 16 {
35 let mut col = 0;
36 while col < cols {
37 // Load 16x8 sub-block by loading (row + 0, col) .. (row + 15, col)
38 let mut v = load_bytes(input, row, col, cols);
39 // The ideas is to take the most significant bit of each row of the sub-block
40 // (msb of each byte) and write these 16 bits coming from 16 rows
41 // and one column in the input to the correct row and columns in the
42 // output. Because the `move_mask` instruction gives us the `msb` of each byte
43 // (=row) in our input, we iterate the output_row_offset from large
44 // to small. After each iteration, we shift each byte in the
45 // sub-block one bit to the left and then again get the msb of each byte and
46 // write it to the next row.
47 // Visualization done by Claude 4 Sonnet:
48 // ┌─────────────────────────────────────────────────────────────┐
49 // │ BIT MATRIX TRANSPOSE: 16x8 BLOCK │
50 // ├─────────────────────────────────────────────────────────────┤
51 // │ │
52 // │ INPUT (16×8) OUTPUT (8×16) │
53 // │ │
54 // │ 0 1 ⋯ 6 7 0 1 ⋯ E F │
55 // │ 0 ◆|◇|⋯|▲|△ 0 ◆|◆|⋯|◆|◆ │
56 // │ 1 ◆|◇|⋯|▲|△ 1 ◇|◇|⋯|◇|◇ │
57 // │ ⋮ ⋮ │
58 // │ E ◆|◇|⋯|▲|△ 6 ▲|▲|⋯|▲|▲ │
59 // │ F ◆|◇|⋯|▲|△ 7 △|△|⋯|△|△ │
60 // │ │
61 // │ MOVE_MASK: Extract column bits → Write as rows │
62 // │ │
63 // │ Iter 0: MSB column 7 → output row 7 │
64 // │ Iter 1: << 1, MSB → output row 6 │
65 // │ ⋮ │
66 // │ Iter 7: << 7, MSB → output row 0 │
67 // │ │
68 // │ INPUT[row,col] → OUTPUT[col,row] │
69 // │ │
70 // └─────────────────────────────────────────────────────────────┘
71 for output_row_offset in (0..8).rev() {
72 // get msb of each byte
73 let msbs = v.to_bitmask().to_le_bytes();
74 // write msbs to output at transposed position
75 let idx = out(row, col + output_row_offset, rows) as isize;
76 // This should result in only one bounds check for the output
77 let out_bytes = &mut output[idx as usize..idx as usize + 2];
78 out_bytes[0] = msbs[0];
79 out_bytes[1] = msbs[1];
80
81 // There is no shift impl for i8x16 so we cast to i64x2 and shift these.
82 // The bits shifted to neighbouring bytes are ignored because we iterate
83 // and call move_mask 8 times.
84 let v: &mut i64x2 = bytemuck::must_cast_mut(&mut v);
85 // shift each byte by one to the left (by shifting it as two i64)
86 *v = *v << 1;
87 }
88 col += 8;
89 }
90 row += 16;
91 }
92}
93
94#[inline]
95fn inp(x: usize, y: usize, cols: usize) -> usize {
96 x * cols / 8 + y / 8
97}
98#[inline]
99fn out(x: usize, y: usize, rows: usize) -> usize {
100 y * rows / 8 + x / 8
101}
102
103#[inline]
104// get col byte of row to row + 15
105fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
106 let bytes = std::array::from_fn(|i| b[inp(row + i, col, cols)] as i8);
107 i8x16::from(bytes)
108}
109
110#[cfg(test)]
111mod tests {
112
113 use proptest::prelude::*;
114
115 use super::*;
116
117 fn arbitrary_bitmat(max_row: usize, max_col: usize) -> BoxedStrategy<(Vec<u8>, usize, usize)> {
118 (
119 (16..max_row).prop_map(|row| row / 16 * 16),
120 (16..max_col).prop_map(|col| col / 16 * 16),
121 )
122 .prop_flat_map(|(rows, cols)| {
123 (vec![any::<u8>(); rows * cols / 8], Just(rows), Just(cols))
124 })
125 .boxed()
126 }
127
128 proptest! {
129 #[cfg(not(miri))]
130 #[test]
131 fn test_double_transpose((v, rows, cols) in arbitrary_bitmat(16 * 30, 16 * 30)) {
132 let mut transposed = vec![0; v.len()];
133 let mut double_transposed = vec![0; v.len()];
134 transpose_bitmatrix(&v,&mut transposed, rows);
135 transpose_bitmatrix(&transposed, &mut double_transposed, cols);
136
137 prop_assert_eq!(v, double_transposed);
138 }
139 }
140
141 #[test]
142 #[cfg(target_arch = "x86_64")] // miri doesn't know the intrinsics on e.g. ARM
143 fn test_double_transpose_miri() {
144 let rows = 32;
145 let cols = 16;
146 let v = vec![0; rows * cols];
147 let mut transposed = vec![0; v.len()];
148 let mut double_transposed = vec![0; v.len()];
149 transpose_bitmatrix(&v, &mut transposed, rows);
150 transpose_bitmatrix(&transposed, &mut double_transposed, cols);
151 assert_eq!(v, double_transposed);
152 }
153}