From 3e881ce049fc6919f6d1ec000ae74fc9eb17d529 Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Wed, 12 Dec 2018 12:09:17 +0100 Subject: [PATCH] accelerate decoder using SIMD Also added a fake parity bit to reduce code duplication. The fake parity bit seems not to affect the decoding performance. --- ldpc_decoder.hh | 233 ++++++++++++++++++++++++------------------------ 1 file changed, 117 insertions(+), 116 deletions(-) diff --git a/ldpc_decoder.hh b/ldpc_decoder.hh index d8272d9..a33c1d5 100644 --- a/ldpc_decoder.hh +++ b/ldpc_decoder.hh @@ -9,162 +9,163 @@ Copyright 2018 Ahmet Inan #include #include "exclusive_reduce.hh" +#include "simd.hh" namespace CODE { template class LDPCDecoder { +#ifdef __AVX2__ + static const int SIMD_SIZE = 32; +#else + static const int SIMD_SIZE = 16; +#endif static const int M = TABLE::M; static const int N = TABLE::N; static const int K = TABLE::K; static const int R = N-K; static const int q = R/M; static const int CNC = TABLE::LINKS_MAX_CN - 2; - static const int LT = TABLE::LINKS_TOTAL; + static const int BNL = TABLE::LINKS_TOTAL / SIMD_SIZE + TABLE::LINKS_MAX_CN * q; - int8_t bnl[LT]; + typedef SIMD TYPE; + + TYPE bnl[BNL]; int8_t pty[R]; uint16_t pos[R * CNC]; uint8_t cnc[R]; - static int8_t sqadd(int8_t a, int8_t b) + static TYPE eor(TYPE a, TYPE b) { - int16_t x = int16_t(a) + int16_t(b); - x = std::min(std::max(x, -128), 127); - return x; + return vreinterpret(veor(vmask(a), vmask(b))); } - static int8_t sqsub(int8_t a, int8_t b) + static TYPE min(TYPE a, TYPE b) { - int16_t x = int16_t(a) - int16_t(b); - x = std::min(std::max(x, -128), 127); - return x; + return vmin(a, b); } - static uint8_t sqsubu(uint8_t a, uint8_t b) + static void finalp(TYPE *links, int cnt) { - int16_t x = int16_t(a) - int16_t(b); - x = std::max(x, 0); - return x; - } - static int8_t smin(int8_t a, int8_t b) - { - return std::min(a, b); - } - static int8_t sclamp(int8_t x, int8_t a, int8_t b) - { - return std::min(std::max(x, a), b); - } - static int8_t seor(int8_t a, int8_t b) - { - return a ^ b; - } - static int8_t sqabs(int8_t a) - { - return std::abs(std::max(a, -127)); - } - static int8_t ssign(int8_t a, int8_t b) - { - return ((b > 0) - (b < 0)) * a; + auto beta = vunsigned(vdup(BETA)); + TYPE mags[cnt]; + for (int i = 0; i < cnt; ++i) + mags[i] = vsigned(vqsub(vunsigned(vqabs(links[i])), beta)); + + TYPE mins[cnt]; + exclusive_reduce(mags, mins, cnt, min); + + TYPE signs[cnt]; + exclusive_reduce(links, signs, cnt, eor); + for (int i = 0; i < cnt; ++i) + signs[i] = vreinterpret(vorr(vmask(signs[i]), vmask(vdup(127)))); + + for (int i = 0; i < cnt; ++i) + links[i] = vsign(mins[i], signs[i]); } bool bad(int8_t *data, int8_t *parity) { - { - int cnt = cnc[0]; - { - int8_t cnv = ssign(1, parity[0]); - for (int c = 0; c < cnt; ++c) - cnv = ssign(cnv, data[pos[c]]); - if (cnv <= 0) - return true; - } - for (int j = 1; j < M; ++j) { - int8_t cnv = ssign(ssign(1, parity[j+(q-1)*M-1]), parity[j]); - for (int c = 0; c < cnt; ++c) - cnv = ssign(cnv, data[pos[CNC*j+c]]); - if (cnv <= 0) - return true; - } - } - for (int i = 1; i < q; ++i) { + for (int i = 0; i < q; ++i) { int cnt = cnc[i]; - for (int j = 0; j < M; ++j) { - int8_t cnv = ssign(ssign(1, parity[M*(i-1)+j]), parity[M*i+j]); + auto res = vmask(vzero()); + for (int j = 0; j < M; j += SIMD_SIZE) { + int num = std::min(M - j, SIMD_SIZE); + TYPE par[2]; + if (i) { + for (int n = 0; n < num; ++n) + par[0].v[n] = parity[M*(i-1)+j+n]; + } else { + if (j) { + for (int n = 0; n < num; ++n) + par[0].v[n] = parity[j+(q-1)*M-1+n]; + } else { + par[0].v[0] = 1; + for (int n = 1; n < num; ++n) + par[0].v[n] = parity[j+(q-1)*M-1+n]; + } + } + for (int n = 0; n < num; ++n) + par[1].v[n] = parity[M*i+j+n]; + TYPE dat[cnt]; for (int c = 0; c < cnt; ++c) - cnv = ssign(cnv, data[pos[CNC*(M*i+j)+c]]); - if (cnv <= 0) - return true; + for (int n = 0; n < num; ++n) + dat[c].v[n] = data[pos[CNC*(M*i+j+n)+c]]; + TYPE cnv = vdup(1); + for (int c = 0; c < 2; ++c) + cnv = vsign(cnv, par[c]); + for (int c = 0; c < cnt; ++c) + cnv = vsign(cnv, dat[c]); + for (int n = num; n < SIMD_SIZE; ++n) + cnv.v[n] = 1; + res = vorr(res, vclez(cnv)); } + for (int n = 0; n < SIMD_SIZE; ++n) + if (res.v[n]) + return true; } return false; } - void finalp(int8_t *links, int cnt) - { - int8_t mags[cnt], mins[cnt]; - for (int i = 0; i < cnt; ++i) - mags[i] = sqsubu(sqabs(links[i]), BETA); - exclusive_reduce(mags, mins, cnt, smin); - - int8_t signs[cnt]; - exclusive_reduce(links, signs, cnt, seor); - for (int i = 0; i < cnt; ++i) - signs[i] |= 127; - - for (int i = 0; i < cnt; ++i) - links[i] = ssign(mins[i], signs[i]); - } void update(int8_t *data, int8_t *parity) { - int8_t *bl = bnl; - { - int cnt = cnc[0]; - { - int deg = cnt + 1; - int8_t inp[deg], out[deg]; - for (int c = 0; c < cnt; ++c) - inp[c] = out[c] = sqsub(data[pos[c]], bl[c]); - inp[cnt] = out[cnt] = sqsub(parity[0], bl[cnt]); - finalp(out, deg); - for (int c = 0; c < cnt; ++c) - data[pos[c]] = sqadd(inp[c], out[c]); - parity[0] = sqadd(inp[cnt], out[cnt]); - for (int d = 0; d < deg; ++d) - *bl++ = sclamp(out[d], -32, 31); - } - int deg = cnt + 2; - for (int j = 1; j < M; ++j) { - int8_t inp[deg], out[deg]; - for (int c = 0; c < cnt; ++c) - inp[c] = out[c] = sqsub(data[pos[CNC*j+c]], bl[c]); - inp[cnt] = out[cnt] = sqsub(parity[j+(q-1)*M-1], bl[cnt]); - inp[cnt+1] = out[cnt+1] = sqsub(parity[j], bl[cnt+1]); - finalp(out, deg); - for (int c = 0; c < cnt; ++c) - data[pos[CNC*j+c]] = sqadd(inp[c], out[c]); - parity[j+(q-1)*M-1] = sqadd(inp[cnt], out[cnt]); - parity[j] = sqadd(inp[cnt+1], out[cnt+1]); - for (int d = 0; d < deg; ++d) - *bl++ = sclamp(out[d], -32, 31); - } - } - for (int i = 1; i < q; ++i) { + TYPE *bl = bnl; + for (int i = 0; i < q; ++i) { int cnt = cnc[i]; int deg = cnt + 2; - for (int j = 0; j < M; ++j) { - int8_t inp[deg], out[deg]; + for (int j = 0; j < M; j += SIMD_SIZE) { + int num = std::min(M - j, SIMD_SIZE); + TYPE par[2]; + if (i) { + for (int n = 0; n < num; ++n) + par[0].v[n] = parity[M*(i-1)+j+n]; + } else { + if (j) { + for (int n = 0; n < num; ++n) + par[0].v[n] = parity[j+(q-1)*M-1+n]; + } else { + par[0].v[0] = 1; + for (int n = 1; n < num; ++n) + par[0].v[n] = parity[j+(q-1)*M-1+n]; + } + } + for (int n = 0; n < num; ++n) + par[1].v[n] = parity[M*i+j+n]; + TYPE dat[cnt]; for (int c = 0; c < cnt; ++c) - inp[c] = out[c] = sqsub(data[pos[CNC*(M*i+j)+c]], bl[c]); - inp[cnt] = out[cnt] = sqsub(parity[M*(i-1)+j], bl[cnt]); - inp[cnt+1] = out[cnt+1] = sqsub(parity[M*i+j], bl[cnt+1]); + for (int n = 0; n < num; ++n) + dat[c].v[n] = data[pos[CNC*(M*i+j+n)+c]]; + TYPE inp[deg], out[deg]; + for (int c = 0; c < cnt; ++c) + inp[c] = out[c] = vqsub(dat[c], bl[c]); + inp[cnt] = out[cnt] = vqsub(par[0], bl[cnt]); + inp[cnt+1] = out[cnt+1] = vqsub(par[1], bl[cnt+1]); finalp(out, deg); for (int c = 0; c < cnt; ++c) - data[pos[CNC*(M*i+j)+c]] = sqadd(inp[c], out[c]); - parity[M*(i-1)+j] = sqadd(inp[cnt], out[cnt]); - parity[M*i+j] = sqadd(inp[cnt+1], out[cnt+1]); + dat[c] = vqadd(inp[c], out[c]); + par[0] = vqadd(inp[cnt], out[cnt]); + par[1] = vqadd(inp[cnt+1], out[cnt+1]); for (int d = 0; d < deg; ++d) - *bl++ = sclamp(out[d], -32, 31); + *bl++ = vclamp(out[d], -32, 31); + if (i) { + for (int n = 0; n < num; ++n) + parity[M*(i-1)+j+n] = par[0].v[n]; + } else { + if (j) { + for (int n = 0; n < num; ++n) + parity[j+(q-1)*M-1+n] = par[0].v[n]; + } else { + for (int n = 1; n < num; ++n) + parity[j+(q-1)*M-1+n] = par[0].v[n]; + } + } + for (int n = 0; n < num; ++n) + parity[M*i+j+n] = par[1].v[n]; + for (int c = 0; c < cnt; ++c) + for (int n = 0; n < num; ++n) + data[pos[CNC*(M*i+j+n)+c]] = dat[c].v[n]; } } + //assert(bl <= bnl + BNL); + //std::cerr << BNL - (bl - bnl) << std::endl; } public: LDPCDecoder() @@ -201,8 +202,8 @@ public: } int operator()(int8_t *data, int8_t *parity, int trials = 25) { - for (int i = 0; i < LT; ++i) - bnl[i] = 0; + for (int i = 0; i < BNL; ++i) + bnl[i] = vzero(); for (int i = 0; i < q; ++i) for (int j = 0; j < M; ++j) pty[M*i+j] = parity[q*j+i];