From 9c55719f8592d53917ef24b2f137ca1719564ac1 Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Sun, 1 Feb 2026 22:07:15 +0100 Subject: [PATCH] use a rank map instead of a frozen bit array --- pac_encoder.hh | 11 +- pac_list_decoder.hh | 203 +++++++----------------------- tests/pac_list_regression_test.cc | 16 +-- 3 files changed, 53 insertions(+), 177 deletions(-) diff --git a/pac_encoder.hh b/pac_encoder.hh index 6f5de42..1c6ce4e 100644 --- a/pac_encoder.hh +++ b/pac_encoder.hh @@ -14,10 +14,6 @@ template class PACEncoder { typedef PolarHelper PH; - static bool get(const uint32_t *bits, int idx) - { - return (bits[idx/32] >> (idx%32)) & 1; - } static bool conv(int *state, bool input) { // 1011011 @@ -30,13 +26,14 @@ class PACEncoder return output; } public: - void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level) + void operator()(TYPE *codeword, const TYPE *message, const int *rank_map, int mesg_bits, int level) { int length = 1 << level; int state = 0; + int frozen = length - mesg_bits; for (int i = 0; i < length; i += 2) { - TYPE msg0 = get(frozen, i) ? PH::one() : *message++; - TYPE msg1 = get(frozen, i+1) ? PH::one() : *message++; + TYPE msg0 = rank_map[i] < frozen ? PH::one() : *message++; + TYPE msg1 = rank_map[i+1] < frozen ? PH::one() : *message++; msg0 = 1 - 2 * conv(&state, msg0 < 0); msg1 = 1 - 2 * conv(&state, msg1 < 0); codeword[i] = PH::qmul(msg0, msg1); diff --git a/pac_list_decoder.hh b/pac_list_decoder.hh index ec4eae8..74ec39e 100644 --- a/pac_list_decoder.hh +++ b/pac_list_decoder.hh @@ -11,8 +11,29 @@ Copyright 2025 Ahmet Inan namespace CODE { +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, const int *rank, int frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, rank, frozen); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, rank+N/2, frozen); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + template -struct PACListLeaf +struct PACListTree { typedef PolarHelper PH; typedef typename PH::PATH PATH; @@ -76,158 +97,19 @@ struct PACListLeaf *hard = hrd; return map; } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static const int N = 1 << M; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, const uint32_t *frozen) - { - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); - MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, frozen); - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, frozen+N/2/32); - for (int i = 0; i < N/2; ++i) - hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); - return vshuf(lmap, rmap); - } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static const int M = 6; - static const int N = 1 << M; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, const uint32_t *frozen) - { - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); - MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, frozen[0]); - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, frozen[1]); - for (int i = 0; i < N/2; ++i) - hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); - return vshuf(lmap, rmap); - } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static const int M = 5; - static const int N = 1 << M; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, uint32_t frozen) - { - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); - MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, frozen >> (N/2)); - for (int i = 0; i < N/2; ++i) - hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); - return vshuf(lmap, rmap); - } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static const int M = 4; - static const int N = 1 << M; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, uint32_t frozen) - { - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); - MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, frozen >> (N/2)); - for (int i = 0; i < N/2; ++i) - hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); - return vshuf(lmap, rmap); - } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static const int M = 3; - static const int N = 1 << M; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, uint32_t frozen) - { - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); - MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, frozen >> (N/2)); - for (int i = 0; i < N/2; ++i) - hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); - return vshuf(lmap, rmap); - } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static const int M = 2; - static const int N = 1 << M; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, uint32_t frozen) - { - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); - MAP lmap = PACListTree::decode(metric, message, maps, count, state, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); - for (int i = 0; i < N/2; ++i) - soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PACListTree::decode(metric, message, maps, count, state, hard+N/2, soft, frozen >> (N/2)); - for (int i = 0; i < N/2; ++i) - hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); - return vshuf(lmap, rmap); - } -}; - -template -struct PACListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, uint32_t frozen) + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, int *state, TYPE *hard, TYPE *soft, const int *rank, int frozen) { soft[1] = PH::prod(soft[2], soft[3]); MAP lmap, rmap; - if (frozen & 1) - lmap = PACListLeaf::rate0(metric, state, hard, soft); + if (rank[0] < frozen) + lmap = rate0(metric, state, hard, soft); else - lmap = PACListLeaf::rate1(metric, message, maps, count, state, hard, soft); + lmap = rate1(metric, message, maps, count, state, hard, soft); soft[1] = PH::madd(hard[0], vshuf(soft[2], lmap), vshuf(soft[3], lmap)); - if (frozen >> 1) - rmap = PACListLeaf::rate0(metric, state, hard+1, soft); + if (rank[1] < frozen) + rmap = rate0(metric, state, hard+1, soft); else - rmap = PACListLeaf::rate1(metric, message, maps, count, state, hard+1, soft); + rmap = rate1(metric, message, maps, count, state, hard+1, soft); hard[0] = PH::qmul(vshuf(hard[0], rmap), hard[1]); return vshuf(lmap, rmap); } @@ -246,7 +128,7 @@ class PACListDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) + void operator()(int *rank, TYPE *message, const VALUE *codeword, const int *rank_map, int mesg_bits, int level) { assert(level <= MAX_M); PATH metric[TYPE::SIZE]; @@ -260,20 +142,21 @@ public: int state[TYPE::SIZE]; for (int i = 0; i < TYPE::SIZE; ++i) state[i] = 0; + int frozen = length - mesg_bits; switch (level) { - case 5: PACListTree::decode(metric, message, maps, &count, state, hard, soft, *frozen); break; - case 6: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 7: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 8: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 9: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 10: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 11: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 12: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 13: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 14: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 15: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; - case 16: PACListTree::decode(metric, message, maps, &count, state, hard, soft, frozen); break; + case 5: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 6: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 7: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 8: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 9: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 10: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 11: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 12: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 13: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 14: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 15: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; + case 16: PACListTree::decode(metric, message, maps, &count, state, hard, soft, rank_map, frozen); break; default: assert(false); } diff --git a/tests/pac_list_regression_test.cc b/tests/pac_list_regression_test.cc index 12e5756..9a15412 100644 --- a/tests/pac_list_regression_test.cc +++ b/tests/pac_list_regression_test.cc @@ -46,15 +46,12 @@ int main() typedef std::default_random_engine generator; typedef std::uniform_int_distribution distribution; auto data = std::bind(distribution(0, 1), generator(rd())); - auto frozen = new uint32_t[N/32]; auto codeword = new code_type[N]; double erasure_probability = 0.5; int K = (1 - erasure_probability) * N; double design_SNR = 10 * std::log10(-std::log(erasure_probability)); std::cerr << "design SNR: " << design_SNR << std::endl; - for (int i = 0; i < N / 32; ++i) - frozen[i] = 0; const int *reliability_sequence; if (1) { auto construct = new CODE::ReedMullerSequence<10>; @@ -66,12 +63,11 @@ int main() } else { reliability_sequence = sequence; } - for (int i = 0, j = 0; i < 1024 && j < N - K; ++i) { + auto rank_map = new int[N]; + for (int i = 0, j = 0; i < 1024 && j < N; ++i) { int index = reliability_sequence[i]; - if (index < N) { - frozen[index/32] |= 1 << (index%32); - ++j; - } + if (index < N) + rank_map[index] = j++; } std::cerr << "Polar(" << N << ", " << K << ")" << std::endl; auto message = new code_type[K]; @@ -121,7 +117,7 @@ int main() } CODE::PACEncoder encode; - encode(codeword, message, frozen, M); + encode(codeword, message, rank_map, K, M); for (int i = 0; i < N; ++i) orig[i] = codeword[i]; @@ -144,7 +140,7 @@ int main() int rank[L]; auto start = std::chrono::system_clock::now(); - (*decode)(rank, decoded, codeword, frozen, M); + (*decode)(rank, decoded, codeword, rank_map, K, M); auto end = std::chrono::system_clock::now(); auto usec = std::chrono::duration_cast(end - start); double mbs = (double)K / usec.count();