diff --git a/ldpc_decoder.hh b/ldpc_decoder.hh index ba84d97..dac88c8 100644 --- a/ldpc_decoder.hh +++ b/ldpc_decoder.hh @@ -36,22 +36,21 @@ class LDPCDecoder static const int W = M/D; static const int PTY = R/D; static const int MSG = K/D; + static const int VAR = N/D; static const int CNC = TABLE::LINKS_MAX_CN - 2; - static const int BNL = (TABLE::LINKS_TOTAL + D-1) / D; - static const int LOC = (TABLE::LINKS_TOTAL - (2*R-1) + D-1) / D; + static const int BNL = (TABLE::LINKS_TOTAL + 1) / D; typedef SIMD TYPE; typedef struct { uint16_t off; uint16_t shi; } Loc; - typedef uint32_t wd_t; - static_assert(sizeof(wd_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links"); Rotate rotate; TYPE bnl[BNL]; - TYPE msg[MSG]; - TYPE pty[PTY]; - Loc loc[LOC]; - wd_t wd[PTY]; - uint8_t cnc[q]; + TYPE var[VAR]; + Loc loc[BNL]; + bool wds[BNL]; + int16_t csh[VAR]; + uint8_t cnt[PTY]; + bool start; static TYPE eor(TYPE a, TYPE b) { @@ -74,150 +73,117 @@ class LDPCDecoder return vreinterpret(vand(vmask(b), vorr(vceqz(a), veor(vcgtz(a), vcltz(b))))); } - bool bad() - { - Loc *lo = loc; - for (int i = 0; i < q; ++i) { - int cnt = cnc[i]; - int deg = cnt + 2; - auto res = vmask(vzero()); - for (int j = 0; j < W; ++j) { - TYPE cnv = vdup(1); - for (int k = 0; k < deg; ++k) { - TYPE tmp; - if (k < cnt) { - tmp = rotate(msg[lo[k].off], -lo[k].shi); - } else if (k == cnt) { - tmp = pty[W*i+j]; - } else { - if (i) { - tmp = pty[W*(i-1)+j]; - } else if (j) { - tmp = pty[W*(q-1)+j-1]; - } else { - tmp = rotate(pty[PTY-1], 1); - tmp.v[0] = 127; - } - } - cnv = vsign(cnv, tmp); - } - res = vorr(res, vclez(cnv)); - lo += cnt; - } - for (int n = 0; n < D; ++n) - if (res.v[n]) - return true; - } - return false; - } - void update() + bool update() { TYPE *bl = bnl; Loc *lo = loc; - for (int i = 0; i < q; ++i) { - int cnt = cnc[i]; - int deg = cnt + 2; - for (int j = 0; j < W; ++j) { - TYPE mags[deg], inps[deg]; - TYPE min0 = vdup(127); - TYPE min1 = vdup(127); - TYPE signs = vdup(127); + bool *wd = wds; + auto bad = vmask(vzero()); + for (int i = 0; i < PTY; ++i) { + int deg = cnt[i]; + TYPE mags[deg], inps[deg]; + TYPE min0 = vdup(127); + TYPE min1 = vdup(127); + TYPE signs = vdup(127); + TYPE cnv = vdup(127); + bool first_wd; + int last_offset = 0; + int8_t prev_val = 0; - for (int k = 0; k < deg; ++k) { - TYPE tmp; - if (k < cnt) { - tmp = rotate(msg[lo[k].off], -lo[k].shi); - } else if (k == cnt) { - tmp = pty[W*i+j]; - } else { - if (i) { - tmp = pty[W*(i-1)+j]; - } else if (j) { - tmp = pty[W*(q-1)+j-1]; - } else { - tmp = rotate(pty[PTY-1], 1); - tmp.v[0] = 127; - } - } - - TYPE inp = vqsub(tmp, bl[k]); - - TYPE mag = vqabs(inp); - - if (BETA) { - auto beta = vunsigned(vdup(BETA)); - mag = vsigned(vqsub(vunsigned(mag), beta)); - } - - min1 = vmin(min1, vmax(min0, mag)); - min0 = vmin(min0, mag); - - signs = eor(signs, inp); - - inps[k] = inp; - mags[k] = mag; + for (int k = 0; k < deg; ++k) { + int offset = lo[k].off; + int shift = lo[k].shi; + int dshift = (shift - csh[offset]) % D; + TYPE tmp = rotate(var[offset], dshift); + if (offset == VAR-1 && shift == 1) { + prev_val = tmp.v[0]; + tmp.v[0] = 127; } - for (int k = 0; k < deg; ++k) { - TYPE mag = mags[k]; - TYPE inp = inps[k]; - TYPE out = vsign(other(mag, min0, min1), mine(signs, inp)); + TYPE inp = vqsub(tmp, bl[k]); + if (start) + inp = tmp; - out = vclamp(out, -32, 31); + TYPE mag = vqabs(inp); + if (BETA) { + auto beta = vunsigned(vdup(BETA)); + mag = vsigned(vqsub(vunsigned(mag), beta)); + } + + min1 = vmin(min1, vmax(min0, mag)); + min0 = vmin(min0, mag); + + signs = eor(signs, inp); + + inps[k] = inp; + mags[k] = mag; + } + for (int k = 0; k < deg; ++k) { + TYPE mag = mags[k]; + TYPE inp = inps[k]; + + TYPE out = vsign(other(mag, min0, min1), mine(signs, inp)); + + out = vclamp(out, -32, 31); + + if (!start) out = selfcorr(bl[k], out); - TYPE tmp = vqadd(inp, out); + TYPE tmp = vqadd(inp, out); - if (k < cnt) { - if (!((wd[W*i+j]>>k)&1)) { - bl[k] = out; - msg[lo[k].off] = rotate(tmp, lo[k].shi); - } - } else if (k == cnt) { - bl[k] = out; - pty[W*i+j] = tmp; - } else { - bl[k] = out; - if (i) { - pty[W*(i-1)+j] = tmp; - } else if (j) { - pty[W*(q-1)+j-1] = tmp; - } else { - tmp.v[0] = pty[PTY-1].v[D-1]; - pty[PTY-1] = rotate(tmp, -1); - } - } + cnv = vsign(cnv, tmp); + + int offset = lo[k].off; + int shift = lo[k].shi; + + if (offset == VAR-1 && shift == 1) + tmp.v[0] = prev_val; + + bool this_wd = wd[k]; + if (start) { + if (k) + this_wd = offset == last_offset; + else + this_wd = false; } - if (wd[W*i+j]) { - for (int first = 0, c = 1; c < cnt; ++c) { - if (lo[first].off != lo[c].off || c == cnt-1) { - int last = c - 1; - if (c == cnt-1) - ++last; - if (last != first) { - int count = last - first + 1; - wd_t mask = ((1 << count) - 1) << first; - wd_t cur = wd[W*i+j]; - wd_t tmp = cur & mask; - wd_t ror = (tmp >> 1) | (tmp << (count-1)); - wd[W*i+j] = (cur & ~mask) | (ror & mask); - } - first = c; - } - } + if (!this_wd) { + bl[k] = out; + var[offset] = tmp; + csh[offset] = shift; + } else if (start) { + bl[k] = vzero(); } - lo += cnt; - bl += deg; + if (k) { + bool next_wd = this_wd; + if (last_offset != offset) { + next_wd = first_wd; + first_wd = this_wd; + } + wd[k-1] = next_wd; + } else { + first_wd = this_wd; + } + last_offset = offset; } + wd[deg-1] = first_wd; + bad = vorr(bad, vclez(cnv)); + lo += deg; + bl += deg; + wd += deg; } //assert(bl <= bnl + BNL); //std::cerr << BNL - (bl - bnl) << std::endl; + for (int n = 0; n < D; ++n) + if (bad.v[n]) + return true; + return false; } public: LDPCDecoder() { uint16_t pos[q * CNC]; + uint8_t cnc[q]; for (int i = 0; i < q; ++i) cnc[i] = 0; int bit_pos = 0; @@ -234,9 +200,13 @@ public: bit_pos += M; } } + for (int i = 0; i < q; ++i) + for (int j = 0; j < W; ++j) + cnt[W*i+j] = cnc[i] + 2; Loc *lo = loc; for (int i = 0; i < q; ++i) { int cnt = cnc[i]; + int deg = cnt + 2; int offset[cnt], shift[cnt]; for (int c = 0; c < cnt; ++c) { shift[c] = pos[CNC*i+c] % M; @@ -245,42 +215,63 @@ public: for (int j = 0; j < W; ++j) { for (int c = 0; c < cnt; ++c) { lo[c].off = offset[c] / D + shift[c] % W; - lo[c].shi = shift[c] / W; + lo[c].shi = (D - shift[c] / W) % D; shift[c] = (shift[c] + 1) % M; } - std::sort(lo, lo + cnt, [](const Loc &a, const Loc &b){ return a.off < b.off; }); - wd[W*i+j] = 0; - for (int c = 0; c < cnt-1; ++c) - if (lo[c].off == lo[c+1].off) - wd[W*i+j] |= 1 << c; - lo += cnt; + if (i) { + lo[cnt].off = MSG+W*(i-1)+j; + lo[cnt].shi = 0; + } else if (j) { + lo[cnt].off = MSG+W*(q-1)+j-1; + lo[cnt].shi = 0; + } else { + lo[cnt].off = VAR-1; + lo[cnt].shi = 1; + } + lo[cnt+1].off = MSG+W*i+j; + lo[cnt+1].shi = 0; + + std::sort(lo, lo + deg, [](const Loc &a, const Loc &b){ return a.off < b.off; }); +#if 0 + std::cout << deg; + for (int d = 0; d < deg; ++d) + std::cout << '\t' << (int)lo[d].off << ':' << (int)lo[d].shi; + std::cout << std::endl; +#endif + lo += deg; } } - //assert(lo <= loc + LOC); - //std::cerr << LOC - (lo - loc) << std::endl; + //assert(lo <= loc + BNL); + //std::cerr << BNL - (lo - loc) << std::endl; } int operator()(int8_t *message, int8_t *parity, int trials = 25) { - for (int i = 0; i < BNL; ++i) - bnl[i] = vzero(); + for (int i = 0; i < VAR; ++i) + csh[i] = 0; for (int i = 0; i < K/M; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - msg[W*i+j].v[n] = message[M*i+W*n+j]; + var[W*i+j].v[n] = message[M*i+W*n+j]; for (int i = 0; i < q; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - pty[W*i+j].v[n] = parity[q*(W*n+j)+i]; - while (bad() && --trials >= 0) - update(); + var[MSG+W*i+j].v[n] = parity[q*(W*n+j)+i]; + + start = true; + while (--trials >= 0 && update()) + start = false; + + for (int i = 0; i < VAR; ++i) + var[i] = rotate(var[i], -csh[i]); + for (int i = 0; i < K/M; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - message[M*i+W*n+j] = msg[W*i+j].v[n]; + message[M*i+W*n+j] = var[W*i+j].v[n]; for (int i = 0; i < q; ++i) for (int j = 0; j < W; ++j) for (int n = 0; n < D; ++n) - parity[q*(W*n+j)+i] = pty[W*i+j].v[n]; + parity[q*(W*n+j)+i] = var[MSG+W*i+j].v[n]; return trials; } };