1use std::{arch::x86_64::*, hint::unreachable_unchecked};
3
4#[inline]
5#[target_feature(enable = "avx2")]
6unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
8 debug_assert!(
9 matches!(shift, 2 | 4 | 8 | 16 | 32),
10 "Must be called with correct shift"
11 );
12 unsafe {
13 match shift {
14 2 => _mm256_slli_epi64::<2>(a),
15 4 => _mm256_slli_epi64::<4>(a),
16 8 => _mm256_slli_epi64::<8>(a),
17 16 => _mm256_slli_epi64::<16>(a),
18 32 => _mm256_slli_epi64::<32>(a),
19 _ => unreachable_unchecked(),
20 }
21 }
22}
23
24#[inline]
25#[target_feature(enable = "avx2")]
26unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
28 debug_assert!(
29 matches!(shift, 2 | 4 | 8 | 16 | 32),
30 "Must be called with correct shift"
31 );
32 unsafe {
33 match shift {
34 2 => _mm256_srli_epi64::<2>(a),
35 4 => _mm256_srli_epi64::<4>(a),
36 8 => _mm256_srli_epi64::<8>(a),
37 16 => _mm256_srli_epi64::<16>(a),
38 32 => _mm256_srli_epi64::<32>(a),
39 _ => unreachable_unchecked(),
40 }
41 }
42}
43
44#[inline]
48#[target_feature(enable = "avx2")]
49unsafe fn avx_transpose_block_iter1(
50 in_out: *mut __m256i,
51 block_size_shift: usize,
52 block_rows_shift: usize,
53 j: usize,
54) {
55 if j < (1 << block_size_shift) && block_size_shift == 6 {
56 unsafe {
57 let x = in_out.add(j / 2);
58 let y = in_out.add(j / 2 + 32);
59
60 let out_x = _mm256_unpacklo_epi64(*x, *y);
61 let out_y = _mm256_unpackhi_epi64(*x, *y);
62 *x = out_x;
63 *y = out_y;
64 return;
65 }
66 }
67
68 if block_size_shift == 0 || block_size_shift >= 6 || block_rows_shift < 1 {
69 return;
70 }
71
72 let mut mask = (!0u64) << 32;
74 for k in (block_size_shift as i32..=4).rev() {
75 mask ^= mask >> (1 << k);
76 }
77
78 unsafe {
79 let x = in_out.add(j / 2);
80 let y = in_out.add(j / 2 + (1 << (block_size_shift - 1)));
81
82 if block_size_shift == 1 {
84 let u = _mm256_permute2x128_si256(*x, *y, 0x20);
85 let v = _mm256_permute2x128_si256(*x, *y, 0x31);
86
87 let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1));
88 diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16));
89 let u = _mm256_xor_si256(u, diff);
90 let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1));
91
92 *x = _mm256_permute2x128_si256(u, v, 0x20);
93 *y = _mm256_permute2x128_si256(u, v, 0x31);
94 }
95
96 let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64_var_shift(*y, 1 << block_size_shift));
97 diff = _mm256_and_si256(diff, _mm256_set1_epi64x(mask as i64));
98 *x = _mm256_xor_si256(*x, diff);
99 *y = _mm256_xor_si256(*y, _mm256_srli_epi64_var_shift(diff, 1 << block_size_shift));
100 }
101}
102
103#[inline] #[target_feature(enable = "avx2")]
105unsafe fn avx_transpose_block_iter2(
106 in_out: *mut __m256i,
107 block_size_shift: usize,
108 block_rows_shift: usize,
109 n_rows: usize,
110) {
111 let mat_size = 1 << (block_size_shift + 1);
112
113 for i in (0..n_rows).step_by(mat_size) {
114 for j in (0..(1 << block_size_shift)).step_by(1 << block_rows_shift) {
115 unsafe {
116 avx_transpose_block_iter1(in_out.add(i / 2), block_size_shift, block_rows_shift, j);
117 }
118 }
119 }
120}
121
122#[inline] #[target_feature(enable = "avx2")]
124unsafe fn avx_transpose_block(
125 in_out: *mut __m256i,
126 block_size_shift: usize,
127 mat_size_shift: usize,
128 block_rows_shift: usize,
129 mat_rows_shift: usize,
130) {
131 if block_size_shift >= mat_size_shift {
132 return;
133 }
134
135 let total_rows = 1 << (mat_rows_shift + mat_size_shift);
137
138 unsafe {
139 avx_transpose_block_iter2(in_out, block_size_shift, block_rows_shift, total_rows);
140
141 avx_transpose_block(
143 in_out,
144 block_size_shift + 1,
145 mat_size_shift,
146 block_rows_shift,
147 mat_rows_shift,
148 );
149 }
150}
151
152const AVX_BLOCK_SHIFT: usize = 4;
153const AVX_BLOCK_SIZE: usize = 1 << AVX_BLOCK_SHIFT;
154
155#[target_feature(enable = "avx2")]
160pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
161 const MAT_SIZE_SHIFT: usize = 7;
162 unsafe {
163 let in_out = in_out.as_mut_ptr();
164 for i in (0..64).step_by(AVX_BLOCK_SIZE) {
165 avx_transpose_block(
166 in_out.add(i),
167 1,
168 MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT,
169 1,
170 AVX_BLOCK_SHIFT + 1 - (MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT),
171 );
172 }
173
174 let block_size_shift = MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT;
176
177 for i in 0..(1 << (block_size_shift - 1)) {
179 avx_transpose_block(
180 in_out.add(i),
181 block_size_shift,
182 MAT_SIZE_SHIFT,
183 block_size_shift,
184 0,
185 );
186 }
187 }
188}
189
190#[target_feature(enable = "avx2")]
200pub unsafe fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
201 assert_eq!(input.len(), output.len());
202 let cols = input.len() * 8 / rows;
203 assert_eq!(0, cols % 128);
204 assert_eq!(0, rows % 128);
205 #[allow(unused_unsafe)]
206 let mut buf = [unsafe { _mm256_setzero_si256() }; 64];
207 let in_stride = cols / 8;
208 let out_stride = rows / 8;
209
210 let r_main = rows / 128;
212 let c_main = cols / 128;
213
214 for i in 0..r_main {
215 for j in 0..c_main {
216 unsafe {
218 let src_ptr = input.as_ptr().add(i * 128 * in_stride + j * 16);
219
220 let buf_u8_ptr = buf.as_mut_ptr() as *mut u8;
221
222 for k in 0..128 {
224 let src_row = src_ptr.add(k * in_stride);
225 std::ptr::copy_nonoverlapping(src_row, buf_u8_ptr.add(k * 16), 16);
226 }
227 }
228 avx_transpose128x128(&mut buf);
230
231 unsafe {
232 let buf_u8_ptr = buf.as_mut_ptr() as *mut u8;
234 let dst_ptr = output.as_mut_ptr().add(j * 128 * out_stride + i * 16);
236 for k in 0..128 {
237 let dst_row = dst_ptr.add(k * out_stride);
238 std::ptr::copy_nonoverlapping(buf_u8_ptr.add(k * 16), dst_row, 16);
239 }
240 }
241 }
242 }
243}
244
245#[cfg(all(test, target_feature = "avx2"))]
246mod tests {
247 use std::arch::x86_64::_mm256_setzero_si256;
248
249 use rand::{RngCore, SeedableRng, rngs::StdRng};
250
251 use super::{avx_transpose128x128, transpose_bitmatrix};
252
253 #[test]
254 fn test_avx_transpose128() {
255 unsafe {
256 let mut v = [_mm256_setzero_si256(); 64];
257 StdRng::seed_from_u64(42).fill_bytes(bytemuck::cast_slice_mut(&mut v));
258
259 let orig = v.clone();
260 avx_transpose128x128(&mut v);
261 avx_transpose128x128(&mut v);
262 let mut failed = false;
263 for (i, (o, t)) in orig.into_iter().zip(v).enumerate() {
264 let o = bytemuck::cast::<_, [u128; 2]>(o);
265 let t = bytemuck::cast::<_, [u128; 2]>(t);
266 if o != t {
267 eprintln!("difference in block {i}");
268 eprintln!("orig: {o:?}");
269 eprintln!("tran: {t:?}");
270 failed = true;
271 }
272 }
273 if failed {
274 panic!("double transposed is different than original")
275 }
276 }
277 }
278
279 #[test]
280 fn test_avx_transpose() {
281 let rows = 128 * 2;
282 let cols = 128 * 2;
283 let mut v = vec![0_u8; rows * cols / 8];
284 StdRng::seed_from_u64(42).fill_bytes(&mut v);
285
286 let mut avx_transposed = v.clone();
287 let mut sse_transposed = v.clone();
288 unsafe {
289 transpose_bitmatrix(&v, &mut avx_transposed, rows);
290 }
291 crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
292
293 assert_eq!(sse_transposed, avx_transposed);
294 }
295}