only need write disable masks for write conflicts

This commit is contained in:
Ahmet Inan 2019-10-03 10:08:00 +02:00
commit 1f4da1bed3

View file

@ -42,15 +42,15 @@ class LDPCDecoder
typedef SIMD<int8_t, SIMD_SIZE> TYPE;
typedef struct { uint16_t off; uint16_t shi; } Loc;
typedef uint32_t wd_t;
static_assert(sizeof(wd_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links");
typedef uint32_t wdm_t;
static_assert(sizeof(wdm_t) * 8 >= CNC, "write disable mask needs at least as many bits as max check node links");
Rotate<TYPE, D> rotate;
TYPE bnl[BNL];
TYPE msg[MSG];
TYPE pty[PTY];
Loc loc[LOC];
wd_t wd[PTY];
wdm_t wdm[PTY];
uint8_t cnc[q];
static TYPE eor(TYPE a, TYPE b)
@ -114,6 +114,7 @@ class LDPCDecoder
{
TYPE *bl = bnl;
Loc *lo = loc;
wdm_t *wd = wdm;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int deg = cnt + 2;
@ -122,11 +123,18 @@ class LDPCDecoder
TYPE min0 = vdup<TYPE>(127);
TYPE min1 = vdup<TYPE>(127);
TYPE signs = vdup<TYPE>(127);
bool write_conflict = false;
int last_offset = -1;
for (int k = 0; k < deg; ++k) {
TYPE tmp;
if (k < cnt) {
tmp = rotate(msg[lo[k].off], -lo[k].shi);
int offset = lo[k].off;
int shift = -lo[k].shi;
tmp = rotate(msg[offset], shift);
if (last_offset == offset)
write_conflict = true;
last_offset = offset;
} else if (k == cnt) {
tmp = pty[W*i+j];
} else {
@ -170,7 +178,7 @@ class LDPCDecoder
TYPE tmp = vqadd(inp, out);
if (k < cnt) {
if (!((wd[W*i+j]>>k)&1)) {
if (!write_conflict || !((*wd>>k)&1)) {
bl[k] = out;
msg[lo[k].off] = rotate(tmp, lo[k].shi);
}
@ -189,7 +197,7 @@ class LDPCDecoder
}
}
}
if (wd[W*i+j]) {
if (write_conflict) {
for (int first = 0, c = 1; c < cnt; ++c) {
if (lo[first].off != lo[c].off || c == cnt-1) {
int last = c - 1;
@ -197,15 +205,16 @@ class LDPCDecoder
++last;
if (last != first) {
int count = last - first + 1;
wd_t mask = ((1 << count) - 1) << first;
wd_t cur = wd[W*i+j];
wd_t tmp = cur & mask;
wd_t ror = (tmp >> 1) | (tmp << (count-1));
wd[W*i+j] = (cur & ~mask) | (ror & mask);
wdm_t mask = ((1 << count) - 1) << first;
wdm_t cur = *wd;
wdm_t tmp = cur & mask;
wdm_t ror = (tmp >> 1) | (tmp << (count-1));
*wd = (cur & ~mask) | (ror & mask);
}
first = c;
}
}
++wd;
}
lo += cnt;
bl += deg;
@ -235,6 +244,7 @@ public:
}
}
Loc *lo = loc;
wdm_t *wd = wdm;
for (int i = 0; i < q; ++i) {
int cnt = cnc[i];
int offset[cnt], shift[cnt];
@ -249,10 +259,12 @@ public:
shift[c] = (shift[c] + 1) % M;
}
std::sort(lo, lo + cnt, [](const Loc &a, const Loc &b){ return a.off < b.off; });
wd[W*i+j] = 0;
wdm_t tmp = 0;
for (int c = 0; c < cnt-1; ++c)
if (lo[c].off == lo[c+1].off)
wd[W*i+j] |= 1 << c;
tmp |= 1 << c;
if (tmp)
*wd++ = tmp;
lo += cnt;
}
}