cryprot_core/transpose/
avx2.rs

1//! Implementation of AVX2 BitMatrix transpose based on libOTe.
2use std::{arch::x86_64::*, hint::unreachable_unchecked};
3
4#[inline]
5#[target_feature(enable = "avx2")]
6/// Must be called with `matches!(shift, 2 | 4 | 8 | 16 | 32)`
7unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
8    debug_assert!(
9        matches!(shift, 2 | 4 | 8 | 16 | 32),
10        "Must be called with correct shift"
11    );
12    unsafe {
13        match shift {
14            2 => _mm256_slli_epi64::<2>(a),
15            4 => _mm256_slli_epi64::<4>(a),
16            8 => _mm256_slli_epi64::<8>(a),
17            16 => _mm256_slli_epi64::<16>(a),
18            32 => _mm256_slli_epi64::<32>(a),
19            _ => unreachable_unchecked(),
20        }
21    }
22}
23
24#[inline]
25#[target_feature(enable = "avx2")]
26/// Must be called with `matches!(shift, 2 | 4 | 8 | 16 | 32)`
27unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
28    debug_assert!(
29        matches!(shift, 2 | 4 | 8 | 16 | 32),
30        "Must be called with correct shift"
31    );
32    unsafe {
33        match shift {
34            2 => _mm256_srli_epi64::<2>(a),
35            4 => _mm256_srli_epi64::<4>(a),
36            8 => _mm256_srli_epi64::<8>(a),
37            16 => _mm256_srli_epi64::<16>(a),
38            32 => _mm256_srli_epi64::<32>(a),
39            _ => unreachable_unchecked(),
40        }
41    }
42}
43
44// Transpose a 2^block_size_shift x 2^block_size_shift block within a larger
45// matrix Only handles first two rows out of every 2^block_rows_shift rows in
46// each block
47#[inline]
48#[target_feature(enable = "avx2")]
49unsafe fn avx_transpose_block_iter1(
50    in_out: *mut __m256i,
51    block_size_shift: usize,
52    block_rows_shift: usize,
53    j: usize,
54) {
55    if j < (1 << block_size_shift) && block_size_shift == 6 {
56        unsafe {
57            let x = in_out.add(j / 2);
58            let y = in_out.add(j / 2 + 32);
59
60            let out_x = _mm256_unpacklo_epi64(*x, *y);
61            let out_y = _mm256_unpackhi_epi64(*x, *y);
62            *x = out_x;
63            *y = out_y;
64            return;
65        }
66    }
67
68    if block_size_shift == 0 || block_size_shift >= 6 || block_rows_shift < 1 {
69        return;
70    }
71
72    // Calculate mask for the current block size
73    let mut mask = (!0u64) << 32;
74    for k in (block_size_shift as i32..=4).rev() {
75        mask ^= mask >> (1 << k);
76    }
77
78    unsafe {
79        let x = in_out.add(j / 2);
80        let y = in_out.add(j / 2 + (1 << (block_size_shift - 1)));
81
82        // Special case for 2x2 blocks (block_size_shift == 1)
83        if block_size_shift == 1 {
84            let u = _mm256_permute2x128_si256(*x, *y, 0x20);
85            let v = _mm256_permute2x128_si256(*x, *y, 0x31);
86
87            let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1));
88            diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16));
89            let u = _mm256_xor_si256(u, diff);
90            let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1));
91
92            *x = _mm256_permute2x128_si256(u, v, 0x20);
93            *y = _mm256_permute2x128_si256(u, v, 0x31);
94        }
95
96        let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64_var_shift(*y, 1 << block_size_shift));
97        diff = _mm256_and_si256(diff, _mm256_set1_epi64x(mask as i64));
98        *x = _mm256_xor_si256(*x, diff);
99        *y = _mm256_xor_si256(*y, _mm256_srli_epi64_var_shift(diff, 1 << block_size_shift));
100    }
101}
102
103#[inline] // Process a range of rows in the matrix
104#[target_feature(enable = "avx2")]
105unsafe fn avx_transpose_block_iter2(
106    in_out: *mut __m256i,
107    block_size_shift: usize,
108    block_rows_shift: usize,
109    n_rows: usize,
110) {
111    let mat_size = 1 << (block_size_shift + 1);
112
113    for i in (0..n_rows).step_by(mat_size) {
114        for j in (0..(1 << block_size_shift)).step_by(1 << block_rows_shift) {
115            unsafe {
116                avx_transpose_block_iter1(in_out.add(i / 2), block_size_shift, block_rows_shift, j);
117            }
118        }
119    }
120}
121
122#[inline] // Main transpose function for blocks within the matrix
123#[target_feature(enable = "avx2")]
124unsafe fn avx_transpose_block(
125    in_out: *mut __m256i,
126    block_size_shift: usize,
127    mat_size_shift: usize,
128    block_rows_shift: usize,
129    mat_rows_shift: usize,
130) {
131    if block_size_shift >= mat_size_shift {
132        return;
133    }
134
135    // Process current block size
136    let total_rows = 1 << (mat_rows_shift + mat_size_shift);
137
138    unsafe {
139        avx_transpose_block_iter2(in_out, block_size_shift, block_rows_shift, total_rows);
140
141        // Recursively process larger blocks
142        avx_transpose_block(
143            in_out,
144            block_size_shift + 1,
145            mat_size_shift,
146            block_rows_shift,
147            mat_rows_shift,
148        );
149    }
150}
151
152const AVX_BLOCK_SHIFT: usize = 4;
153const AVX_BLOCK_SIZE: usize = 1 << AVX_BLOCK_SHIFT;
154
155/// Transpose 128x128 bit matrix using AVX2.
156///
157/// # Safety
158/// AVX2 needs to be enabled.
159#[target_feature(enable = "avx2")]
160pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
161    const MAT_SIZE_SHIFT: usize = 7;
162    unsafe {
163        let in_out = in_out.as_mut_ptr();
164        for i in (0..64).step_by(AVX_BLOCK_SIZE) {
165            avx_transpose_block(
166                in_out.add(i),
167                1,
168                MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT,
169                1,
170                AVX_BLOCK_SHIFT + 1 - (MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT),
171            );
172        }
173
174        // Process larger blocks
175        let block_size_shift = MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT;
176
177        // Special case for full matrix
178        for i in 0..(1 << (block_size_shift - 1)) {
179            avx_transpose_block(
180                in_out.add(i),
181                block_size_shift,
182                MAT_SIZE_SHIFT,
183                block_size_shift,
184                0,
185            );
186        }
187    }
188}
189
190/// Transpose a bit matrix.
191///
192/// # Panics
193/// If the input is not divisable by 128.
194/// If the number of columns (= input.len() * 8 / 128) is less than 128.
195/// If `input.len() != output.len()`
196///
197/// # Safety
198/// AVX2 instruction set must be available.
199#[target_feature(enable = "avx2")]
200pub unsafe fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
201    assert_eq!(input.len(), output.len());
202    let cols = input.len() * 8 / rows;
203    assert_eq!(0, cols % 128);
204    assert_eq!(0, rows % 128);
205    #[allow(unused_unsafe)]
206    let mut buf = [unsafe { _mm256_setzero_si256() }; 64];
207    let in_stride = cols / 8;
208    let out_stride = rows / 8;
209
210    // Number of 128x128 bit squares
211    let r_main = rows / 128;
212    let c_main = cols / 128;
213
214    for i in 0..r_main {
215        for j in 0..c_main {
216            // Process each 128x128 bit square
217            unsafe {
218                let src_ptr = input.as_ptr().add(i * 128 * in_stride + j * 16);
219
220                let buf_u8_ptr = buf.as_mut_ptr() as *mut u8;
221
222                // Copy 128 rows into buffer
223                for k in 0..128 {
224                    let src_row = src_ptr.add(k * in_stride);
225                    std::ptr::copy_nonoverlapping(src_row, buf_u8_ptr.add(k * 16), 16);
226                }
227            }
228            // Transpose the 128x128 bit square
229            avx_transpose128x128(&mut buf);
230
231            unsafe {
232                // needs to be recreated because prev &mut borrow invalidates ptr
233                let buf_u8_ptr = buf.as_mut_ptr() as *mut u8;
234                // Copy transposed data to output
235                let dst_ptr = output.as_mut_ptr().add(j * 128 * out_stride + i * 16);
236                for k in 0..128 {
237                    let dst_row = dst_ptr.add(k * out_stride);
238                    std::ptr::copy_nonoverlapping(buf_u8_ptr.add(k * 16), dst_row, 16);
239                }
240            }
241        }
242    }
243}
244
245#[cfg(all(test, target_feature = "avx2"))]
246mod tests {
247    use std::arch::x86_64::_mm256_setzero_si256;
248
249    use rand::{RngCore, SeedableRng, rngs::StdRng};
250
251    use super::{avx_transpose128x128, transpose_bitmatrix};
252
253    #[test]
254    fn test_avx_transpose128() {
255        unsafe {
256            let mut v = [_mm256_setzero_si256(); 64];
257            StdRng::seed_from_u64(42).fill_bytes(bytemuck::cast_slice_mut(&mut v));
258
259            let orig = v.clone();
260            avx_transpose128x128(&mut v);
261            avx_transpose128x128(&mut v);
262            let mut failed = false;
263            for (i, (o, t)) in orig.into_iter().zip(v).enumerate() {
264                let o = bytemuck::cast::<_, [u128; 2]>(o);
265                let t = bytemuck::cast::<_, [u128; 2]>(t);
266                if o != t {
267                    eprintln!("difference in block {i}");
268                    eprintln!("orig: {o:?}");
269                    eprintln!("tran: {t:?}");
270                    failed = true;
271                }
272            }
273            if failed {
274                panic!("double transposed is different than original")
275            }
276        }
277    }
278
279    #[test]
280    fn test_avx_transpose() {
281        let rows = 128 * 2;
282        let cols = 128 * 2;
283        let mut v = vec![0_u8; rows * cols / 8];
284        StdRng::seed_from_u64(42).fill_bytes(&mut v);
285
286        let mut avx_transposed = v.clone();
287        let mut sse_transposed = v.clone();
288        unsafe {
289            transpose_bitmatrix(&v, &mut avx_transposed, rows);
290        }
291        crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
292
293        assert_eq!(sse_transposed, avx_transposed);
294    }
295}