diff --git a/avx2.hh b/avx2.hh index 454c7d1..0d89776 100644 --- a/avx2.hh +++ b/avx2.hh @@ -1134,3 +1134,35 @@ inline SIMD vclamp(SIMD x, int64_t a, int64_t b) return tmp; } +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + __m256i c = _mm256_sub_epi8(b.m, _mm256_set1_epi8(16)); + __m256i d = _mm256_or_si256(b.m, _mm256_cmpgt_epi8(b.m, _mm256_set1_epi8(15))); + __m256i e = _mm256_shuffle_epi8(_mm256_permute2x128_si256(a.m, a.m, 0), d); + __m256i f = _mm256_shuffle_epi8(_mm256_permute2x128_si256(a.m, a.m, 17), c); + tmp.m = _mm256_or_si256(e, f); + return tmp; +} + +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + __m256i c = _mm256_sub_epi8(b.m, _mm256_set1_epi8(16)); + __m256i d = _mm256_or_si256(b.m, _mm256_cmpgt_epi8(b.m, _mm256_set1_epi8(15))); + __m256i e = _mm256_shuffle_epi8(_mm256_permute2x128_si256(a.m, a.m, 0), d); + __m256i f = _mm256_shuffle_epi8(_mm256_permute2x128_si256(a.m, a.m, 17), c); + tmp.m = _mm256_or_si256(e, f); + return tmp; +} + +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + tmp.m = _mm256_permutevar8x32_ps(a.m, b.m); + return tmp; +} + diff --git a/neon.hh b/neon.hh index e591ed1..a94cc6c 100644 --- a/neon.hh +++ b/neon.hh @@ -938,3 +938,25 @@ inline SIMD vclamp(SIMD x, int32_t a, int32_t b) return tmp; } +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + uint8x8x2_t c { vget_low_u8(a.m), vget_high_u8(a.m) }; + uint8x8_t d = vtbl2_u8(c, vget_low_u8(b.m)); + uint8x8_t e = vtbl2_u8(c, vget_high_u8(b.m)); + tmp.m = vcombine_u8(d, e); + return tmp; +} + +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + int8x8x2_t c { vget_low_s8(a.m), vget_high_s8(a.m) }; + int8x8_t d = vtbl2_s8(c, vget_low_s8((int8x16_t)b.m)); + int8x8_t e = vtbl2_s8(c, vget_high_s8((int8x16_t)b.m)); + tmp.m = vcombine_s8(d, e); + return tmp; +} + diff --git a/simd.hh b/simd.hh index 4201089..9b223b4 100644 --- a/simd.hh +++ b/simd.hh @@ -1387,6 +1387,96 @@ static inline SIMD vsign(SIMD a, SIMD +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0.f; + return tmp; +} + +template +static inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + for (int i = 0; i < WIDTH; ++i) + tmp.v[i] = b.v[i] < WIDTH ? a.v[b.v[i]] : 0.; + return tmp; +} + #if 1 #ifdef __AVX2__ #include "avx2.hh" diff --git a/sse4_1.hh b/sse4_1.hh index cc3f29a..0e61b9f 100644 --- a/sse4_1.hh +++ b/sse4_1.hh @@ -1127,3 +1127,19 @@ inline SIMD vclamp(SIMD x, int32_t a, int32_t b) return tmp; } +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + tmp.m = _mm_shuffle_epi8(a.m, b.m); + return tmp; +} + +template <> +inline SIMD vshuf(SIMD a, SIMD b) +{ + SIMD tmp; + tmp.m = _mm_shuffle_epi8(a.m, b.m); + return tmp; +} +