diff --git a/ldpc_decoder.hh b/ldpc_decoder.hh index ba84d97..0d26d27 100644 --- a/ldpc_decoder.hh +++ b/ldpc_decoder.hh @@ -42,15 +42,15 @@ class LDPCDecoder typedef SIMD TYPE; typedef struct { uint16_t off; uint16_t shi; } Loc; - typedef uint32_t wd_t; - static_assert(sizeof(wd_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links"); + typedef uint32_t wdm_t; + static_assert(sizeof(wdm_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links"); Rotate rotate; TYPE bnl[BNL]; TYPE msg[MSG]; TYPE pty[PTY]; Loc loc[LOC]; - wd_t wd[PTY]; + wdm_t wdm[PTY]; uint8_t cnc[q]; static TYPE eor(TYPE a, TYPE b) @@ -114,6 +114,7 @@ class LDPCDecoder { TYPE *bl = bnl; Loc *lo = loc; + wdm_t *wd = wdm; for (int i = 0; i < q; ++i) { int cnt = cnc[i]; int deg = cnt + 2; @@ -122,11 +123,18 @@ class LDPCDecoder TYPE min0 = vdup(127); TYPE min1 = vdup(127); TYPE signs = vdup(127); + bool write_conflict = false; + int last_offset = -1; for (int k = 0; k < deg; ++k) { TYPE tmp; if (k < cnt) { - tmp = rotate(msg[lo[k].off], -lo[k].shi); + int offset = lo[k].off; + int shift = -lo[k].shi; + tmp = rotate(msg[offset], shift); + if (last_offset == offset) + write_conflict = true; + last_offset = offset; } else if (k == cnt) { tmp = pty[W*i+j]; } else { @@ -170,7 +178,7 @@ class LDPCDecoder TYPE tmp = vqadd(inp, out); if (k < cnt) { - if (!((wd[W*i+j]>>k)&1)) { + if (!write_conflict || !((*wd>>k)&1)) { bl[k] = out; msg[lo[k].off] = rotate(tmp, lo[k].shi); } @@ -189,7 +197,7 @@ class LDPCDecoder } } } - if (wd[W*i+j]) { + if (write_conflict) { for (int first = 0, c = 1; c < cnt; ++c) { if (lo[first].off != lo[c].off || c == cnt-1) { int last = c - 1; @@ -197,15 +205,16 @@ class LDPCDecoder ++last; if (last != first) { int count = last - first + 1; - wd_t mask = ((1 << count) - 1) << first; - wd_t cur = wd[W*i+j]; - wd_t tmp = cur & mask; - wd_t ror = (tmp >> 1) | (tmp << (count-1)); - wd[W*i+j] = (cur & ~mask) | (ror & mask); + wdm_t mask = ((1 << count) - 1) << first; + wdm_t cur = *wd; + wdm_t tmp = cur & mask; + wdm_t ror = (tmp >> 1) | (tmp << (count-1)); + *wd = (cur & ~mask) | (ror & mask); } first = c; } } + ++wd; } lo += cnt; bl += deg; @@ -235,6 +244,7 @@ public: } } Loc *lo = loc; + wdm_t *wd = wdm; for (int i = 0; i < q; ++i) { int cnt = cnc[i]; int offset[cnt], shift[cnt]; @@ -249,10 +259,12 @@ public: shift[c] = (shift[c] + 1) % M; } std::sort(lo, lo + cnt, [](const Loc &a, const Loc &b){ return a.off < b.off; }); - wd[W*i+j] = 0; + wdm_t tmp = 0; for (int c = 0; c < cnt-1; ++c) if (lo[c].off == lo[c+1].off) - wd[W*i+j] |= 1 << c; + tmp |= 1 << c; + if (tmp) + *wd++ = tmp; lo += cnt; } }