mirror of
https://github.com/aicodix/code.git
synced 2026-04-27 14:30:36 +00:00
accelerate decoder using SIMD
Also added a fake parity bit to reduce code duplication. The fake parity bit seems not to affect the decoding performance.
This commit is contained in:
parent
199a988def
commit
3e881ce049
1 changed files with 116 additions and 115 deletions
233
ldpc_decoder.hh
233
ldpc_decoder.hh
|
|
@ -9,162 +9,163 @@ Copyright 2018 Ahmet Inan <inan@aicodix.de>
|
|||
|
||||
#include <algorithm>
|
||||
#include "exclusive_reduce.hh"
|
||||
#include "simd.hh"
|
||||
|
||||
namespace CODE {
|
||||
|
||||
template <typename TABLE, int BETA>
|
||||
class LDPCDecoder
|
||||
{
|
||||
#ifdef __AVX2__
|
||||
static const int SIMD_SIZE = 32;
|
||||
#else
|
||||
static const int SIMD_SIZE = 16;
|
||||
#endif
|
||||
static const int M = TABLE::M;
|
||||
static const int N = TABLE::N;
|
||||
static const int K = TABLE::K;
|
||||
static const int R = N-K;
|
||||
static const int q = R/M;
|
||||
static const int CNC = TABLE::LINKS_MAX_CN - 2;
|
||||
static const int LT = TABLE::LINKS_TOTAL;
|
||||
static const int BNL = TABLE::LINKS_TOTAL / SIMD_SIZE + TABLE::LINKS_MAX_CN * q;
|
||||
|
||||
int8_t bnl[LT];
|
||||
typedef SIMD<int8_t, SIMD_SIZE> TYPE;
|
||||
|
||||
TYPE bnl[BNL];
|
||||
int8_t pty[R];
|
||||
uint16_t pos[R * CNC];
|
||||
uint8_t cnc[R];
|
||||
|
||||
static int8_t sqadd(int8_t a, int8_t b)
|
||||
static TYPE eor(TYPE a, TYPE b)
|
||||
{
|
||||
int16_t x = int16_t(a) + int16_t(b);
|
||||
x = std::min<int16_t>(std::max<int16_t>(x, -128), 127);
|
||||
return x;
|
||||
return vreinterpret<TYPE>(veor(vmask(a), vmask(b)));
|
||||
}
|
||||
static int8_t sqsub(int8_t a, int8_t b)
|
||||
static TYPE min(TYPE a, TYPE b)
|
||||
{
|
||||
int16_t x = int16_t(a) - int16_t(b);
|
||||
x = std::min<int16_t>(std::max<int16_t>(x, -128), 127);
|
||||
return x;
|
||||
return vmin(a, b);
|
||||
}
|
||||
static uint8_t sqsubu(uint8_t a, uint8_t b)
|
||||
static void finalp(TYPE *links, int cnt)
|
||||
{
|
||||
int16_t x = int16_t(a) - int16_t(b);
|
||||
x = std::max<int16_t>(x, 0);
|
||||
return x;
|
||||
}
|
||||
static int8_t smin(int8_t a, int8_t b)
|
||||
{
|
||||
return std::min(a, b);
|
||||
}
|
||||
static int8_t sclamp(int8_t x, int8_t a, int8_t b)
|
||||
{
|
||||
return std::min(std::max(x, a), b);
|
||||
}
|
||||
static int8_t seor(int8_t a, int8_t b)
|
||||
{
|
||||
return a ^ b;
|
||||
}
|
||||
static int8_t sqabs(int8_t a)
|
||||
{
|
||||
return std::abs(std::max<int8_t>(a, -127));
|
||||
}
|
||||
static int8_t ssign(int8_t a, int8_t b)
|
||||
{
|
||||
return ((b > 0) - (b < 0)) * a;
|
||||
auto beta = vunsigned(vdup<TYPE>(BETA));
|
||||
TYPE mags[cnt];
|
||||
for (int i = 0; i < cnt; ++i)
|
||||
mags[i] = vsigned(vqsub(vunsigned(vqabs(links[i])), beta));
|
||||
|
||||
TYPE mins[cnt];
|
||||
exclusive_reduce(mags, mins, cnt, min);
|
||||
|
||||
TYPE signs[cnt];
|
||||
exclusive_reduce(links, signs, cnt, eor);
|
||||
for (int i = 0; i < cnt; ++i)
|
||||
signs[i] = vreinterpret<TYPE>(vorr(vmask(signs[i]), vmask(vdup<TYPE>(127))));
|
||||
|
||||
for (int i = 0; i < cnt; ++i)
|
||||
links[i] = vsign(mins[i], signs[i]);
|
||||
}
|
||||
|
||||
bool bad(int8_t *data, int8_t *parity)
|
||||
{
|
||||
{
|
||||
int cnt = cnc[0];
|
||||
{
|
||||
int8_t cnv = ssign(1, parity[0]);
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
cnv = ssign(cnv, data[pos[c]]);
|
||||
if (cnv <= 0)
|
||||
return true;
|
||||
}
|
||||
for (int j = 1; j < M; ++j) {
|
||||
int8_t cnv = ssign(ssign(1, parity[j+(q-1)*M-1]), parity[j]);
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
cnv = ssign(cnv, data[pos[CNC*j+c]]);
|
||||
if (cnv <= 0)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < q; ++i) {
|
||||
for (int i = 0; i < q; ++i) {
|
||||
int cnt = cnc[i];
|
||||
for (int j = 0; j < M; ++j) {
|
||||
int8_t cnv = ssign(ssign(1, parity[M*(i-1)+j]), parity[M*i+j]);
|
||||
auto res = vmask(vzero<TYPE>());
|
||||
for (int j = 0; j < M; j += SIMD_SIZE) {
|
||||
int num = std::min(M - j, SIMD_SIZE);
|
||||
TYPE par[2];
|
||||
if (i) {
|
||||
for (int n = 0; n < num; ++n)
|
||||
par[0].v[n] = parity[M*(i-1)+j+n];
|
||||
} else {
|
||||
if (j) {
|
||||
for (int n = 0; n < num; ++n)
|
||||
par[0].v[n] = parity[j+(q-1)*M-1+n];
|
||||
} else {
|
||||
par[0].v[0] = 1;
|
||||
for (int n = 1; n < num; ++n)
|
||||
par[0].v[n] = parity[j+(q-1)*M-1+n];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < num; ++n)
|
||||
par[1].v[n] = parity[M*i+j+n];
|
||||
TYPE dat[cnt];
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
cnv = ssign(cnv, data[pos[CNC*(M*i+j)+c]]);
|
||||
if (cnv <= 0)
|
||||
return true;
|
||||
for (int n = 0; n < num; ++n)
|
||||
dat[c].v[n] = data[pos[CNC*(M*i+j+n)+c]];
|
||||
TYPE cnv = vdup<TYPE>(1);
|
||||
for (int c = 0; c < 2; ++c)
|
||||
cnv = vsign(cnv, par[c]);
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
cnv = vsign(cnv, dat[c]);
|
||||
for (int n = num; n < SIMD_SIZE; ++n)
|
||||
cnv.v[n] = 1;
|
||||
res = vorr(res, vclez(cnv));
|
||||
}
|
||||
for (int n = 0; n < SIMD_SIZE; ++n)
|
||||
if (res.v[n])
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void finalp(int8_t *links, int cnt)
|
||||
{
|
||||
int8_t mags[cnt], mins[cnt];
|
||||
for (int i = 0; i < cnt; ++i)
|
||||
mags[i] = sqsubu(sqabs(links[i]), BETA);
|
||||
exclusive_reduce(mags, mins, cnt, smin);
|
||||
|
||||
int8_t signs[cnt];
|
||||
exclusive_reduce(links, signs, cnt, seor);
|
||||
for (int i = 0; i < cnt; ++i)
|
||||
signs[i] |= 127;
|
||||
|
||||
for (int i = 0; i < cnt; ++i)
|
||||
links[i] = ssign(mins[i], signs[i]);
|
||||
}
|
||||
void update(int8_t *data, int8_t *parity)
|
||||
{
|
||||
int8_t *bl = bnl;
|
||||
{
|
||||
int cnt = cnc[0];
|
||||
{
|
||||
int deg = cnt + 1;
|
||||
int8_t inp[deg], out[deg];
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
inp[c] = out[c] = sqsub(data[pos[c]], bl[c]);
|
||||
inp[cnt] = out[cnt] = sqsub(parity[0], bl[cnt]);
|
||||
finalp(out, deg);
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
data[pos[c]] = sqadd(inp[c], out[c]);
|
||||
parity[0] = sqadd(inp[cnt], out[cnt]);
|
||||
for (int d = 0; d < deg; ++d)
|
||||
*bl++ = sclamp(out[d], -32, 31);
|
||||
}
|
||||
int deg = cnt + 2;
|
||||
for (int j = 1; j < M; ++j) {
|
||||
int8_t inp[deg], out[deg];
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
inp[c] = out[c] = sqsub(data[pos[CNC*j+c]], bl[c]);
|
||||
inp[cnt] = out[cnt] = sqsub(parity[j+(q-1)*M-1], bl[cnt]);
|
||||
inp[cnt+1] = out[cnt+1] = sqsub(parity[j], bl[cnt+1]);
|
||||
finalp(out, deg);
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
data[pos[CNC*j+c]] = sqadd(inp[c], out[c]);
|
||||
parity[j+(q-1)*M-1] = sqadd(inp[cnt], out[cnt]);
|
||||
parity[j] = sqadd(inp[cnt+1], out[cnt+1]);
|
||||
for (int d = 0; d < deg; ++d)
|
||||
*bl++ = sclamp(out[d], -32, 31);
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < q; ++i) {
|
||||
TYPE *bl = bnl;
|
||||
for (int i = 0; i < q; ++i) {
|
||||
int cnt = cnc[i];
|
||||
int deg = cnt + 2;
|
||||
for (int j = 0; j < M; ++j) {
|
||||
int8_t inp[deg], out[deg];
|
||||
for (int j = 0; j < M; j += SIMD_SIZE) {
|
||||
int num = std::min(M - j, SIMD_SIZE);
|
||||
TYPE par[2];
|
||||
if (i) {
|
||||
for (int n = 0; n < num; ++n)
|
||||
par[0].v[n] = parity[M*(i-1)+j+n];
|
||||
} else {
|
||||
if (j) {
|
||||
for (int n = 0; n < num; ++n)
|
||||
par[0].v[n] = parity[j+(q-1)*M-1+n];
|
||||
} else {
|
||||
par[0].v[0] = 1;
|
||||
for (int n = 1; n < num; ++n)
|
||||
par[0].v[n] = parity[j+(q-1)*M-1+n];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < num; ++n)
|
||||
par[1].v[n] = parity[M*i+j+n];
|
||||
TYPE dat[cnt];
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
inp[c] = out[c] = sqsub(data[pos[CNC*(M*i+j)+c]], bl[c]);
|
||||
inp[cnt] = out[cnt] = sqsub(parity[M*(i-1)+j], bl[cnt]);
|
||||
inp[cnt+1] = out[cnt+1] = sqsub(parity[M*i+j], bl[cnt+1]);
|
||||
for (int n = 0; n < num; ++n)
|
||||
dat[c].v[n] = data[pos[CNC*(M*i+j+n)+c]];
|
||||
TYPE inp[deg], out[deg];
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
inp[c] = out[c] = vqsub(dat[c], bl[c]);
|
||||
inp[cnt] = out[cnt] = vqsub(par[0], bl[cnt]);
|
||||
inp[cnt+1] = out[cnt+1] = vqsub(par[1], bl[cnt+1]);
|
||||
finalp(out, deg);
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
data[pos[CNC*(M*i+j)+c]] = sqadd(inp[c], out[c]);
|
||||
parity[M*(i-1)+j] = sqadd(inp[cnt], out[cnt]);
|
||||
parity[M*i+j] = sqadd(inp[cnt+1], out[cnt+1]);
|
||||
dat[c] = vqadd(inp[c], out[c]);
|
||||
par[0] = vqadd(inp[cnt], out[cnt]);
|
||||
par[1] = vqadd(inp[cnt+1], out[cnt+1]);
|
||||
for (int d = 0; d < deg; ++d)
|
||||
*bl++ = sclamp(out[d], -32, 31);
|
||||
*bl++ = vclamp(out[d], -32, 31);
|
||||
if (i) {
|
||||
for (int n = 0; n < num; ++n)
|
||||
parity[M*(i-1)+j+n] = par[0].v[n];
|
||||
} else {
|
||||
if (j) {
|
||||
for (int n = 0; n < num; ++n)
|
||||
parity[j+(q-1)*M-1+n] = par[0].v[n];
|
||||
} else {
|
||||
for (int n = 1; n < num; ++n)
|
||||
parity[j+(q-1)*M-1+n] = par[0].v[n];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < num; ++n)
|
||||
parity[M*i+j+n] = par[1].v[n];
|
||||
for (int c = 0; c < cnt; ++c)
|
||||
for (int n = 0; n < num; ++n)
|
||||
data[pos[CNC*(M*i+j+n)+c]] = dat[c].v[n];
|
||||
}
|
||||
}
|
||||
//assert(bl <= bnl + BNL);
|
||||
//std::cerr << BNL - (bl - bnl) << std::endl;
|
||||
}
|
||||
public:
|
||||
LDPCDecoder()
|
||||
|
|
@ -201,8 +202,8 @@ public:
|
|||
}
|
||||
int operator()(int8_t *data, int8_t *parity, int trials = 25)
|
||||
{
|
||||
for (int i = 0; i < LT; ++i)
|
||||
bnl[i] = 0;
|
||||
for (int i = 0; i < BNL; ++i)
|
||||
bnl[i] = vzero<TYPE>();
|
||||
for (int i = 0; i < q; ++i)
|
||||
for (int j = 0; j < M; ++j)
|
||||
pty[M*i+j] = parity[q*j+i];
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue