diff --git a/cauchy_reed_solomon_erasure_coding.hh b/cauchy_reed_solomon_erasure_coding.hh index f1656b5..ccf96db 100644 --- a/cauchy_reed_solomon_erasure_coding.hh +++ b/cauchy_reed_solomon_erasure_coding.hh @@ -6,6 +6,17 @@ Copyright 2023 Ahmet Inan #pragma once +#ifdef __ARM_NEON +#include +#endif +#ifdef __AVX2__ +#include +#else +#ifdef __SSE4_1__ +#include +#endif +#endif + namespace CODE { template @@ -35,8 +46,220 @@ struct CauchyReedSolomonErasureCoding } return prod_xy / (index(rows[j] + col_i) * prod_x * prod_y); } +#if defined(__ARM_NEON) && !defined(__aarch64__) + static inline uint8x16_t vqtbl1q_u8(uint8x16_t lut, uint8x16_t idx) + { + uint8x8x2_t l82 = {{ vget_low_u8(lut), vget_high_u8(lut) }}; + uint8x8_t lo = vtbl2_u8(l82, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(l82, vget_high_u8(idx)); + return vcombine_u8(lo, hi); + } +#endif +#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__SSE4_1__) + void mac_simd(uint8_t *c, const uint8_t *a, IndexType b, int size, bool init) + { + alignas(16) uint8_t bln[16], bhn[16]; + for (int i = 0; i < 16; ++i) { + bln[i] = (b * ValueType(i)).v; + bhn[i] = (b * ValueType(i << 4)).v; + } +#ifdef __ARM_NEON + uint8x16_t l16 = vld1q_u8(bln); + uint8x16_t h16 = vld1q_u8(bhn); + for (int i = 0; i < size; i += 16, a += 16, c += 16) { + uint8x16_t a16 = vld1q_u8(reinterpret_cast(__builtin_assume_aligned(a, 16))); + uint8x16_t aln = vandq_u8(a16, vdupq_n_u8(15)); + uint8x16_t ahn = vshrq_n_u8(a16, 4); + uint8x16_t cln = vqtbl1q_u8(l16, aln); + uint8x16_t chn = vqtbl1q_u8(h16, ahn); + uint8x16_t c16 = veorq_u8(cln, chn); + if (!init) + c16 = veorq_u8(c16, vld1q_u8(reinterpret_cast(__builtin_assume_aligned(c, 16)))); + vst1q_u8(reinterpret_cast(__builtin_assume_aligned(c, 16)), c16); + } +#else +#ifdef __AVX2__ + __m128i l16 = _mm_load_si128(reinterpret_cast(bln)); + __m128i h16 = _mm_load_si128(reinterpret_cast(bhn)); + __m256i l162 = _mm256_broadcastsi128_si256(l16); + __m256i h162 = _mm256_broadcastsi128_si256(h16); + for (int i = 0; i < size; i += 32, a += 32, c += 32) { + __m256i a32 = _mm256_load_si256(reinterpret_cast(__builtin_assume_aligned(a, 32))); + __m256i aln = _mm256_and_si256(a32, _mm256_set1_epi8(15)); + __m256i cln = _mm256_shuffle_epi8(l162, aln); + __m256i ahn = _mm256_and_si256(_mm256_srli_epi16(a32, 4), _mm256_set1_epi8(15)); + __m256i chn = _mm256_shuffle_epi8(h162, ahn); + __m256i c32 = _mm256_xor_si256(cln, chn); + if (!init) + c32 = _mm256_xor_si256(c32, _mm256_load_si256(reinterpret_cast(__builtin_assume_aligned(c, 32)))); + _mm256_store_si256(reinterpret_cast<__m256i *>(__builtin_assume_aligned(c, 32)), c32); + } +#else + __m128i l16 = _mm_load_si128(reinterpret_cast(bln)); + __m128i h16 = _mm_load_si128(reinterpret_cast(bhn)); + for (int i = 0; i < size; i += 16, a += 16, c += 16) { + __m128i a16 = _mm_load_si128(reinterpret_cast(__builtin_assume_aligned(a, 16))); + __m128i aln = _mm_and_si128(a16, _mm_set1_epi8(15)); + __m128i cln = _mm_shuffle_epi8(l16, aln); + __m128i ahn = _mm_and_si128(_mm_srli_epi16(a16, 4), _mm_set1_epi8(15)); + __m128i chn = _mm_shuffle_epi8(h16, ahn); + __m128i c16 = _mm_xor_si128(cln, chn); + if (!init) + c16 = _mm_xor_si128(c16, _mm_load_si128(reinterpret_cast(__builtin_assume_aligned(c, 16)))); + _mm_store_si128(reinterpret_cast<__m128i *>(__builtin_assume_aligned(c, 16)), c16); + } +#endif +#endif + } + void mac_simd(uint16_t *c, const uint16_t *a, IndexType b, int size, bool init) + { + alignas(16) uint8_t blll[16], bllh[16], blhl[16], blhh[16], bhll[16], bhlh[16], bhhl[16], bhhh[16]; + for (int i = 0; i < 16; ++i) { + uint16_t bll = (b * ValueType(i)).v; + uint16_t blh = (b * ValueType(i << 4)).v; + uint16_t bhl = (b * ValueType(i << 8)).v; + uint16_t bhh = (b * ValueType(i << 12)).v; + blll[i] = bll; bhll[i] = bll >> 8; + bllh[i] = blh; bhlh[i] = blh >> 8; + blhl[i] = bhl; bhhl[i] = bhl >> 8; + blhh[i] = bhh; bhhh[i] = bhh >> 8; + } +#ifdef __ARM_NEON + uint8x16_t lll16 = vld1q_u8(blll); + uint8x16_t llh16 = vld1q_u8(bllh); + uint8x16_t lhl16 = vld1q_u8(blhl); + uint8x16_t lhh16 = vld1q_u8(blhh); + uint8x16_t hll16 = vld1q_u8(bhll); + uint8x16_t hlh16 = vld1q_u8(bhlh); + uint8x16_t hhl16 = vld1q_u8(bhhl); + uint8x16_t hhh16 = vld1q_u8(bhhh); + for (int i = 0; i < size; i += 16, a += 16, c += 16) { + uint16x8_t al8 = vld1q_u16(reinterpret_cast(__builtin_assume_aligned(a, 16))); + uint16x8_t ah8 = vld1q_u16(reinterpret_cast(__builtin_assume_aligned(a+8, 16))); + uint8x16_t al16 = (uint8x16_t)vorrq_u16(vshlq_n_u16(ah8, 8), vandq_u16(al8, vdupq_n_u16(255))); + uint8x16_t ah16 = (uint8x16_t)vorrq_u16(vandq_u16(ah8, vdupq_n_u16(0xff00)), vshrq_n_u16(al8, 8)); + uint8x16_t alln = vandq_u8(al16, vdupq_n_u8(15)); + uint8x16_t clll = vqtbl1q_u8(lll16, alln); + uint8x16_t chll = vqtbl1q_u8(hll16, alln); + uint8x16_t alhn = vshrq_n_u8(al16, 4); + uint8x16_t cllh = vqtbl1q_u8(llh16, alhn); + uint8x16_t chlh = vqtbl1q_u8(hlh16, alhn); + uint8x16_t ahln = vandq_u8(ah16, vdupq_n_u8(15)); + uint8x16_t clhl = vqtbl1q_u8(lhl16, ahln); + uint8x16_t chhl = vqtbl1q_u8(hhl16, ahln); + uint8x16_t ahhn = vshrq_n_u8(ah16, 4); + uint8x16_t clhh = vqtbl1q_u8(lhh16, ahhn); + uint8x16_t chhh = vqtbl1q_u8(hhh16, ahhn); + uint16x8_t cl16 = (uint16x8_t)veorq_u8(veorq_u8(clll, cllh), veorq_u8(clhl, clhh)); + uint16x8_t ch16 = (uint16x8_t)veorq_u8(veorq_u8(chll, chlh), veorq_u8(chhl, chhh)); + uint16x8_t cl8 = vorrq_u16(vshlq_n_u16(ch16, 8), vandq_u16(cl16, vdupq_n_u16(255))); + uint16x8_t ch8 = vorrq_u16(vandq_u16(ch16, vdupq_n_u16(0xff00)), vshrq_n_u16(cl16, 8)); + if (!init) { + cl8 = veorq_u16(cl8, vld1q_u16(reinterpret_cast(__builtin_assume_aligned(c, 16)))); + ch8 = veorq_u16(ch8, vld1q_u16(reinterpret_cast(__builtin_assume_aligned(c+8, 16)))); + } + vst1q_u16(reinterpret_cast(__builtin_assume_aligned(c, 16)), cl8); + vst1q_u16(reinterpret_cast(__builtin_assume_aligned(c+8, 16)), ch8); + } +#else +#ifdef __AVX2__ + __m128i lll16 = _mm_load_si128(reinterpret_cast(blll)); + __m128i llh16 = _mm_load_si128(reinterpret_cast(bllh)); + __m128i lhl16 = _mm_load_si128(reinterpret_cast(blhl)); + __m128i lhh16 = _mm_load_si128(reinterpret_cast(blhh)); + __m128i hll16 = _mm_load_si128(reinterpret_cast(bhll)); + __m128i hlh16 = _mm_load_si128(reinterpret_cast(bhlh)); + __m128i hhl16 = _mm_load_si128(reinterpret_cast(bhhl)); + __m128i hhh16 = _mm_load_si128(reinterpret_cast(bhhh)); + __m256i lll162 = _mm256_broadcastsi128_si256(lll16); + __m256i llh162 = _mm256_broadcastsi128_si256(llh16); + __m256i lhl162 = _mm256_broadcastsi128_si256(lhl16); + __m256i lhh162 = _mm256_broadcastsi128_si256(lhh16); + __m256i hll162 = _mm256_broadcastsi128_si256(hll16); + __m256i hlh162 = _mm256_broadcastsi128_si256(hlh16); + __m256i hhl162 = _mm256_broadcastsi128_si256(hhl16); + __m256i hhh162 = _mm256_broadcastsi128_si256(hhh16); + for (int i = 0; i < size; i += 32, a += 32, c += 32) { + __m256i al16 = _mm256_load_si256(reinterpret_cast(__builtin_assume_aligned(a, 32))); + __m256i ah16 = _mm256_load_si256(reinterpret_cast(__builtin_assume_aligned(a+16, 32))); + __m256i al32 = _mm256_or_si256(_mm256_slli_epi16(ah16, 8), _mm256_and_si256(al16, _mm256_set1_epi16(255))); + __m256i ah32 = _mm256_or_si256(_mm256_and_si256(ah16, _mm256_set1_epi16(0xff00)), _mm256_srli_epi16(al16, 8)); + __m256i alln = _mm256_and_si256(al32, _mm256_set1_epi8(15)); + __m256i clll = _mm256_shuffle_epi8(lll162, alln); + __m256i chll = _mm256_shuffle_epi8(hll162, alln); + __m256i alhn = _mm256_and_si256(_mm256_srli_epi16(al32, 4), _mm256_set1_epi8(15)); + __m256i cllh = _mm256_shuffle_epi8(llh162, alhn); + __m256i chlh = _mm256_shuffle_epi8(hlh162, alhn); + __m256i ahln = _mm256_and_si256(ah32, _mm256_set1_epi8(15)); + __m256i clhl = _mm256_shuffle_epi8(lhl162, ahln); + __m256i chhl = _mm256_shuffle_epi8(hhl162, ahln); + __m256i ahhn = _mm256_and_si256(_mm256_srli_epi16(ah32, 4), _mm256_set1_epi8(15)); + __m256i clhh = _mm256_shuffle_epi8(lhh162, ahhn); + __m256i chhh = _mm256_shuffle_epi8(hhh162, ahhn); + __m256i cl32 = _mm256_xor_si256(_mm256_xor_si256(clll, cllh), _mm256_xor_si256(clhl, clhh)); + __m256i ch32 = _mm256_xor_si256(_mm256_xor_si256(chll, chlh), _mm256_xor_si256(chhl, chhh)); + __m256i cl16 = _mm256_or_si256(_mm256_slli_epi16(ch32, 8), _mm256_and_si256(cl32, _mm256_set1_epi16(255))); + __m256i ch16 = _mm256_or_si256(_mm256_and_si256(ch32, _mm256_set1_epi16(0xff00)), _mm256_srli_epi16(cl32, 8)); + if (!init) { + cl16 = _mm256_xor_si256(cl16, _mm256_load_si256(reinterpret_cast(__builtin_assume_aligned(c, 32)))); + ch16 = _mm256_xor_si256(ch16, _mm256_load_si256(reinterpret_cast(__builtin_assume_aligned(c+16, 32)))); + } + _mm256_store_si256(reinterpret_cast<__m256i *>(__builtin_assume_aligned(c, 32)), cl16); + _mm256_store_si256(reinterpret_cast<__m256i *>(__builtin_assume_aligned(c+16, 32)), ch16); + } +#else + __m128i lll16 = _mm_load_si128(reinterpret_cast(blll)); + __m128i llh16 = _mm_load_si128(reinterpret_cast(bllh)); + __m128i lhl16 = _mm_load_si128(reinterpret_cast(blhl)); + __m128i lhh16 = _mm_load_si128(reinterpret_cast(blhh)); + __m128i hll16 = _mm_load_si128(reinterpret_cast(bhll)); + __m128i hlh16 = _mm_load_si128(reinterpret_cast(bhlh)); + __m128i hhl16 = _mm_load_si128(reinterpret_cast(bhhl)); + __m128i hhh16 = _mm_load_si128(reinterpret_cast(bhhh)); + for (int i = 0; i < size; i += 16, a += 16, c += 16) { + __m128i al8 = _mm_load_si128(reinterpret_cast(__builtin_assume_aligned(a, 16))); + __m128i ah8 = _mm_load_si128(reinterpret_cast(__builtin_assume_aligned(a+8, 16))); + __m128i al16 = _mm_or_si128(_mm_slli_epi16(ah8, 8), _mm_and_si128(al8, _mm_set1_epi16(255))); + __m128i ah16 = _mm_or_si128(_mm_and_si128(ah8, _mm_set1_epi16(0xff00)), _mm_srli_epi16(al8, 8)); + __m128i alln = _mm_and_si128(al16, _mm_set1_epi8(15)); + __m128i clll = _mm_shuffle_epi8(lll16, alln); + __m128i chll = _mm_shuffle_epi8(hll16, alln); + __m128i alhn = _mm_and_si128(_mm_srli_epi16(al16, 4), _mm_set1_epi8(15)); + __m128i cllh = _mm_shuffle_epi8(llh16, alhn); + __m128i chlh = _mm_shuffle_epi8(hlh16, alhn); + __m128i ahln = _mm_and_si128(ah16, _mm_set1_epi8(15)); + __m128i clhl = _mm_shuffle_epi8(lhl16, ahln); + __m128i chhl = _mm_shuffle_epi8(hhl16, ahln); + __m128i ahhn = _mm_and_si128(_mm_srli_epi16(ah16, 4), _mm_set1_epi8(15)); + __m128i clhh = _mm_shuffle_epi8(lhh16, ahhn); + __m128i chhh = _mm_shuffle_epi8(hhh16, ahhn); + __m128i cl16 = _mm_xor_si128(_mm_xor_si128(clll, cllh), _mm_xor_si128(clhl, clhh)); + __m128i ch16 = _mm_xor_si128(_mm_xor_si128(chll, chlh), _mm_xor_si128(chhl, chhh)); + __m128i cl8 = _mm_or_si128(_mm_slli_epi16(ch16, 8), _mm_and_si128(cl16, _mm_set1_epi16(255))); + __m128i ch8 = _mm_or_si128(_mm_and_si128(ch16, _mm_set1_epi16(0xff00)), _mm_srli_epi16(cl16, 8)); + if (!init) { + cl8 = _mm_xor_si128(cl8, _mm_load_si128(reinterpret_cast(__builtin_assume_aligned(c, 16)))); + ch8 = _mm_xor_si128(ch8, _mm_load_si128(reinterpret_cast(__builtin_assume_aligned(c+8, 16)))); + } + _mm_store_si128(reinterpret_cast<__m128i *>(__builtin_assume_aligned(c, 16)), cl8); + _mm_store_si128(reinterpret_cast<__m128i *>(__builtin_assume_aligned(c+8, 16)), ch8); + } +#endif +#endif + } +#endif void multiply_accumulate(ValueType *c, const ValueType *a, IndexType b, int len, bool init) { +#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__SSE4_1__) +#ifdef __AVX2__ + int SIMD = 32; +#else + int SIMD = 16; +#endif + if (((GF::M == 8 && sizeof(value_type) == 1) || (GF::M == 16 && sizeof(value_type) == 2)) && reinterpret_cast(c) % SIMD == 0 && reinterpret_cast(a) % SIMD == 0 && len % SIMD == 0) { + mac_simd(reinterpret_cast(c), reinterpret_cast(a), b, len, init); + } else +#endif if (init) { for (int i = 0; i < len; i++) c[i] = b * a[i]; diff --git a/tests/crs_regression_test.cc b/tests/crs_regression_test.cc index 0ee641a..500cf34 100644 --- a/tests/crs_regression_test.cc +++ b/tests/crs_regression_test.cc @@ -4,6 +4,7 @@ Regression Test for the Cauchy Reed Solomon Encoder and Decoder Copyright 2023 Ahmet Inan */ +#include #include #include #include @@ -15,6 +16,11 @@ Copyright 2023 Ahmet Inan template void crs_test(int trials) { +#ifdef __AVX2__ + int SIMD = 32; +#else + int SIMD = 16; +#endif CODE::CauchyReedSolomonErasureCoding crs; std::random_device rd; std::default_random_engine generator(rd()); @@ -25,12 +31,13 @@ void crs_test(int trials) while (--trials) { int block_count = rnd_cnt(); int numbers_total = GF::Q - block_count; - int block_bytes = rnd_len() * sizeof(typename GF::value_type); + int block_bytes = rnd_len() * sizeof(typename GF::value_type) * SIMD; int data_bytes = block_count * block_bytes; - uint8_t *orig = new uint8_t[data_bytes]; + uint8_t *orig = reinterpret_cast(std::aligned_alloc(SIMD, data_bytes)); + uint8_t *data = reinterpret_cast(std::aligned_alloc(SIMD, data_bytes)); + uint8_t *blocks = reinterpret_cast(std::aligned_alloc(SIMD, data_bytes)); for (int i = 0; i < data_bytes; ++i) orig[i] = rnd_dat(); - uint8_t *blocks = new uint8_t[data_bytes]; auto numbers = new typename GF::value_type[numbers_total]; for (int i = 0; i < numbers_total; ++i) numbers[i] = i; @@ -44,7 +51,6 @@ void crs_test(int trials) auto enc_end = std::chrono::system_clock::now(); auto enc_usec = std::chrono::duration_cast(enc_end - enc_start); double enc_mbs = double(data_bytes) / enc_usec.count(); - uint8_t *data = new uint8_t[data_bytes]; auto dec_start = std::chrono::system_clock::now(); for (int i = 0; i < block_count; ++i) crs.decode(data + block_bytes * i, blocks, numbers, i, block_bytes, block_count); @@ -55,9 +61,9 @@ void crs_test(int trials) for (int i = 0; i < data_bytes; ++i) assert(data[i] == orig[i]); delete[] numbers; - delete[] blocks; - delete[] orig; - delete[] data; + std::free(blocks); + std::free(orig); + std::free(data); } }