shuffle codeword for SIMD aligned memory access

Idea taken from:

Low cost LDPC decoder for DVB-S2
by John Dielissen, Andries Hekstra and Vincent Berg - 2006
This commit is contained in:
Ahmet Inan 2019-09-22 20:45:49 +02:00
commit 717d482822
3 changed files with 169 additions and 131 deletions

View file

@ -61,7 +61,7 @@ This version stores only the first q bit positions and might be faster on low po
### [ldpc_decoder2.hh](ldpc_decoder2.hh)
[Low-density parity-check](https://en.wikipedia.org/wiki/Low-density_parity-check_code) layered decoder
This version stores and uses all bit positions and might be faster on high performance workstations.
This version stores and uses all word positions and might be faster on high performance workstations.
### [exclusive_reduce.hh](exclusive_reduce.hh)

View file

@ -19,21 +19,32 @@ class LDPCDecoder
{
#ifdef __AVX2__
static const int SIMD_SIZE = 32;
// M = 360 = 30 * 12
static const int WORD_SIZE = 30;
#else
static const int SIMD_SIZE = 16;
// M = 360 = 15 * 24
static const int WORD_SIZE = 15;
#endif
static_assert(TABLE::M % WORD_SIZE == 0, "M must be multiple of word size");
static_assert(WORD_SIZE <= SIMD_SIZE, "SIMD size must be bigger or equal word size");
static const int M = TABLE::M;
static const int N = TABLE::N;
static const int K = TABLE::K;
static const int R = N-K;
static const int q = R/M;
static const int D = WORD_SIZE;
static const int W = M/D;
static const int PTY = R/D;
static const int MSG = K/D;
static const int CNC = TABLE::LINKS_MAX_CN - 2;
static const int BNL = TABLE::LINKS_TOTAL / SIMD_SIZE + TABLE::LINKS_MAX_CN * q;
static const int BNL = (TABLE::LINKS_TOTAL + D-1) / D;
typedef SIMD<int8_t, SIMD_SIZE> TYPE;
TYPE bnl[BNL];
int8_t pty[R];
TYPE msg[MSG];
TYPE pty[PTY];
uint16_t pos[q * CNC];
uint8_t cnc[q];
@ -75,8 +86,21 @@ class LDPCDecoder
for (int i = 0; i < cnt; ++i)
links[i] = vsign(other(mags[i], mins[0], mins[1]), mine(signs, links[i]));
}
static TYPE rotate(TYPE a, int s)
{
if (s < 0)
s += D;
int t = D - s;
TYPE ret;
// TODO: I can has barrel shifter?
for (int n = 0; n < s; ++n)
ret.v[n] = a.v[n+t];
for (int n = 0; n < t; ++n)
ret.v[n+s] = a.v[n];
return ret;
}
bool bad(int8_t *data, int8_t *parity)
bool bad()
{
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
@ -86,48 +110,36 @@ class LDPCDecoder
offset[c] = pos[CNC*i+c] - shift[c];
}
auto res = vmask(vzero<TYPE>());
for (int j = 0; j < M; j += SIMD_SIZE) {
int num = std::min(M - j, SIMD_SIZE);
for (int j = 0; j < W; ++j) {
TYPE par[2];
if (i) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(i-1)+j+n];
par[0] = pty[W*(i-1)+j];
} else if (j) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
par[0] = pty[W*(q-1)+j-1];
} else {
par[0] = rotate(pty[PTY-1], 1);
par[0].v[0] = 127;
for (int n = 1; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
}
for (int n = 0; n < num; ++n)
par[1].v[n] = parity[M*i+j+n];
TYPE dat[cnt];
for (int c = 0; c < cnt; ++c) {
int tmp = std::min(num, M - shift[c]);
for (int n = 0; n < tmp; ++n)
dat[c].v[n] = data[offset[c]+shift[c]+n];
for (int n = tmp; n < num; ++n)
dat[c].v[n] = data[offset[c]+n-tmp];
}
par[1] = pty[W*i+j];
TYPE mes[cnt];
for (int c = 0; c < cnt; ++c)
mes[c] = rotate(msg[offset[c]/D+shift[c]%W], -shift[c]/W);
TYPE cnv = vdup<TYPE>(1);
for (int c = 0; c < 2; ++c)
cnv = vsign(cnv, par[c]);
for (int c = 0; c < cnt; ++c)
cnv = vsign(cnv, dat[c]);
for (int n = num; n < SIMD_SIZE; ++n)
cnv.v[n] = 1;
cnv = vsign(cnv, mes[c]);
res = vorr(res, vclez(cnv));
for (int c = 0; c < cnt; ++c)
shift[c] = (shift[c] + num) % M;
shift[c] = (shift[c] + 1) % M;
}
for (int n = 0; n < SIMD_SIZE; ++n)
for (int n = 0; n < D; ++n)
if (res.v[n])
return true;
}
return false;
}
void update(int8_t *data, int8_t *parity)
void update()
{
TYPE *bl = bnl;
for (int i = 0; i < q; ++i) {
@ -138,63 +150,45 @@ class LDPCDecoder
offset[c] = pos[CNC*i+c] - shift[c];
}
int deg = cnt + 2;
for (int j = 0; j < M; j += SIMD_SIZE) {
int num = std::min(M - j, SIMD_SIZE);
for (int j = 0; j < W; ++j) {
TYPE par[2];
if (i) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(i-1)+j+n];
par[0] = pty[W*(i-1)+j];
} else if (j) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
par[0] = pty[W*(q-1)+j-1];
} else {
par[0] = rotate(pty[PTY-1], 1);
par[0].v[0] = 127;
for (int n = 1; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
}
for (int n = 0; n < num; ++n)
par[1].v[n] = parity[M*i+j+n];
TYPE dat[cnt];
for (int c = 0; c < cnt; ++c) {
int tmp = std::min(num, M - shift[c]);
for (int n = 0; n < tmp; ++n)
dat[c].v[n] = data[offset[c]+shift[c]+n];
for (int n = tmp; n < num; ++n)
dat[c].v[n] = data[offset[c]+n-tmp];
}
par[1] = pty[W*i+j];
TYPE mes[cnt];
for (int c = 0; c < cnt; ++c)
mes[c] = rotate(msg[offset[c]/D+shift[c]%W], -shift[c]/W);
TYPE inp[deg], out[deg];
for (int c = 0; c < cnt; ++c)
inp[c] = out[c] = vqsub(dat[c], bl[c]);
inp[c] = out[c] = vqsub(mes[c], bl[c]);
inp[cnt] = out[cnt] = vqsub(par[0], bl[cnt]);
inp[cnt+1] = out[cnt+1] = vqsub(par[1], bl[cnt+1]);
finalp(out, deg);
for (int c = 0; c < cnt; ++c)
dat[c] = vqadd(inp[c], out[c]);
mes[c] = vqadd(inp[c], out[c]);
par[0] = vqadd(inp[cnt], out[cnt]);
par[1] = vqadd(inp[cnt+1], out[cnt+1]);
for (int d = 0; d < deg; ++d)
*bl++ = vclamp(out[d], -32, 31);
if (i) {
for (int n = 0; n < num; ++n)
parity[M*(i-1)+j+n] = par[0].v[n];
pty[W*(i-1)+j] = par[0];
} else if (j) {
for (int n = 0; n < num; ++n)
parity[M*(q-1)-1+j+n] = par[0].v[n];
pty[W*(q-1)+j-1] = par[0];
} else {
for (int n = 1; n < num; ++n)
parity[M*(q-1)-1+j+n] = par[0].v[n];
}
for (int n = 0; n < num; ++n)
parity[M*i+j+n] = par[1].v[n];
for (int c = 0; c < cnt; ++c) {
int tmp = std::min(num, M - shift[c]);
for (int n = 0; n < tmp; ++n)
data[offset[c]+shift[c]+n] = dat[c].v[n];
for (int n = tmp; n < num; ++n)
data[offset[c]+n-tmp] = dat[c].v[n];
par[0].v[0] = pty[PTY-1].v[D-1];
pty[PTY-1] = rotate(par[0], -1);
}
pty[W*i+j] = par[1];
for (int c = 0; c < cnt; ++c)
shift[c] = (shift[c] + num) % M;
msg[offset[c]/D+shift[c]%W] = rotate(mes[c], shift[c]/W);
for (int c = 0; c < cnt; ++c)
shift[c] = (shift[c] + 1) % M;
}
}
//assert(bl <= bnl + BNL);
@ -220,18 +214,28 @@ public:
}
}
}
int operator()(int8_t *data, int8_t *parity, int trials = 25)
int operator()(int8_t *message, int8_t *parity, int trials = 25)
{
for (int i = 0; i < BNL; ++i)
bnl[i] = vzero<TYPE>();
for (int i = 0; i < K/M; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
msg[W*i+j].v[n] = message[M*i+W*n+j];
for (int i = 0; i < q; ++i)
for (int j = 0; j < M; ++j)
pty[M*i+j] = parity[q*j+i];
while (bad(data, pty) && --trials >= 0)
update(data, pty);
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
pty[W*i+j].v[n] = parity[q*(W*n+j)+i];
while (bad() && --trials >= 0)
update();
for (int i = 0; i < K/M; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
message[M*i+W*n+j] = msg[W*i+j].v[n];
for (int i = 0; i < q; ++i)
for (int j = 0; j < M; ++j)
parity[q*j+i] = pty[M*i+j];
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
parity[q*(W*n+j)+i] = pty[W*i+j].v[n];
return trials;
}
};

View file

@ -1,7 +1,7 @@
/*
LDPC SISO layered decoder v2
This version stores and uses all bit positions
This version stores and uses all word positions
Copyright 2018 Ahmet Inan <inan@aicodix.de>
*/
@ -19,22 +19,36 @@ class LDPCDecoder
{
#ifdef __AVX2__
static const int SIMD_SIZE = 32;
// M = 360 = 30 * 12
static const int WORD_SIZE = 30;
#else
static const int SIMD_SIZE = 16;
// M = 360 = 15 * 24
static const int WORD_SIZE = 15;
#endif
static_assert(TABLE::M % WORD_SIZE == 0, "M must be multiple of word size");
static_assert(WORD_SIZE <= SIMD_SIZE, "SIMD size must be bigger or equal word size");
static const int M = TABLE::M;
static const int N = TABLE::N;
static const int K = TABLE::K;
static const int R = N-K;
static const int q = R/M;
static const int D = WORD_SIZE;
static const int W = M/D;
static const int PTY = R/D;
static const int MSG = K/D;
static const int CNC = TABLE::LINKS_MAX_CN - 2;
static const int BNL = TABLE::LINKS_TOTAL / SIMD_SIZE + TABLE::LINKS_MAX_CN * q;
static const int BNL = (TABLE::LINKS_TOTAL + D-1) / D;
static const int POS = PTY * CNC;
typedef SIMD<int8_t, SIMD_SIZE> TYPE;
TYPE bnl[BNL];
int8_t pty[R];
uint16_t pos[R * CNC];
TYPE msg[MSG];
TYPE pty[PTY];
uint16_t pos[q * CNC];
uint16_t off[POS];
uint8_t shi[POS];
uint8_t cnc[q];
static TYPE eor(TYPE a, TYPE b)
@ -75,100 +89,103 @@ class LDPCDecoder
for (int i = 0; i < cnt; ++i)
links[i] = vsign(other(mags[i], mins[0], mins[1]), mine(signs, links[i]));
}
bool bad(int8_t *data, int8_t *parity)
static TYPE rotate(TYPE a, int s)
{
if (s < 0)
s += D;
int t = D - s;
TYPE ret;
// TODO: I can has barrel shifter?
for (int n = 0; n < s; ++n)
ret.v[n] = a.v[n+t];
for (int n = 0; n < t; ++n)
ret.v[n+s] = a.v[n];
return ret;
}
bool bad()
{
uint16_t *of = off;
uint8_t *sh = shi;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
auto res = vmask(vzero<TYPE>());
for (int j = 0; j < M; j += SIMD_SIZE) {
int num = std::min(M - j, SIMD_SIZE);
for (int j = 0; j < W; ++j) {
TYPE par[2];
if (i) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(i-1)+j+n];
par[0] = pty[W*(i-1)+j];
} else if (j) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
par[0] = pty[W*(q-1)+j-1];
} else {
par[0] = rotate(pty[PTY-1], 1);
par[0].v[0] = 127;
for (int n = 1; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
}
for (int n = 0; n < num; ++n)
par[1].v[n] = parity[M*i+j+n];
TYPE dat[cnt];
par[1] = pty[W*i+j];
TYPE mes[cnt];
for (int c = 0; c < cnt; ++c)
for (int n = 0; n < num; ++n)
dat[c].v[n] = data[pos[CNC*(M*i+j+n)+c]];
mes[c] = rotate(msg[of[c]], -sh[c]);
TYPE cnv = vdup<TYPE>(1);
for (int c = 0; c < 2; ++c)
cnv = vsign(cnv, par[c]);
for (int c = 0; c < cnt; ++c)
cnv = vsign(cnv, dat[c]);
for (int n = num; n < SIMD_SIZE; ++n)
cnv.v[n] = 1;
cnv = vsign(cnv, mes[c]);
res = vorr(res, vclez(cnv));
of += cnt;
sh += cnt;
}
for (int n = 0; n < SIMD_SIZE; ++n)
for (int n = 0; n < D; ++n)
if (res.v[n])
return true;
}
return false;
}
void update(int8_t *data, int8_t *parity)
void update()
{
TYPE *bl = bnl;
uint16_t *of = off;
uint8_t *sh = shi;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int deg = cnt + 2;
for (int j = 0; j < M; j += SIMD_SIZE) {
int num = std::min(M - j, SIMD_SIZE);
for (int j = 0; j < W; ++j) {
TYPE par[2];
if (i) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(i-1)+j+n];
par[0] = pty[W*(i-1)+j];
} else if (j) {
for (int n = 0; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
par[0] = pty[W*(q-1)+j-1];
} else {
par[0] = rotate(pty[PTY-1], 1);
par[0].v[0] = 127;
for (int n = 1; n < num; ++n)
par[0].v[n] = parity[M*(q-1)-1+j+n];
}
for (int n = 0; n < num; ++n)
par[1].v[n] = parity[M*i+j+n];
TYPE dat[cnt];
par[1] = pty[W*i+j];
TYPE mes[cnt];
for (int c = 0; c < cnt; ++c)
for (int n = 0; n < num; ++n)
dat[c].v[n] = data[pos[CNC*(M*i+j+n)+c]];
mes[c] = rotate(msg[of[c]], -sh[c]);
TYPE inp[deg], out[deg];
for (int c = 0; c < cnt; ++c)
inp[c] = out[c] = vqsub(dat[c], bl[c]);
inp[c] = out[c] = vqsub(mes[c], bl[c]);
inp[cnt] = out[cnt] = vqsub(par[0], bl[cnt]);
inp[cnt+1] = out[cnt+1] = vqsub(par[1], bl[cnt+1]);
finalp(out, deg);
for (int c = 0; c < cnt; ++c)
dat[c] = vqadd(inp[c], out[c]);
mes[c] = vqadd(inp[c], out[c]);
par[0] = vqadd(inp[cnt], out[cnt]);
par[1] = vqadd(inp[cnt+1], out[cnt+1]);
for (int d = 0; d < deg; ++d)
*bl++ = vclamp(out[d], -32, 31);
if (i) {
for (int n = 0; n < num; ++n)
parity[M*(i-1)+j+n] = par[0].v[n];
pty[W*(i-1)+j] = par[0];
} else if (j) {
for (int n = 0; n < num; ++n)
parity[M*(q-1)-1+j+n] = par[0].v[n];
pty[W*(q-1)+j-1] = par[0];
} else {
for (int n = 1; n < num; ++n)
parity[M*(q-1)-1+j+n] = par[0].v[n];
par[0].v[0] = pty[PTY-1].v[D-1];
pty[PTY-1] = rotate(par[0], -1);
}
for (int n = 0; n < num; ++n)
parity[M*i+j+n] = par[1].v[n];
pty[W*i+j] = par[1];
for (int c = 0; c < cnt; ++c)
for (int n = 0; n < num; ++n)
data[pos[CNC*(M*i+j+n)+c]] = dat[c].v[n];
msg[of[c]] = rotate(mes[c], sh[c]);
of += cnt;
sh += cnt;
}
}
//assert(bl <= bnl + BNL);
@ -187,39 +204,56 @@ public:
for (int d = 0; d < bit_deg; ++d) {
int n = row_ptr[d] % q;
int m = row_ptr[d] / q;
pos[CNC*M*n+cnc[n]++] = bit_pos + (M - m) % M;
pos[CNC*n+cnc[n]++] = bit_pos + (M - m) % M;
}
row_ptr += bit_deg;
bit_pos += M;
}
}
uint16_t *of = off;
uint8_t *sh = shi;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int offset[cnt], shift[cnt];
for (int c = 0; c < cnt; ++c) {
shift[c] = pos[CNC*M*i+c] % M;
offset[c] = pos[CNC*M*i+c] - shift[c];
shift[c] = pos[CNC*i+c] % M;
offset[c] = pos[CNC*i+c] - shift[c];
}
for (int j = 1; j < M; ++j) {
for (int j = 0; j < W; ++j) {
for (int c = 0; c < cnt; ++c) {
*of++ = offset[c] / D + shift[c] % W;
*sh++ = shift[c] / W;
shift[c] = (shift[c] + 1) % M;
pos[CNC*(M*i+j)+c] = offset[c] + shift[c];
}
}
}
//assert(of <= off + POS);
//std::cerr << POS - (sh - shi) << std::endl;
//assert(sh <= shi + POS);
//std::cerr << POS - (of - off) << std::endl;
}
int operator()(int8_t *data, int8_t *parity, int trials = 25)
int operator()(int8_t *message, int8_t *parity, int trials = 25)
{
for (int i = 0; i < BNL; ++i)
bnl[i] = vzero<TYPE>();
for (int i = 0; i < K/M; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
msg[W*i+j].v[n] = message[M*i+W*n+j];
for (int i = 0; i < q; ++i)
for (int j = 0; j < M; ++j)
pty[M*i+j] = parity[q*j+i];
while (bad(data, pty) && --trials >= 0)
update(data, pty);
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
pty[W*i+j].v[n] = parity[q*(W*n+j)+i];
while (bad() && --trials >= 0)
update();
for (int i = 0; i < K/M; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
message[M*i+W*n+j] = msg[W*i+j].v[n];
for (int i = 0; i < q; ++i)
for (int j = 0; j < M; ++j)
parity[q*j+i] = pty[M*i+j];
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
parity[q*(W*n+j)+i] = pty[W*i+j].v[n];
return trials;
}
};