simplified by merging pty and msg arrays into var

while using location array for parity as well
This commit is contained in:
Ahmet Inan 2019-10-03 20:53:49 +02:00
commit 422b405596

View file

@ -36,22 +36,22 @@ class LDPCDecoder
static const int W = M/D;
static const int PTY = R/D;
static const int MSG = K/D;
static const int VAR = N/D;
static const int CNC = TABLE::LINKS_MAX_CN - 2;
static const int BNL = (TABLE::LINKS_TOTAL + D-1) / D;
static const int LOC = (TABLE::LINKS_TOTAL - (2*R-1) + D-1) / D;
static const int LOC = (TABLE::LINKS_TOTAL + D-1) / D;
typedef SIMD<int8_t, SIMD_SIZE> TYPE;
typedef struct { uint16_t off; uint16_t shi; } Loc;
typedef uint32_t wdm_t;
static_assert(sizeof(wdm_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links");
static_assert(sizeof(wdm_t) * 8 >= TABLE::LINKS_MAX_CN, "write disable mask needs at least as many bits as max check node links");
Rotate<TYPE, D> rotate;
TYPE bnl[BNL];
TYPE msg[MSG];
TYPE pty[PTY];
TYPE var[VAR];
Loc loc[LOC];
wdm_t wdm[PTY];
int16_t csh[MSG];
int16_t csh[VAR];
uint8_t cnc[q];
static TYPE eor(TYPE a, TYPE b)
@ -79,35 +79,22 @@ class LDPCDecoder
{
Loc *lo = loc;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int deg = cnt + 2;
int deg = cnc[i] + 2;
auto res = vmask(vzero<TYPE>());
for (int j = 0; j < W; ++j) {
TYPE cnv = vdup<TYPE>(1);
for (int k = 0; k < deg; ++k) {
TYPE tmp;
if (k < cnt) {
int offset = lo[k].off;
int shift = -lo[k].shi;
shift -= csh[offset];
shift %= D;
tmp = rotate(msg[offset], shift);
} else if (k == cnt) {
tmp = pty[W*i+j];
} else {
if (i) {
tmp = pty[W*(i-1)+j];
} else if (j) {
tmp = pty[W*(q-1)+j-1];
} else {
tmp = rotate(pty[PTY-1], 1);
tmp.v[0] = 127;
}
}
int offset = lo[k].off;
int shift = lo[k].shi;
shift -= csh[offset];
shift %= D;
TYPE tmp = rotate(var[offset], shift);
if (i == 0 && j == 0 && offset == VAR-1)
tmp.v[0] = 127;
cnv = vsign(cnv, tmp);
}
res = vorr(res, vclez(cnv));
lo += cnt;
lo += deg;
}
for (int n = 0; n < D; ++n)
if (res.v[n])
@ -121,8 +108,7 @@ class LDPCDecoder
Loc *lo = loc;
wdm_t *wd = wdm;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int deg = cnt + 2;
int deg = cnc[i] + 2;
for (int j = 0; j < W; ++j) {
TYPE mags[deg], inps[deg];
TYPE min0 = vdup<TYPE>(127);
@ -130,31 +116,23 @@ class LDPCDecoder
TYPE signs = vdup<TYPE>(127);
bool write_conflict = false;
int last_offset = -1;
int8_t prev_val = 0;
for (int k = 0; k < deg; ++k) {
TYPE tmp;
if (k < cnt) {
int offset = lo[k].off;
int shift = -lo[k].shi;
shift -= csh[offset];
shift %= D;
tmp = rotate(msg[offset], shift);
if (last_offset == offset)
write_conflict = true;
last_offset = offset;
} else if (k == cnt) {
tmp = pty[W*i+j];
} else {
if (i) {
tmp = pty[W*(i-1)+j];
} else if (j) {
tmp = pty[W*(q-1)+j-1];
} else {
tmp = rotate(pty[PTY-1], 1);
tmp.v[0] = 127;
}
int offset = lo[k].off;
int shift = lo[k].shi;
shift -= csh[offset];
shift %= D;
TYPE tmp = rotate(var[offset], shift);
if (i == 0 && j == 0 && offset == VAR-1) {
prev_val = tmp.v[0];
tmp.v[0] = 127;
}
if (last_offset == offset)
write_conflict = true;
last_offset = offset;
TYPE inp = vqsub(tmp, bl[k]);
TYPE mag = vqabs(inp);
@ -184,34 +162,23 @@ class LDPCDecoder
TYPE tmp = vqadd(inp, out);
if (k < cnt) {
if (!write_conflict || !((*wd>>k)&1)) {
bl[k] = out;
int offset = lo[k].off;
int shift = -lo[k].shi;
msg[offset] = tmp;
csh[offset] = shift;
}
} else if (k == cnt) {
int offset = lo[k].off;
int shift = lo[k].shi;
if (i == 0 && j == 0 && offset == VAR-1)
tmp.v[0] = prev_val;
if (!write_conflict || !((*wd>>k)&1)) {
bl[k] = out;
pty[W*i+j] = tmp;
} else {
bl[k] = out;
if (i) {
pty[W*(i-1)+j] = tmp;
} else if (j) {
pty[W*(q-1)+j-1] = tmp;
} else {
tmp.v[0] = pty[PTY-1].v[D-1];
pty[PTY-1] = rotate(tmp, -1);
}
var[offset] = tmp;
csh[offset] = shift;
}
}
if (write_conflict) {
for (int first = 0, c = 1; c < cnt; ++c) {
if (lo[first].off != lo[c].off || c == cnt-1) {
int last = c - 1;
if (c == cnt-1)
for (int first = 0, k = 1; k < deg; ++k) {
if (lo[first].off != lo[k].off || k == deg-1) {
int last = k - 1;
if (k == deg-1)
++last;
if (last != first) {
int count = last - first + 1;
@ -221,12 +188,12 @@ class LDPCDecoder
wdm_t ror = (tmp >> 1) | (tmp << (count-1));
*wd = (cur & ~mask) | (ror & mask);
}
first = c;
first = k;
}
}
++wd;
}
lo += cnt;
lo += deg;
bl += deg;
}
}
@ -257,6 +224,7 @@ public:
wdm_t *wd = wdm;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int deg = cnt + 2;
int offset[cnt], shift[cnt];
for (int c = 0; c < cnt; ++c) {
shift[c] = pos[CNC*i+c] % M;
@ -265,17 +233,30 @@ public:
for (int j = 0; j < W; ++j) {
for (int c = 0; c < cnt; ++c) {
lo[c].off = offset[c] / D + shift[c] % W;
lo[c].shi = shift[c] / W;
lo[c].shi = (D - shift[c] / W) % D;
shift[c] = (shift[c] + 1) % M;
}
std::sort(lo, lo + cnt, [](const Loc &a, const Loc &b){ return a.off < b.off; });
if (i) {
lo[cnt].off = MSG+W*(i-1)+j;
lo[cnt].shi = 0;
} else if (j) {
lo[cnt].off = MSG+W*(q-1)+j-1;
lo[cnt].shi = 0;
} else {
lo[cnt].off = VAR-1;
lo[cnt].shi = 1;
}
lo[cnt+1].off = MSG+W*i+j;
lo[cnt+1].shi = 0;
std::sort(lo, lo + deg, [](const Loc &a, const Loc &b){ return a.off < b.off; });
wdm_t tmp = 0;
for (int c = 0; c < cnt-1; ++c)
if (lo[c].off == lo[c+1].off)
tmp |= 1 << c;
for (int d = 0; d < deg-1; ++d)
if (lo[d].off == lo[d+1].off)
tmp |= 1 << d;
if (tmp)
*wd++ = tmp;
lo += cnt;
lo += deg;
}
}
//assert(lo <= loc + LOC);
@ -285,28 +266,28 @@ public:
{
for (int i = 0; i < BNL; ++i)
bnl[i] = vzero<TYPE>();
for (int i = 0; i < MSG; ++i)
for (int i = 0; i < VAR; ++i)
csh[i] = 0;
for (int i = 0; i < K/M; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
msg[W*i+j].v[n] = message[M*i+W*n+j];
var[W*i+j].v[n] = message[M*i+W*n+j];
for (int i = 0; i < q; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
pty[W*i+j].v[n] = parity[q*(W*n+j)+i];
var[MSG+W*i+j].v[n] = parity[q*(W*n+j)+i];
while (bad() && --trials >= 0)
update();
for (int i = 0; i < MSG; ++i)
msg[i] = rotate(msg[i], -csh[i]);
for (int i = 0; i < VAR; ++i)
var[i] = rotate(var[i], -csh[i]);
for (int i = 0; i < K/M; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
message[M*i+W*n+j] = msg[W*i+j].v[n];
message[M*i+W*n+j] = var[W*i+j].v[n];
for (int i = 0; i < q; ++i)
for (int j = 0; j < W; ++j)
for (int n = 0; n < D; ++n)
parity[q*(W*n+j)+i] = pty[W*i+j].v[n];
parity[q*(W*n+j)+i] = var[MSG+W*i+j].v[n];
return trials;
}
};