1use std::{arch::x86_64::*, cmp};
3
4use bytemuck::{must_cast_slice, must_cast_slice_mut};
5use seq_macro::seq;
6
7#[inline]
10#[target_feature(enable = "avx2")]
11fn transpose_2x2_matrices(x: &mut __m256i, y: &mut __m256i) {
12 let u = _mm256_permute2x128_si256(*x, *y, 0x20);
15 let v = _mm256_permute2x128_si256(*x, *y, 0x31);
17 let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1));
21 diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16));
27 let u = _mm256_xor_si256(u, diff);
30 let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1));
33 *x = _mm256_permute2x128_si256(u, v, 0x20);
37 *y = _mm256_permute2x128_si256(u, v, 0x31);
38}
39
40#[inline]
45#[target_feature(enable = "avx2")]
46fn partial_swap_sub_matrices<const SHIFT_AMOUNT: i32, const MASK: u64>(
47 x: &mut __m256i,
48 y: &mut __m256i,
49) {
50 let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64::<SHIFT_AMOUNT>(*y));
52 diff = _mm256_and_si256(diff, _mm256_set1_epi64x(MASK as i64));
53 *x = _mm256_xor_si256(*x, diff);
55 *y = _mm256_xor_si256(*y, _mm256_srli_epi64::<SHIFT_AMOUNT>(diff));
57}
58
59#[inline]
62#[target_feature(enable = "avx2")]
63fn partial_swap_64x64_matrices(x: &mut __m256i, y: &mut __m256i) {
64 let out_x = _mm256_unpacklo_epi64(*x, *y);
65 let out_y = _mm256_unpackhi_epi64(*x, *y);
66 *x = out_x;
67 *y = out_y;
68}
69
70#[target_feature(enable = "avx2")]
75pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
76 for chunk in in_out.chunks_exact_mut(2) {
92 if let [x, y] = chunk {
93 transpose_2x2_matrices(x, y);
94 } else {
95 unreachable!("chunk size is 2")
96 }
97 }
98
99 seq!(N in 1..=5 {
102 const SHIFT_~N: i32 = 1 << N;
103 const MASK_~N: u64 = match N {
107 1 => mask(0b1100, 4),
108 2 => mask(0b11110000, 8),
109 3 => mask(0b1111111100000000, 16),
110 4 => mask(0b11111111111111110000000000000000, 32),
111 5 => 0xffffffff00000000,
112 _ => unreachable!(),
113 };
114 #[allow(clippy::eq_op)] const OFFSET~N: usize = 1 << (N - 1);
120
121 for chunk in in_out.chunks_exact_mut(2 * OFFSET~N) {
122 let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET~N);
123 for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) {
126 partial_swap_sub_matrices::<SHIFT_~N, MASK_~N>(x, y);
127 }
128 }
129 });
130
131 const SHIFT_6: usize = 6;
134 const OFFSET_6: usize = 1 << (SHIFT_6 - 1); for chunk in in_out.chunks_exact_mut(2 * OFFSET_6) {
137 let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET_6);
138 for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) {
139 partial_swap_64x64_matrices(x, y);
140 }
141 }
142}
143
144const fn mask(pattern: u64, pattern_len: u32) -> u64 {
146 let mut mask = pattern;
147 let mut current_block_len = pattern_len;
148
149 while current_block_len < 64 {
152 mask = (mask << current_block_len) | mask;
153 current_block_len *= 2;
154 }
155
156 mask
157}
158
159#[target_feature(enable = "avx2")]
176pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
177 assert_eq!(input.len(), output.len());
178 assert!(rows >= 128, "Number of rows must be >= 128.");
179 assert_eq!(
180 0,
181 input.len() % rows,
182 "input.len(), must be divisble by rows"
183 );
184 assert_eq!(0, rows % 128, "Number of rows must be a multiple of 128.");
185 let cols = input.len() * 8 / rows;
186 assert_eq!(0, cols % 8, "Number of columns must be a multiple of 8.");
187
188 let mut buf = [_mm256_setzero_si256(); 64 * 4];
191 let in_stride = cols / 8; let out_stride = rows / 8; let r_main = rows / 128;
196 let c_main = cols / 128;
197 let c_rest = cols % 128;
198
199 for i in 0..r_main {
202 let mut j = 0;
204 while j < c_main {
205 let input_offset = i * 128 * in_stride + j * 16;
206 let curr_addr = input[input_offset..].as_ptr().addr();
207 let next_cache_line_addr = (curr_addr + 1).next_multiple_of(64); let blocks_in_cache_line = (next_cache_line_addr - curr_addr) / 16;
209
210 let remaining_blocks_in_cache_line = if blocks_in_cache_line == 0 {
211 4
214 } else {
215 blocks_in_cache_line
216 };
217 let remaining_blocks_in_cache_line =
219 cmp::min(remaining_blocks_in_cache_line, c_main - j);
220
221 let buf_as_bytes: &mut [u8] = must_cast_slice_mut(&mut buf);
222
223 macro_rules! loading_loop {
228 ($remaining_blocks_in_cache_line:expr) => {
229 for k in 0..128 {
230 let src_slice = &input[input_offset + k * in_stride
231 ..input_offset + k * in_stride + 16 * remaining_blocks_in_cache_line];
232
233 for block in 0..remaining_blocks_in_cache_line {
234 buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16]
235 .copy_from_slice(&src_slice[block * 16..(block + 1) * 16]);
236 }
237 }
238 };
239 }
240
241 match remaining_blocks_in_cache_line {
243 4 => loading_loop!(4),
244 #[allow(unused_variables)] other => loading_loop!(other),
246 }
247
248 for block in 0..remaining_blocks_in_cache_line {
249 avx_transpose128x128(
250 (&mut buf[block * 64..(block + 1) * 64])
251 .try_into()
252 .expect("slice has length 64"),
253 );
254 }
255
256 let mut output_offset = j * 128 * out_stride + i * 16;
257 let buf_as_bytes: &[u8] = must_cast_slice(&buf);
258
259 if out_stride == 16 {
260 let dst_slice = &mut output
265 [output_offset..output_offset + 16 * 128 * remaining_blocks_in_cache_line];
266 dst_slice.copy_from_slice(&buf_as_bytes[..remaining_blocks_in_cache_line * 2048]);
267 } else {
268 for block in 0..remaining_blocks_in_cache_line {
269 for k in 0..128 {
270 let src_slice =
271 &buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16];
272 let dst_slice = &mut output
273 [output_offset + k * out_stride..output_offset + k * out_stride + 16];
274 dst_slice.copy_from_slice(src_slice);
275 }
276 output_offset += 128 * out_stride;
277 }
278 }
279
280 j += remaining_blocks_in_cache_line;
281 }
282
283 if c_rest > 0 {
284 handle_rest_cols(input, output, &mut buf, in_stride, out_stride, c_rest, i, j);
285 }
286 }
287}
288
289#[inline(never)]
299#[target_feature(enable = "avx2")]
300#[allow(clippy::too_many_arguments)]
301fn handle_rest_cols(
302 input: &[u8],
303 output: &mut [u8],
304 buf: &mut [__m256i; 256],
305 in_stride: usize,
306 out_stride: usize,
307 c_rest: usize,
308 i: usize,
309 j: usize,
310) {
311 let input_offset = i * 128 * in_stride + j * 16;
312 let remaining_cols_bytes = c_rest / 8;
313 buf[0..64].fill(_mm256_setzero_si256());
314 let buf_as_bytes: &mut [u8] = must_cast_slice_mut(buf);
315
316 for k in 0..128 {
317 let src_row_offset = input_offset + k * in_stride;
318 let src_slice = &input[src_row_offset..src_row_offset + remaining_cols_bytes];
319 let buf_offset = k * 16;
322 buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(src_slice);
323 }
324
325 avx_transpose128x128((&mut buf[..64]).try_into().expect("slice has length 64"));
326
327 let output_offset = j * 128 * out_stride + i * 16;
328 let buf_as_bytes: &[u8] = must_cast_slice(&*buf);
329
330 for k in 0..c_rest {
331 let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16];
332 let dst_slice =
333 &mut output[output_offset + k * out_stride..output_offset + k * out_stride + 16];
334 dst_slice.copy_from_slice(src_slice);
335 }
336}
337
338#[cfg(all(test, target_feature = "avx2"))]
339mod tests {
340 use std::arch::x86_64::_mm256_setzero_si256;
341
342 use rand::{RngCore, SeedableRng, rngs::StdRng};
343
344 use super::{avx_transpose128x128, transpose_bitmatrix};
345
346 #[test]
347 fn test_avx_transpose128() {
348 unsafe {
349 let mut v = [_mm256_setzero_si256(); 64];
350 StdRng::seed_from_u64(42).fill_bytes(bytemuck::cast_slice_mut(&mut v));
351
352 let orig = v;
353 avx_transpose128x128(&mut v);
354 avx_transpose128x128(&mut v);
355 let mut failed = false;
356 for (i, (o, t)) in orig.into_iter().zip(v).enumerate() {
357 let o = bytemuck::cast::<_, [u128; 2]>(o);
358 let t = bytemuck::cast::<_, [u128; 2]>(t);
359 if o != t {
360 eprintln!("difference in block {i}");
361 eprintln!("orig: {o:?}");
362 eprintln!("tran: {t:?}");
363 failed = true;
364 }
365 }
366 if failed {
367 panic!("double transposed is different than original")
368 }
369 }
370 }
371
372 #[test]
373 fn test_avx_transpose() {
374 let rows = 128 * 2;
375 let cols = 128 * 2;
376 let mut v = vec![0_u8; rows * cols / 8];
377 StdRng::seed_from_u64(42).fill_bytes(&mut v);
378
379 let mut avx_transposed = v.clone();
380 let mut sse_transposed = v.clone();
381 unsafe {
382 transpose_bitmatrix(&v, &mut avx_transposed, rows);
383 }
384 crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
385
386 assert_eq!(sse_transposed, avx_transposed);
387 }
388
389 #[test]
390 fn test_avx_transpose_unaligned_data() {
391 let rows = 128 * 2;
392 let cols = 128 * 2;
393 let mut v = vec![0_u8; rows * (cols + 128) / 8];
394 StdRng::seed_from_u64(42).fill_bytes(&mut v);
395
396 let v = {
397 let addr = v.as_ptr().addr();
398 let offset = addr.next_multiple_of(3) - addr;
399 &v[offset..offset + rows * cols / 8]
400 };
401 assert_eq!(0, v.as_ptr().addr() % 3);
402 let mut avx_transposed = v.to_owned();
404 let mut sse_transposed = v.to_owned();
405
406 unsafe {
407 transpose_bitmatrix(&v, &mut avx_transposed, rows);
408 }
409 crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
410
411 assert_eq!(sse_transposed, avx_transposed);
412 }
413
414 #[test]
415 fn test_avx_transpose_larger_cols_divisible_by_4_times_128() {
416 let rows = 128;
417 let cols = 128 * 8;
418 let mut v = vec![0_u8; rows * cols / 8];
419 StdRng::seed_from_u64(42).fill_bytes(&mut v);
420
421 let mut avx_transposed = v.clone();
422 let mut sse_transposed = v.clone();
423 unsafe {
424 transpose_bitmatrix(&v, &mut avx_transposed, rows);
425 }
426 crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
427
428 assert_eq!(sse_transposed, avx_transposed);
429 }
430
431 #[test]
432 fn test_avx_transpose_larger_cols_divisible_by_8() {
433 let rows = 128;
434 let cols = 128 + 32;
435 let mut v = vec![0_u8; rows * cols / 8];
436 StdRng::seed_from_u64(42).fill_bytes(&mut v);
437
438 let mut avx_transposed = v.clone();
439 let mut sse_transposed = v.clone();
440 unsafe {
441 transpose_bitmatrix(&v, &mut avx_transposed, rows);
442 }
443 crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
444
445 assert_eq!(sse_transposed, avx_transposed);
446 }
447}