From 0bb0eaf5a98cbaa5ae7a2c92b94b683b51b6cf8b Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Fri, 14 Jul 2023 14:33:07 +0200 Subject: [PATCH] added parity check aided SCL decoding --- polar_parity_aided.hh | 386 ++++++++++++++++++++++++++++ tests/polar_list_regression_test.cc | 18 +- 2 files changed, 402 insertions(+), 2 deletions(-) create mode 100644 polar_parity_aided.hh diff --git a/polar_parity_aided.hh b/polar_parity_aided.hh new file mode 100644 index 0000000..4a7f31a --- /dev/null +++ b/polar_parity_aided.hh @@ -0,0 +1,386 @@ +/* +Parity aided successive cancellation list decoding of polar codes + +Copyright 2023 Ahmet Inan +*/ + +#pragma once + +#include +#include "polar_helper.hh" + +namespace CODE { + +template +class PolarParityEncoder +{ + typedef PolarHelper PH; + static bool get(const uint32_t *bits, int idx) + { + return (bits[idx/32] >> (idx%32)) & 1; + } +public: + void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level, int stride) + { + int length = 1 << level; + int count = stride; + TYPE parity = PH::one(); + for (int i = 0; i < length; i += 2) { + TYPE msg0, msg1; + if (get(frozen, i)) { + msg0 = PH::one(); + } else if (count) { + msg0 = *message++; + parity = PH::qmul(parity, msg0); + --count; + } else { + msg0 = parity; + parity = PH::one(); + count = stride; + } + if (get(frozen, i + 1)) { + msg1 = PH::one(); + } else if (count) { + msg1 = *message++; + parity = PH::qmul(parity, msg1); + --count; + } else { + msg1 = parity; + parity = PH::one(); + count = stride; + } + codeword[i] = PH::qmul(msg0, msg1); + codeword[i+1] = msg1; + } + for (int h = 2; h < length; h *= 2) + for (int i = 0; i < length; i += 2 * h) + for (int j = i; j < i + h; ++j) + codeword[j] = PH::qmul(codeword[j], codeword[j+h]); + } +}; + +template +struct PolarParityNode +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int N = 1 << M; + static MAP rate0(PATH *metric, TYPE *hard, TYPE *soft) + { + for (int i = 0; i < N; ++i) + hard[i] = PH::one(); + for (int i = 0; i < N; ++i) + for (int k = 0; k < TYPE::SIZE; ++k) + if (soft[i+N].v[k] < 0) + metric[k] -= soft[i+N].v[k]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = k; + return map; + } +}; + +template +struct PolarParityNode +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static MAP rate0(PATH *metric, TYPE *hard, TYPE *soft) + { + *hard = PH::one(); + for (int k = 0; k < TYPE::SIZE; ++k) + if (soft[1].v[k] < 0) + metric[k] -= soft[1].v[k]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = k; + return map; + } + static MAP rate1(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, TYPE *parity, int *count, int stride) + { + TYPE sft = soft[1]; + PATH fork[2*TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + fork[k] = fork[k+TYPE::SIZE] = metric[k]; + for (int k = 0; k < TYPE::SIZE; ++k) + if (sft.v[k] < 0) + fork[k] -= sft.v[k]; + else + fork[k+TYPE::SIZE] += sft.v[k]; + int perm[2*TYPE::SIZE]; + for (int k = 0; k < 2*TYPE::SIZE; ++k) + perm[k] = k; + std::nth_element(perm, perm+TYPE::SIZE, perm+2*TYPE::SIZE, [fork](int a, int b){ return fork[a] < fork[b]; }); + for (int k = 0; k < TYPE::SIZE; ++k) + metric[k] = fork[perm[k]]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = perm[k] % TYPE::SIZE; + TYPE hrd; + for (int k = 0; k < TYPE::SIZE; ++k) + hrd.v[k] = perm[k] < TYPE::SIZE ? 1 : -1; + if (*count) { + message[*index] = hrd; + *parity = PH::qmul(vshuf(*parity, map), hrd); + maps[*index] = map; + ++*index; + --*count; + } else { + message[*index-1] = vshuf(message[*index-1], map); + maps[*index-1] = vshuf(maps[*index-1], map); + TYPE chk = vshuf(*parity, map); + for (int k = 0; k < TYPE::SIZE; ++k) + if (chk.v[k] != hrd.v[k]) + metric[k] = 1000; + *parity = PH::one(); + *count = stride; + } + *hard = hrd; + return map; + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, const uint32_t *frozen, TYPE *parity, int *count, int stride) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarParityTree::decode(metric, message, maps, index, hard, soft, frozen, parity, count, stride); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarParityTree::decode(metric, message, maps, index, hard+N/2, soft, frozen+N/2/32, parity, count, stride); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 6; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, const uint32_t *frozen, TYPE *parity, int *count, int stride) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if (frozen[0] == 0xffffffff) + lmap = PolarParityNode::rate0(metric, hard, soft); + else + lmap = PolarParityTree::decode(metric, message, maps, index, hard, soft, frozen[0], parity, count, stride); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen[1] == 0xffffffff) + rmap = PolarParityNode::rate0(metric, hard+N/2, soft); + else + rmap = PolarParityTree::decode(metric, message, maps, index, hard+N/2, soft, frozen[1], parity, count, stride); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 5; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, uint32_t frozen, TYPE *parity, int *count, int stride) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PolarParityNode::rate0(metric, hard, soft); + else + lmap = PolarParityTree::decode(metric, message, maps, index, hard, soft, frozen & ((1<<(1<<(M-1)))-1), parity, count, stride); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PolarParityNode::rate0(metric, hard+N/2, soft); + else + rmap = PolarParityTree::decode(metric, message, maps, index, hard+N/2, soft, frozen >> (N/2), parity, count, stride); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 4; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, uint32_t frozen, TYPE *parity, int *count, int stride) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PolarParityNode::rate0(metric, hard, soft); + else + lmap = PolarParityTree::decode(metric, message, maps, index, hard, soft, frozen & ((1<<(1<<(M-1)))-1), parity, count, stride); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PolarParityNode::rate0(metric, hard+N/2, soft); + else + rmap = PolarParityTree::decode(metric, message, maps, index, hard+N/2, soft, frozen >> (N/2), parity, count, stride); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 3; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, uint32_t frozen, TYPE *parity, int *count, int stride) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PolarParityNode::rate0(metric, hard, soft); + else + lmap = PolarParityTree::decode(metric, message, maps, index, hard, soft, frozen & ((1<<(1<<(M-1)))-1), parity, count, stride); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PolarParityNode::rate0(metric, hard+N/2, soft); + else + rmap = PolarParityTree::decode(metric, message, maps, index, hard+N/2, soft, frozen >> (N/2), parity, count, stride); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 2; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, uint32_t frozen, TYPE *parity, int *count, int stride) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PolarParityNode::rate0(metric, hard, soft); + else + lmap = PolarParityTree::decode(metric, message, maps, index, hard, soft, frozen & ((1<<(1<<(M-1)))-1), parity, count, stride); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PolarParityNode::rate0(metric, hard+N/2, soft); + else + rmap = PolarParityTree::decode(metric, message, maps, index, hard+N/2, soft, frozen >> (N/2), parity, count, stride); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarParityTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *index, TYPE *hard, TYPE *soft, uint32_t frozen, TYPE *parity, int *count, int stride) + { + soft[1] = PH::prod(soft[2], soft[3]); + MAP lmap, rmap; + if (frozen & 1) + lmap = PolarParityNode::rate0(metric, hard, soft); + else + lmap = PolarParityNode::rate1(metric, message, maps, index, hard, soft, parity, count, stride); + soft[1] = PH::madd(hard[0], vshuf(soft[2], lmap), vshuf(soft[3], lmap)); + if (frozen >> 1) + rmap = PolarParityNode::rate0(metric, hard+1, soft); + else + rmap = PolarParityNode::rate1(metric, message, maps, index, hard+1, soft, parity, count, stride); + hard[0] = PH::qmul(vshuf(hard[0], rmap), hard[1]); + return vshuf(lmap, rmap); + } +}; + +template +class PolarParityDecoder +{ + static_assert(MAX_M >= 5 && MAX_M <= 16); + typedef PolarHelper PH; + typedef typename TYPE::value_type VALUE; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int MAX_N = 1 << MAX_M; + TYPE soft[2*MAX_N]; + TYPE hard[MAX_N]; + MAP maps[MAX_N]; +public: + void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride) + { + assert(level <= MAX_M); + int index = 0; + metric[0] = 0; + for (int k = 1; k < TYPE::SIZE; ++k) + metric[k] = 1000; + int length = 1 << level; + for (int i = 0; i < length; ++i) + soft[length+i] = vdup(codeword[i]); + TYPE parity = PH::one(); + int count = stride; + + switch (level) { + case 5: PolarParityTree::decode(metric, message, maps, &index, hard, soft, *frozen, &parity, &count, stride); break; + case 6: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 7: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 8: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 9: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 10: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 11: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 12: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 13: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 14: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 15: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + case 16: PolarParityTree::decode(metric, message, maps, &index, hard, soft, frozen, &parity, &count, stride); break; + default: assert(false); + } + + MAP acc = maps[index-1]; + for (int i = index-2; i >= 0; --i) { + message[i] = vshuf(message[i], acc); + acc = vshuf(maps[i], acc); + } + } +}; + +} + diff --git a/tests/polar_list_regression_test.cc b/tests/polar_list_regression_test.cc index c3ce7ce..5cf2cff 100644 --- a/tests/polar_list_regression_test.cc +++ b/tests/polar_list_regression_test.cc @@ -16,6 +16,7 @@ Copyright 2020 Ahmet Inan #include "polar_list_decoder.hh" #include "polar_encoder.hh" #include "polar_sequence.hh" +#include "polar_parity_aided.hh" #include "crc.hh" #include "sequence.h" @@ -28,10 +29,13 @@ int main() { const int M = 10; const int N = 1 << M; - const bool systematic = true; + const bool systematic = false; const bool crc_aided = true; + const bool par_aided = true; + static_assert(!par_aided || !systematic, "systematic and parity aided are mutually exclusive"); CODE::CRC crc(0xD419CC15); const int C = 32; + const int S = 32; #if 1 typedef int8_t code_type; double SCALE = 2; @@ -79,12 +83,16 @@ int main() frozen[i] = 0; for (int i = 0; i < N - K; ++i) frozen[reliability_sequence[i]/32] |= 1 << (reliability_sequence[i]%32); + int P = K / (S + 1); + if (par_aided) + K -= P; std::cerr << "Polar(" << N << ", " << K << ")" << std::endl; auto message = new code_type[K]; auto decoded = new simd_type[K]; CODE::PolarHelper::PATH metric[SIMD_WIDTH]; std::cerr << "sizeof(PolarListDecoder) = " << sizeof(CODE::PolarListDecoder) << std::endl; auto decode = new CODE::PolarListDecoder; + auto par_dec = new CODE::PolarParityDecoder; auto orig = new code_type[N]; auto noisy = new code_type[N]; @@ -132,6 +140,9 @@ int main() for (int i = 0, j = 0; i < N; ++i) if (!get_bit(frozen, i)) assert(codeword[i] == message[j++]); + } else if (par_aided) { + CODE::PolarParityEncoder encode; + encode(codeword, message, frozen, M, S); } else { CODE::PolarEncoder encode; encode(codeword, message, frozen, M); @@ -157,7 +168,10 @@ int main() noisy[i] = codeword[i]; auto start = std::chrono::system_clock::now(); - (*decode)(metric, decoded, codeword, frozen, M); + if (par_aided) + (*par_dec)(metric, decoded, codeword, frozen, M, S); + else + (*decode)(metric, decoded, codeword, frozen, M); auto end = std::chrono::system_clock::now(); auto usec = std::chrono::duration_cast(end - start); double mbs = (double)K / usec.count();