cryprot_core/transpose/
portable.rs

1use wide::{i8x16, i64x2};
2
3/// Transpose a bit matrix.
4///
5/// # Panics
6/// - If `rows < 16`
7/// - If `rows` is not divisible by 16
8/// - If `input.len()` is not divisible by `rows`
9/// - If the number of columns, computed as `input.len() * 8 / rows` is less
10///   than 16
11/// - If the number of columns is not divisible by 8
12pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
13    assert!(rows >= 16, "rows must be at least 16");
14    assert_eq!(0, rows % 16, "rows must be divisible by 16");
15    assert_eq!(
16        0,
17        input.len() % rows,
18        "input.len() must be divisible by rows"
19    );
20    let cols = input.len() * 8 / rows;
21    assert!(cols >= 16, "columns must be at least 16. Columns {cols}");
22    assert_eq!(
23        0,
24        cols % 8,
25        "Number of bitmatrix columns must be divisable by 8. columns: {cols}"
26    );
27
28    // Transpose a matrix by splitting it into 16x8 blocks (=16 bytes which can be
29    // one SSE2 128 bit vector depending on target support), write transposed
30    // block into output at the transposed position and continue with next
31    // block.
32
33    let mut row: usize = 0;
34    while row <= rows - 16 {
35        let mut col = 0;
36        while col < cols {
37            // Load 16x8 sub-block by loading (row + 0, col) .. (row + 15, col)
38            let mut v = load_bytes(input, row, col, cols);
39            // The ideas is to take the most significant bit of each row of the sub-block
40            // (msb of each byte) and write these 16 bits coming from 16 rows
41            // and one column in the input to the correct row and columns in the
42            // output. Because the `move_mask` instruction gives us the `msb` of each byte
43            // (=row) in our input, we iterate the output_row_offset from large
44            // to small. After each iteration, we shift each byte in the
45            // sub-block one bit to the left and then again get the msb of each byte and
46            // write it to the next row.
47            // Visualization done by Claude 4 Sonnet:
48            // ┌─────────────────────────────────────────────────────────────┐
49            // │              BIT MATRIX TRANSPOSE: 16x8 BLOCK               │
50            // ├─────────────────────────────────────────────────────────────┤
51            // │                                                             │
52            // │  INPUT (16×8)              OUTPUT (8×16)                    │
53            // │                                                             │
54            // │    0 1 ⋯ 6 7                 0 1 ⋯ E F                      │
55            // │  0 ◆|◇|⋯|▲|△               0 ◆|◆|⋯|◆|◆                      │
56            // │  1 ◆|◇|⋯|▲|△               1 ◇|◇|⋯|◇|◇                      │
57            // │  ⋮                         ⋮                              │
58            // │  E ◆|◇|⋯|▲|△               6 ▲|▲|⋯|▲|▲                      │
59            // │  F ◆|◇|⋯|▲|△               7 △|△|⋯|△|△                      │
60            // │                                                             │
61            // │  MOVE_MASK: Extract column bits → Write as rows             │
62            // │                                                             │
63            // │  Iter 0: MSB column 7 → output row 7                        │
64            // │  Iter 1: << 1, MSB → output row 6                           │
65            // │  ⋮                                                         │
66            // │  Iter 7: << 7, MSB → output row 0                           │
67            // │                                                             │
68            // │  INPUT[row,col] → OUTPUT[col,row]                           │
69            // │                                                             │
70            // └─────────────────────────────────────────────────────────────┘
71            for output_row_offset in (0..8).rev() {
72                // get msb of each byte
73                let msbs = v.to_bitmask().to_le_bytes();
74                // write msbs to output at transposed position
75                let idx = out(row, col + output_row_offset, rows) as isize;
76                // This should result in only one bounds check for the output
77                let out_bytes = &mut output[idx as usize..idx as usize + 2];
78                out_bytes[0] = msbs[0];
79                out_bytes[1] = msbs[1];
80
81                // There is no shift impl for i8x16 so we cast to i64x2 and shift these.
82                // The bits shifted to neighbouring bytes are ignored because we iterate
83                // and call move_mask 8 times.
84                let v: &mut i64x2 = bytemuck::must_cast_mut(&mut v);
85                // shift each byte by one to the left (by shifting it as two i64)
86                *v = *v << 1;
87            }
88            col += 8;
89        }
90        row += 16;
91    }
92}
93
94#[inline]
95fn inp(x: usize, y: usize, cols: usize) -> usize {
96    x * cols / 8 + y / 8
97}
98#[inline]
99fn out(x: usize, y: usize, rows: usize) -> usize {
100    y * rows / 8 + x / 8
101}
102
103#[inline]
104// get col byte of row to row + 15
105fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
106    let bytes = std::array::from_fn(|i| b[inp(row + i, col, cols)] as i8);
107    i8x16::from(bytes)
108}
109
110#[cfg(test)]
111mod tests {
112
113    use proptest::prelude::*;
114
115    use super::*;
116
117    fn arbitrary_bitmat(max_row: usize, max_col: usize) -> BoxedStrategy<(Vec<u8>, usize, usize)> {
118        (
119            (16..max_row).prop_map(|row| row / 16 * 16),
120            (16..max_col).prop_map(|col| col / 16 * 16),
121        )
122            .prop_flat_map(|(rows, cols)| {
123                (vec![any::<u8>(); rows * cols / 8], Just(rows), Just(cols))
124            })
125            .boxed()
126    }
127
128    proptest! {
129        #[cfg(not(miri))]
130        #[test]
131        fn test_double_transpose((v, rows, cols) in arbitrary_bitmat(16 * 30, 16 * 30)) {
132            let mut transposed = vec![0; v.len()];
133            let mut double_transposed = vec![0; v.len()];
134            transpose_bitmatrix(&v,&mut transposed, rows);
135            transpose_bitmatrix(&transposed, &mut double_transposed, cols);
136
137            prop_assert_eq!(v, double_transposed);
138        }
139    }
140
141    #[test]
142    #[cfg(target_arch = "x86_64")] // miri doesn't know the intrinsics on e.g. ARM
143    fn test_double_transpose_miri() {
144        let rows = 32;
145        let cols = 16;
146        let v = vec![0; rows * cols];
147        let mut transposed = vec![0; v.len()];
148        let mut double_transposed = vec![0; v.len()];
149        transpose_bitmatrix(&v, &mut transposed, rows);
150        transpose_bitmatrix(&transposed, &mut double_transposed, cols);
151        assert_eq!(v, double_transposed);
152    }
153}