diff --git a/ldpc_decoder.hh b/ldpc_decoder.hh index aed39aa..8a6d2df 100644 --- a/ldpc_decoder.hh +++ b/ldpc_decoder.hh @@ -36,22 +36,22 @@ class LDPCDecoder static const int W = M/D; static const int PTY = R/D; static const int MSG = K/D; + static const int VAR = N/D; static const int CNC = TABLE::LINKS_MAX_CN - 2; static const int BNL = (TABLE::LINKS_TOTAL + D-1) / D; - static const int LOC = (TABLE::LINKS_TOTAL - (2*R-1) + D-1) / D; + static const int LOC = (TABLE::LINKS_TOTAL + D-1) / D; typedef SIMD TYPE; typedef struct { uint16_t off; uint16_t shi; } Loc; typedef uint32_t wdm_t; - static_assert(sizeof(wdm_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links"); + static_assert(sizeof(wdm_t) * 8 >= TABLE::LINKS_MAX_CN, "write disable mask needs at least as many bits as max check node links"); Rotate rotate; TYPE bnl[BNL]; - TYPE msg[MSG]; - TYPE pty[PTY]; + TYPE var[VAR]; Loc loc[LOC]; wdm_t wdm[PTY]; - int16_t csh[MSG]; + int16_t csh[VAR]; uint8_t cnc[q]; static TYPE eor(TYPE a, TYPE b) @@ -79,35 +79,22 @@ class LDPCDecoder { Loc *lo = loc; for (int i = 0; i < q; ++i) { - int cnt = cnc[i]; - int deg = cnt + 2; + int deg = cnc[i] + 2; auto res = vmask(vzero()); for (int j = 0; j < W; ++j) { TYPE cnv = vdup(1); for (int k = 0; k < deg; ++k) { - TYPE tmp; - if (k < cnt) { - int offset = lo[k].off; - int shift = -lo[k].shi; - shift -= csh[offset]; - shift %= D; - tmp = rotate(msg[offset], shift); - } else if (k == cnt) { - tmp = pty[W*i+j]; - } else { - if (i) { - tmp = pty[W*(i-1)+j]; - } else if (j) { - tmp = pty[W*(q-1)+j-1]; - } else { - tmp = rotate(pty[PTY-1], 1); - tmp.v[0] = 127; - } - } + int offset = lo[k].off; + int shift = lo[k].shi; + shift -= csh[offset]; + shift %= D; + TYPE tmp = rotate(var[offset], shift); + if (i == 0 && j == 0 && offset == VAR-1) + tmp.v[0] = 127; cnv = vsign(cnv, tmp); } res = vorr(res, vclez(cnv)); - lo += cnt; + lo += deg; } for (int n = 0; n < D; ++n) if (res.v[n]) @@ -121,8 +108,7 @@ class LDPCDecoder Loc *lo = loc; wdm_t *wd = wdm; for (int i = 0; i < q; ++i) { - int cnt = cnc[i]; - int deg = cnt + 2; + int deg = cnc[i] + 2; for (int j = 0; j < W; ++j) { TYPE mags[deg], inps[deg]; TYPE min0 = vdup(127); @@ -130,31 +116,23 @@ class LDPCDecoder TYPE signs = vdup(127); bool write_conflict = false; int last_offset = -1; + int8_t prev_val = 0; for (int k = 0; k < deg; ++k) { - TYPE tmp; - if (k < cnt) { - int offset = lo[k].off; - int shift = -lo[k].shi; - shift -= csh[offset]; - shift %= D; - tmp = rotate(msg[offset], shift); - if (last_offset == offset) - write_conflict = true; - last_offset = offset; - } else if (k == cnt) { - tmp = pty[W*i+j]; - } else { - if (i) { - tmp = pty[W*(i-1)+j]; - } else if (j) { - tmp = pty[W*(q-1)+j-1]; - } else { - tmp = rotate(pty[PTY-1], 1); - tmp.v[0] = 127; - } + int offset = lo[k].off; + int shift = lo[k].shi; + shift -= csh[offset]; + shift %= D; + TYPE tmp = rotate(var[offset], shift); + if (i == 0 && j == 0 && offset == VAR-1) { + prev_val = tmp.v[0]; + tmp.v[0] = 127; } + if (last_offset == offset) + write_conflict = true; + last_offset = offset; + TYPE inp = vqsub(tmp, bl[k]); TYPE mag = vqabs(inp); @@ -184,34 +162,23 @@ class LDPCDecoder TYPE tmp = vqadd(inp, out); - if (k < cnt) { - if (!write_conflict || !((*wd>>k)&1)) { - bl[k] = out; - int offset = lo[k].off; - int shift = -lo[k].shi; - msg[offset] = tmp; - csh[offset] = shift; - } - } else if (k == cnt) { + int offset = lo[k].off; + int shift = lo[k].shi; + + if (i == 0 && j == 0 && offset == VAR-1) + tmp.v[0] = prev_val; + + if (!write_conflict || !((*wd>>k)&1)) { bl[k] = out; - pty[W*i+j] = tmp; - } else { - bl[k] = out; - if (i) { - pty[W*(i-1)+j] = tmp; - } else if (j) { - pty[W*(q-1)+j-1] = tmp; - } else { - tmp.v[0] = pty[PTY-1].v[D-1]; - pty[PTY-1] = rotate(tmp, -1); - } + var[offset] = tmp; + csh[offset] = shift; } } if (write_conflict) { - for (int first = 0, c = 1; c < cnt; ++c) { - if (lo[first].off != lo[c].off || c == cnt-1) { - int last = c - 1; - if (c == cnt-1) + for (int first = 0, k = 1; k < deg; ++k) { + if (lo[first].off != lo[k].off || k == deg-1) { + int last = k - 1; + if (k == deg-1) ++last; if (last != first) { int count = last - first + 1; @@ -221,12 +188,12 @@ class LDPCDecoder wdm_t ror = (tmp >> 1) | (tmp << (count-1)); *wd = (cur & ~mask) | (ror & mask); } - first = c; + first = k; } } ++wd; } - lo += cnt; + lo += deg; bl += deg; } } @@ -257,6 +224,7 @@ public: wdm_t *wd = wdm; for (int i = 0; i < q; ++i) { int cnt = cnc[i]; + int deg = cnt + 2; int offset[cnt], shift[cnt]; for (int c = 0; c < cnt; ++c) { shift[c] = pos[CNC*i+c] % M; @@ -265,17 +233,30 @@ public: for (int j = 0; j < W; ++j) { for (int c = 0; c < cnt; ++c) { lo[c].off = offset[c] / D + shift[c] % W; - lo[c].shi = shift[c] / W; + lo[c].shi = (D - shift[c] / W) % D; shift[c] = (shift[c] + 1) % M; } - std::sort(lo, lo + cnt, [](const Loc &a, const Loc &b){ return a.off < b.off; }); + if (i) { + lo[cnt].off = MSG+W*(i-1)+j; + lo[cnt].shi = 0; + } else if (j) { + lo[cnt].off = MSG+W*(q-1)+j-1; + lo[cnt].shi = 0; + } else { + lo[cnt].off = VAR-1; + lo[cnt].shi = 1; + } + lo[cnt+1].off = MSG+W*i+j; + lo[cnt+1].shi = 0; + + std::sort(lo, lo + deg, [](const Loc &a, const Loc &b){ return a.off < b.off; }); wdm_t tmp = 0; - for (int c = 0; c < cnt-1; ++c) - if (lo[c].off == lo[c+1].off) - tmp |= 1 << c; + for (int d = 0; d < deg-1; ++d) + if (lo[d].off == lo[d+1].off) + tmp |= 1 << d; if (tmp) *wd++ = tmp; - lo += cnt; + lo += deg; } } //assert(lo <= loc + LOC); @@ -285,28 +266,28 @@ public: { for (int i = 0; i < BNL; ++i) bnl[i] = vzero(); - for (int i = 0; i < MSG; ++i) + for (int i = 0; i < VAR; ++i) csh[i] = 0; for (int i = 0; i < K/M; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - msg[W*i+j].v[n] = message[M*i+W*n+j]; + var[W*i+j].v[n] = message[M*i+W*n+j]; for (int i = 0; i < q; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - pty[W*i+j].v[n] = parity[q*(W*n+j)+i]; + var[MSG+W*i+j].v[n] = parity[q*(W*n+j)+i]; while (bad() && --trials >= 0) update(); - for (int i = 0; i < MSG; ++i) - msg[i] = rotate(msg[i], -csh[i]); + for (int i = 0; i < VAR; ++i) + var[i] = rotate(var[i], -csh[i]); for (int i = 0; i < K/M; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - message[M*i+W*n+j] = msg[W*i+j].v[n]; + message[M*i+W*n+j] = var[W*i+j].v[n]; for (int i = 0; i < q; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - parity[q*(W*n+j)+i] = pty[W*i+j].v[n]; + parity[q*(W*n+j)+i] = var[MSG+W*i+j].v[n]; return trials; } };