cryprot_core/transpose/
portable.rs

1use wide::{i8x16, i64x2};
2
3/// Transpose a bit matrix.
4///
5/// # Panics
6/// TODO
7pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
8    assert!(rows >= 16);
9    assert_eq!(0, rows % 16);
10    assert_eq!(0, input.len() % rows);
11    let cols = input.len() * 8 / rows;
12    assert!(cols >= 16);
13    assert_eq!(
14        0,
15        cols % 8,
16        "Number of bitmatrix columns must be divisable by 8. columns: {cols}"
17    );
18
19    unsafe {
20        let mut row: usize = 0;
21        while row <= rows - 16 {
22            let mut col = 0;
23            while col < cols {
24                let mut v = load_bytes(input, row, col, cols);
25                // reverse iterator because we start writing the msb of each byte, then shift
26                // left for i = 0, we write the previous lsb
27                for i in (0..8).rev() {
28                    // get msb of each byte
29                    let msbs = v.move_mask().to_le_bytes();
30                    // write msbs to output at transposed position as one i16
31                    let msb_i16 = i16::from_ne_bytes([msbs[0], msbs[1]]);
32                    let idx = out(row, col + i, rows) as isize;
33                    let out_ptr = output.as_mut_ptr().offset(idx) as *mut i16;
34                    // ptr is potentially unaligned
35                    out_ptr.write_unaligned(msb_i16);
36
37                    // SAFETY: u8x16 and i64x2 have the same layout
38                    //  we need to convert cast it, because there is no shift impl for u8x16
39                    let v_i64x2 = &mut v as *mut _ as *mut i64x2;
40                    // shift each byte by one to the left (by shifting it as two i64)
41                    *v_i64x2 = *v_i64x2 << 1;
42                }
43                col += 8;
44            }
45            row += 16;
46        }
47    }
48}
49
50#[inline]
51fn inp(x: usize, y: usize, cols: usize) -> usize {
52    x * cols / 8 + y / 8
53}
54#[inline]
55fn out(x: usize, y: usize, rows: usize) -> usize {
56    y * rows / 8 + x / 8
57}
58
59#[inline]
60// get col byte of row to row + 15
61unsafe fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
62    unsafe {
63        // if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes
64        // faster than from impl
65        #[cfg(target_feature = "sse2")]
66        {
67            use std::{arch::x86_64::_mm_setr_epi8, mem::transmute};
68            let v = _mm_setr_epi8(
69                *b.get_unchecked(inp(row, col, cols)) as i8,
70                *b.get_unchecked(inp(row + 1, col, cols)) as i8,
71                *b.get_unchecked(inp(row + 2, col, cols)) as i8,
72                *b.get_unchecked(inp(row + 3, col, cols)) as i8,
73                *b.get_unchecked(inp(row + 4, col, cols)) as i8,
74                *b.get_unchecked(inp(row + 5, col, cols)) as i8,
75                *b.get_unchecked(inp(row + 6, col, cols)) as i8,
76                *b.get_unchecked(inp(row + 7, col, cols)) as i8,
77                *b.get_unchecked(inp(row + 8, col, cols)) as i8,
78                *b.get_unchecked(inp(row + 9, col, cols)) as i8,
79                *b.get_unchecked(inp(row + 10, col, cols)) as i8,
80                *b.get_unchecked(inp(row + 11, col, cols)) as i8,
81                *b.get_unchecked(inp(row + 12, col, cols)) as i8,
82                *b.get_unchecked(inp(row + 13, col, cols)) as i8,
83                *b.get_unchecked(inp(row + 14, col, cols)) as i8,
84                *b.get_unchecked(inp(row + 15, col, cols)) as i8,
85            );
86            transmute(v)
87        }
88        #[cfg(not(target_feature = "sse2"))]
89        {
90            let bytes = std::array::from_fn(|i| *b.get_unchecked(inp(row + i, col, cols)) as i8);
91            i8x16::from(bytes)
92        }
93    }
94}
95
96#[cfg(test)]
97mod tests {
98
99    use proptest::prelude::*;
100
101    use super::*;
102
103    fn arbitrary_bitmat(max_row: usize, max_col: usize) -> BoxedStrategy<(Vec<u8>, usize, usize)> {
104        (
105            (16..max_row).prop_map(|row| row / 16 * 16),
106            (16..max_col).prop_map(|col| col / 16 * 16),
107        )
108            .prop_flat_map(|(rows, cols)| {
109                (vec![any::<u8>(); rows * cols / 8], Just(rows), Just(cols))
110            })
111            .boxed()
112    }
113
114    proptest! {
115        #[cfg(not(miri))]
116        #[test]
117        fn test_double_transpose((v, rows, cols) in arbitrary_bitmat(16 * 30, 16 * 30)) {
118            let mut transposed = vec![0; v.len()];
119            let mut double_transposed = vec![0; v.len()];
120            transpose_bitmatrix(&v,&mut transposed, rows);
121            transpose_bitmatrix(&transposed, &mut double_transposed, cols);
122
123            prop_assert_eq!(v, double_transposed);
124        }
125    }
126
127    #[test]
128    #[cfg(target_arch = "x86_64")] // miri doesn't know the intrinsics on e.g. ARM
129    fn test_double_transpose_miri() {
130        let rows = 32;
131        let cols = 16;
132        let v = vec![0; rows * cols];
133        let mut transposed = vec![0; v.len()];
134        let mut double_transposed = vec![0; v.len()];
135        transpose_bitmatrix(&v, &mut transposed, rows);
136        transpose_bitmatrix(&transposed, &mut double_transposed, cols);
137        assert_eq!(v, double_transposed);
138    }
139}