cryprot_core/transpose/
avx2.rs

1//! Implementation of AVX2 BitMatrix transpose based on libOTe.
2use std::{arch::x86_64::*, cmp};
3
4use bytemuck::{must_cast_slice, must_cast_slice_mut};
5use seq_macro::seq;
6
7/// Performs a 2x2 bit transpose operation on two 256-bit vectors representing a
8/// 4x128 matrix.
9#[inline]
10#[target_feature(enable = "avx2")]
11fn transpose_2x2_matrices(x: &mut __m256i, y: &mut __m256i) {
12    // x = [x_H | x_L] and y = [y_H | y_L]
13    // u = [y_L | x_L] u is the low 128 bits of x and y
14    let u = _mm256_permute2x128_si256(*x, *y, 0x20);
15    // v = [y_H | x_H] v is the high 128 bits of x and y
16    let v = _mm256_permute2x128_si256(*x, *y, 0x31);
17    // Shift v by one left so each element in at (i, j) aligns with (i+1, j-1) and
18    // compute the difference. the row shift i+1 is done by the permute
19    // instructions before and the column by the sll instruction
20    let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1));
21    // select all odd indices of diff and zero out even indices. the idea is to
22    // calculate the difference of all odd numbered indices j of the even
23    // numbered row i with the even numbered indices j-1 in row i+1.
24    // These are precisely the elements in the 2x2 matrices that make up x and y
25    // that potentially need to be swapped for the transpose if they differ
26    diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16));
27    // perform the swaps in u, which corresponds the lower bits of x and y by XORing
28    // the diff
29    let u = _mm256_xor_si256(u, diff);
30    // for the bottom row in the 2x2 matrices (the high bits of x and y) we need to
31    // shift the diff by 1 to the right so it aligns with the even numbered indices
32    let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1));
33    // the permuted 2x2 matrices are split over u and v, with the upper row in u and
34    // the lower in v. We perform the same permutation as in the beginning, thereby
35    // writing the 2x2 permuted bits of x and y back
36    *x = _mm256_permute2x128_si256(u, v, 0x20);
37    *y = _mm256_permute2x128_si256(u, v, 0x31);
38}
39
40/// Performs a general bit-level transpose.
41///
42/// `SHIFT_AMOUNT` is the constant shift value (e.g., 2, 4, 8, 16, 32) for the
43/// intrinsics. `MASK` is the bitmask for the XOR-swap.
44#[inline]
45#[target_feature(enable = "avx2")]
46fn partial_swap_sub_matrices<const SHIFT_AMOUNT: i32, const MASK: u64>(
47    x: &mut __m256i,
48    y: &mut __m256i,
49) {
50    // calculate the diff of the bits that need to be potentially swapped
51    let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64::<SHIFT_AMOUNT>(*y));
52    diff = _mm256_and_si256(diff, _mm256_set1_epi64x(MASK as i64));
53    // swap the bits in x by xoring the difference
54    *x = _mm256_xor_si256(*x, diff);
55    // and in y
56    *y = _mm256_xor_si256(*y, _mm256_srli_epi64::<SHIFT_AMOUNT>(diff));
57}
58
59/// Performs a partial 64x64 bit matrix swap. This is used to swap the rows in
60/// the upper right quadrant with those of the lower left in the 128x128 matrix.
61#[inline]
62#[target_feature(enable = "avx2")]
63fn partial_swap_64x64_matrices(x: &mut __m256i, y: &mut __m256i) {
64    let out_x = _mm256_unpacklo_epi64(*x, *y);
65    let out_y = _mm256_unpackhi_epi64(*x, *y);
66    *x = out_x;
67    *y = out_y;
68}
69
70/// Transpose a 128x128 bit matrix using AVX2 intrinsics.
71///
72/// # Safety
73/// AVX2 needs to be enabled.
74#[target_feature(enable = "avx2")]
75pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
76    // This algorithm implements a bit-transpose of a 128x128 bit matrix using a
77    // divide-and-conquer algorithm. The idea is that for
78    // A = [ A B ]
79    //     [ C D ]
80    // A^T is equal to
81    //     [ A^T C^T ]
82    //     [ B^T D^T ]
83    //
84    // We first divide our matrix into 2x2 bit matrices which we transpose at the
85    // bit level. Then we swap the 2x2 bit matrices to complete a 4x4
86    // transpose. We swap the 4x4 bit matrices to complete a 8x8 transpose and so on
87    // until we swap 64x64 bit matrices and thus complete the intended 128x128 bit
88    // transpose.
89
90    // Part 1: Specialized 2x2 block transpose transposing individual bits
91    for chunk in in_out.chunks_exact_mut(2) {
92        if let [x, y] = chunk {
93            transpose_2x2_matrices(x, y);
94        } else {
95            unreachable!("chunk size is 2")
96        }
97    }
98
99    // Phases 1-5: swap sub-matrices of size 2x2, 4x4, 8x8, 16x16, 32x32 bit
100    // Using seq_macro to reduce repetition
101    seq!(N in 1..=5 {
102        const SHIFT_~N: i32 = 1 << N;
103        // Our mask selects the part of the sub-matrix that needs to be potentially
104        // swapped allong the diagonal. The lower 2^SHIFT bits are 0 and the following
105        // 2^SHIFT bits are 1, repeated to a 64 bit mask
106        const MASK_~N: u64 = match N {
107            1 => mask(0b1100, 4),
108            2 => mask(0b11110000, 8),
109            3 => mask(0b1111111100000000, 16),
110            4 => mask(0b11111111111111110000000000000000, 32),
111            5 => 0xffffffff00000000,
112            _ => unreachable!(),
113        };
114        // The offset between x and y for matrix rows that need to be swapped in terms
115        // of 256 bit elements. In the first iteration we swap the 2x2 matrices that
116        // are at positions in_out[i] and in_out[j], so the offset is 1. For 4x4 matrices
117        // the offset is 2
118        #[allow(clippy::eq_op)] // false positive due to use of seq!
119        const OFFSET~N: usize = 1 << (N - 1);
120
121        for chunk in in_out.chunks_exact_mut(2 * OFFSET~N) {
122            let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET~N);
123            // For larger matrices, and larger offsets, we need to iterate over all
124            // rows of the sub-matrices
125            for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) {
126                partial_swap_sub_matrices::<SHIFT_~N, MASK_~N>(x, y);
127            }
128        }
129    });
130
131    // Phase 6: swap 64x64 bit-matrices therefore completing the 128x128 bit
132    // transpose
133    const SHIFT_6: usize = 6;
134    const OFFSET_6: usize = 1 << (SHIFT_6 - 1); // 32
135
136    for chunk in in_out.chunks_exact_mut(2 * OFFSET_6) {
137        let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET_6);
138        for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) {
139            partial_swap_64x64_matrices(x, y);
140        }
141    }
142}
143
144/// Create a u64 bit mask based on the pattern which is repeated to fill the u54
145const fn mask(pattern: u64, pattern_len: u32) -> u64 {
146    let mut mask = pattern;
147    let mut current_block_len = pattern_len;
148
149    // We keep doubling the effective length of our repeating block
150    // until it covers 64 bits.
151    while current_block_len < 64 {
152        mask = (mask << current_block_len) | mask;
153        current_block_len *= 2;
154    }
155
156    mask
157}
158
159/// Transpose a bit matrix using AVX2.
160///
161/// This implementation is specifically tuned for transposing `128 x l` matrices
162/// as done in OT protocols. Performance might be better if `input` is 16-byte
163/// aligned and the number of columns is divisible by 512 on systems with
164/// 64-byte cache lines.
165///
166/// # Panics
167/// If `input.len() != output.len()`
168/// If the number of rows is less than 128.
169/// If `input.len()` is not divisible by rows.
170/// If the number of rows is not divisible by 128.
171/// If the number of columns (= input.len() * 8 / rows) is not divisible by 8.
172///
173/// # Safety
174/// AVX2 instruction set must be available.
175#[target_feature(enable = "avx2")]
176pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
177    assert_eq!(input.len(), output.len());
178    assert!(rows >= 128, "Number of rows must be >= 128.");
179    assert_eq!(
180        0,
181        input.len() % rows,
182        "input.len(), must be divisble by rows"
183    );
184    assert_eq!(0, rows % 128, "Number of rows must be a multiple of 128.");
185    let cols = input.len() * 8 / rows;
186    assert_eq!(0, cols % 8, "Number of columns must be a multiple of 8.");
187
188    // Buffer to hold a 4 128x128 bit squares (64 * 4 __m256i registers = 2048 * 4
189    // bytes)
190    let mut buf = [_mm256_setzero_si256(); 64 * 4];
191    let in_stride = cols / 8; // Stride in bytes for input rows
192    let out_stride = rows / 8; // Stride in bytes for output rows
193
194    // Number of 128x128 bit squares in rows and columns
195    let r_main = rows / 128;
196    let c_main = cols / 128;
197    let c_rest = cols % 128;
198
199    // Iterate through each 128x128 bit square in the matrix
200    // Row block index
201    for i in 0..r_main {
202        // Column block index
203        let mut j = 0;
204        while j < c_main {
205            let input_offset = i * 128 * in_stride + j * 16;
206            let curr_addr = input[input_offset..].as_ptr().addr();
207            let next_cache_line_addr = (curr_addr + 1).next_multiple_of(64); // cache line size
208            let blocks_in_cache_line = (next_cache_line_addr - curr_addr) / 16;
209
210            let remaining_blocks_in_cache_line = if blocks_in_cache_line == 0 {
211                // will cross over a cache line, but if the blocks are not 16-byte aligned, this
212                // is the best we can do
213                4
214            } else {
215                blocks_in_cache_line
216            };
217            // Ensure we don't read OOB of the input
218            let remaining_blocks_in_cache_line =
219                cmp::min(remaining_blocks_in_cache_line, c_main - j);
220
221            let buf_as_bytes: &mut [u8] = must_cast_slice_mut(&mut buf);
222
223            // The loading loop loads the input data into the buf. By using a macro and
224            // matching on 4 blocks in a cache line (each row in a block is 16 bytes, so the
225            // rows 4 consecutive blocks are 64 bytes long) the optimizer uses a loop
226            // unrolled version for this case.
227            macro_rules! loading_loop {
228                ($remaining_blocks_in_cache_line:expr) => {
229                    for k in 0..128 {
230                        let src_slice = &input[input_offset + k * in_stride
231                            ..input_offset + k * in_stride + 16 * remaining_blocks_in_cache_line];
232
233                        for block in 0..remaining_blocks_in_cache_line {
234                            buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16]
235                                .copy_from_slice(&src_slice[block * 16..(block + 1) * 16]);
236                        }
237                    }
238                };
239            }
240
241            // This gets optimized to the unrolled loop for the default case of 4 blocks
242            match remaining_blocks_in_cache_line {
243                4 => loading_loop!(4),
244                #[allow(unused_variables)] // false positive
245                other => loading_loop!(other),
246            }
247
248            for block in 0..remaining_blocks_in_cache_line {
249                avx_transpose128x128(
250                    (&mut buf[block * 64..(block + 1) * 64])
251                        .try_into()
252                        .expect("slice has length 64"),
253                );
254            }
255
256            let mut output_offset = j * 128 * out_stride + i * 16;
257            let buf_as_bytes: &[u8] = must_cast_slice(&buf);
258
259            if out_stride == 16 {
260                // if the out_stride is 16 bytes, the transposed sub-matrices are in contigous
261                // memory in the output, so we can use a single copy_from_slice. This is
262                // especially helpfule for the case of transposing a 128xl matrix as done in OT
263                // extension.
264                let dst_slice = &mut output
265                    [output_offset..output_offset + 16 * 128 * remaining_blocks_in_cache_line];
266                dst_slice.copy_from_slice(&buf_as_bytes[..remaining_blocks_in_cache_line * 2048]);
267            } else {
268                for block in 0..remaining_blocks_in_cache_line {
269                    for k in 0..128 {
270                        let src_slice =
271                            &buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16];
272                        let dst_slice = &mut output
273                            [output_offset + k * out_stride..output_offset + k * out_stride + 16];
274                        dst_slice.copy_from_slice(src_slice);
275                    }
276                    output_offset += 128 * out_stride;
277                }
278            }
279
280            j += remaining_blocks_in_cache_line;
281        }
282
283        if c_rest > 0 {
284            handle_rest_cols(input, output, &mut buf, in_stride, out_stride, c_rest, i, j);
285        }
286    }
287}
288
289// Inline never to reduce code size of `transpose_bitmatrix` method. This is
290// method is only called once row block if the columns are not divisible by 128.
291// Since this is only rarely executed opposed to the core loop of
292// `transpose_bitmatrix` we annotate it with inline(never) to ensure the
293// optimizer doesn't inline it which could negatively impact performance
294// due to larger code size and potentially more instruction cache misses. This
295// is an assumption and not verified by a benchmark, but even if it were wrong,
296// it shouldn't negatively impact runtime because this method is called rarely
297// in our use cases where we have 128 rows and many columns.
298#[inline(never)]
299#[target_feature(enable = "avx2")]
300#[allow(clippy::too_many_arguments)]
301fn handle_rest_cols(
302    input: &[u8],
303    output: &mut [u8],
304    buf: &mut [__m256i; 256],
305    in_stride: usize,
306    out_stride: usize,
307    c_rest: usize,
308    i: usize,
309    j: usize,
310) {
311    let input_offset = i * 128 * in_stride + j * 16;
312    let remaining_cols_bytes = c_rest / 8;
313    buf[0..64].fill(_mm256_setzero_si256());
314    let buf_as_bytes: &mut [u8] = must_cast_slice_mut(buf);
315
316    for k in 0..128 {
317        let src_row_offset = input_offset + k * in_stride;
318        let src_slice = &input[src_row_offset..src_row_offset + remaining_cols_bytes];
319        // we use 16 because we still transpose a 128x128 matrix, of which only a part
320        // is filled
321        let buf_offset = k * 16;
322        buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(src_slice);
323    }
324
325    avx_transpose128x128((&mut buf[..64]).try_into().expect("slice has length 64"));
326
327    let output_offset = j * 128 * out_stride + i * 16;
328    let buf_as_bytes: &[u8] = must_cast_slice(&*buf);
329
330    for k in 0..c_rest {
331        let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16];
332        let dst_slice =
333            &mut output[output_offset + k * out_stride..output_offset + k * out_stride + 16];
334        dst_slice.copy_from_slice(src_slice);
335    }
336}
337
338#[cfg(all(test, target_feature = "avx2"))]
339mod tests {
340    use std::arch::x86_64::_mm256_setzero_si256;
341
342    use rand::{RngCore, SeedableRng, rngs::StdRng};
343
344    use super::{avx_transpose128x128, transpose_bitmatrix};
345
346    #[test]
347    fn test_avx_transpose128() {
348        unsafe {
349            let mut v = [_mm256_setzero_si256(); 64];
350            StdRng::seed_from_u64(42).fill_bytes(bytemuck::cast_slice_mut(&mut v));
351
352            let orig = v;
353            avx_transpose128x128(&mut v);
354            avx_transpose128x128(&mut v);
355            let mut failed = false;
356            for (i, (o, t)) in orig.into_iter().zip(v).enumerate() {
357                let o = bytemuck::cast::<_, [u128; 2]>(o);
358                let t = bytemuck::cast::<_, [u128; 2]>(t);
359                if o != t {
360                    eprintln!("difference in block {i}");
361                    eprintln!("orig: {o:?}");
362                    eprintln!("tran: {t:?}");
363                    failed = true;
364                }
365            }
366            if failed {
367                panic!("double transposed is different than original")
368            }
369        }
370    }
371
372    #[test]
373    fn test_avx_transpose() {
374        let rows = 128 * 2;
375        let cols = 128 * 2;
376        let mut v = vec![0_u8; rows * cols / 8];
377        StdRng::seed_from_u64(42).fill_bytes(&mut v);
378
379        let mut avx_transposed = v.clone();
380        let mut sse_transposed = v.clone();
381        unsafe {
382            transpose_bitmatrix(&v, &mut avx_transposed, rows);
383        }
384        crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
385
386        assert_eq!(sse_transposed, avx_transposed);
387    }
388
389    #[test]
390    fn test_avx_transpose_unaligned_data() {
391        let rows = 128 * 2;
392        let cols = 128 * 2;
393        let mut v = vec![0_u8; rows * (cols + 128) / 8];
394        StdRng::seed_from_u64(42).fill_bytes(&mut v);
395
396        let v = {
397            let addr = v.as_ptr().addr();
398            let offset = addr.next_multiple_of(3) - addr;
399            &v[offset..offset + rows * cols / 8]
400        };
401        assert_eq!(0, v.as_ptr().addr() % 3);
402        // allocate out bufs with same dims
403        let mut avx_transposed = v.to_owned();
404        let mut sse_transposed = v.to_owned();
405
406        unsafe {
407            transpose_bitmatrix(&v, &mut avx_transposed, rows);
408        }
409        crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
410
411        assert_eq!(sse_transposed, avx_transposed);
412    }
413
414    #[test]
415    fn test_avx_transpose_larger_cols_divisible_by_4_times_128() {
416        let rows = 128;
417        let cols = 128 * 8;
418        let mut v = vec![0_u8; rows * cols / 8];
419        StdRng::seed_from_u64(42).fill_bytes(&mut v);
420
421        let mut avx_transposed = v.clone();
422        let mut sse_transposed = v.clone();
423        unsafe {
424            transpose_bitmatrix(&v, &mut avx_transposed, rows);
425        }
426        crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
427
428        assert_eq!(sse_transposed, avx_transposed);
429    }
430
431    #[test]
432    fn test_avx_transpose_larger_cols_divisible_by_8() {
433        let rows = 128;
434        let cols = 128 + 32;
435        let mut v = vec![0_u8; rows * cols / 8];
436        StdRng::seed_from_u64(42).fill_bytes(&mut v);
437
438        let mut avx_transposed = v.clone();
439        let mut sse_transposed = v.clone();
440        unsafe {
441            transpose_bitmatrix(&v, &mut avx_transposed, rows);
442        }
443        crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
444
445        assert_eq!(sse_transposed, avx_transposed);
446    }
447}