From fbd253aff41e00c3c5b047eb6ecc2c7de5e97596 Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Mon, 22 Dec 2025 12:54:26 +0100 Subject: [PATCH] prepare for Polarization-Adjusted Convolutional codes --- pac_encoder.hh | 39 ++++ pac_list_decoder.hh | 327 ++++++++++++++++++++++++++++++ tests/pac_list_regression_test.cc | 216 ++++++++++++++++++++ 3 files changed, 582 insertions(+) create mode 100644 pac_encoder.hh create mode 100644 pac_list_decoder.hh create mode 100644 tests/pac_list_regression_test.cc diff --git a/pac_encoder.hh b/pac_encoder.hh new file mode 100644 index 0000000..b358391 --- /dev/null +++ b/pac_encoder.hh @@ -0,0 +1,39 @@ +/* +Encoder for Polarization-Adjusted Convolutional codes + +Copyright 2025 Ahmet Inan +*/ + +#pragma once + +#include "polar_helper.hh" + +namespace CODE { + +template +class PACEncoder +{ + typedef PolarHelper PH; + static bool get(const uint32_t *bits, int idx) + { + return (bits[idx/32] >> (idx%32)) & 1; + } +public: + void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level) + { + int length = 1 << level; + for (int i = 0; i < length; i += 2) { + TYPE msg0 = get(frozen, i) ? PH::one() : *message++; + TYPE msg1 = get(frozen, i+1) ? PH::one() : *message++; + codeword[i] = PH::qmul(msg0, msg1); + codeword[i+1] = msg1; + } + for (int h = 2; h < length; h *= 2) + for (int i = 0; i < length; i += 2 * h) + for (int j = i; j < i + h; ++j) + codeword[j] = PH::qmul(codeword[j], codeword[j+h]); + } +}; + +} + diff --git a/pac_list_decoder.hh b/pac_list_decoder.hh new file mode 100644 index 0000000..29b3720 --- /dev/null +++ b/pac_list_decoder.hh @@ -0,0 +1,327 @@ +/* +List decoding of Polarization-Adjusted Convolutional codes + +Copyright 2025 Ahmet Inan +*/ + +#pragma once + +#include "sort.hh" +#include "polar_helper.hh" + +namespace CODE { + +template +struct PACListNode +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int N = 1 << M; + static MAP rate0(PATH *metric, TYPE *hard, TYPE *soft) + { + for (int i = 0; i < N; ++i) + hard[i] = PH::one(); + for (int i = 0; i < N; ++i) + for (int k = 0; k < TYPE::SIZE; ++k) + if (soft[i+N].v[k] < 0) + metric[k] -= soft[i+N].v[k]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = k; + return map; + } +}; + +template +struct PACListNode +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static MAP rate0(PATH *metric, TYPE *hard, TYPE *soft) + { + *hard = PH::one(); + for (int k = 0; k < TYPE::SIZE; ++k) + if (soft[1].v[k] < 0) + metric[k] -= soft[1].v[k]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = k; + return map; + } + static MAP rate1(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft) + { + TYPE sft = soft[1]; + PATH fork[2*TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + fork[2*k] = fork[2*k+1] = metric[k]; + for (int k = 0; k < TYPE::SIZE; ++k) + if (sft.v[k] < 0) + fork[2*k] -= sft.v[k]; + else + fork[2*k+1] += sft.v[k]; + int perm[2*TYPE::SIZE]; + CODE::insertion_sort(perm, fork, 2*TYPE::SIZE); + for (int k = 0; k < TYPE::SIZE; ++k) + metric[k] = fork[k]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = perm[k] >> 1; + TYPE hrd; + for (int k = 0; k < TYPE::SIZE; ++k) + hrd.v[k] = 1 - 2 * (perm[k] & 1); + message[*count] = hrd; + maps[*count] = map; + ++*count; + *hard = hrd; + return map; + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, const uint32_t *frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PACListTree::decode(metric, message, maps, count, hard, soft, frozen); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PACListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen+N/2/32); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 6; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, const uint32_t *frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if (frozen[0] == 0xffffffff) + lmap = PACListNode::rate0(metric, hard, soft); + else + lmap = PACListTree::decode(metric, message, maps, count, hard, soft, frozen[0]); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen[1] == 0xffffffff) + rmap = PACListNode::rate0(metric, hard+N/2, soft); + else + rmap = PACListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen[1]); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 5; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PACListNode::rate0(metric, hard, soft); + else + lmap = PACListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PACListNode::rate0(metric, hard+N/2, soft); + else + rmap = PACListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 4; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PACListNode::rate0(metric, hard, soft); + else + lmap = PACListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PACListNode::rate0(metric, hard+N/2, soft); + else + rmap = PACListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 3; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PACListNode::rate0(metric, hard, soft); + else + lmap = PACListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PACListNode::rate0(metric, hard+N/2, soft); + else + rmap = PACListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 2; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap, rmap; + if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) + lmap = PACListNode::rate0(metric, hard, soft); + else + lmap = PACListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) + rmap = PACListNode::rate0(metric, hard+N/2, soft); + else + rmap = PACListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PACListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + soft[1] = PH::prod(soft[2], soft[3]); + MAP lmap, rmap; + if (frozen & 1) + lmap = PACListNode::rate0(metric, hard, soft); + else + lmap = PACListNode::rate1(metric, message, maps, count, hard, soft); + soft[1] = PH::madd(hard[0], vshuf(soft[2], lmap), vshuf(soft[3], lmap)); + if (frozen >> 1) + rmap = PACListNode::rate0(metric, hard+1, soft); + else + rmap = PACListNode::rate1(metric, message, maps, count, hard+1, soft); + hard[0] = PH::qmul(vshuf(hard[0], rmap), hard[1]); + return vshuf(lmap, rmap); + } +}; + +template +class PACListDecoder +{ + static_assert(MAX_M >= 5 && MAX_M <= 16); + typedef PolarHelper PH; + typedef typename TYPE::value_type VALUE; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int MAX_N = 1 << MAX_M; + TYPE soft[2*MAX_N]; + TYPE hard[MAX_N]; + MAP maps[MAX_N]; +public: + void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) + { + assert(level <= MAX_M); + PATH metric[TYPE::SIZE]; + int count = 0; + metric[0] = 0; + for (int k = 1; k < TYPE::SIZE; ++k) + metric[k] = 1000000; + int length = 1 << level; + for (int i = 0; i < length; ++i) + soft[length+i] = vdup(codeword[i]); + + switch (level) { + case 5: PACListTree::decode(metric, message, maps, &count, hard, soft, *frozen); break; + case 6: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 7: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 8: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 9: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 10: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 11: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 12: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 13: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 14: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 15: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 16: PACListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + default: assert(false); + } + + for (int i = 0, r = 0; rank != nullptr && i < TYPE::SIZE; ++i) { + if (i > 0 && metric[i-1] != metric[i]) + ++r; + rank[i] = r; + } + MAP acc = maps[count-1]; + for (int i = count-2; i >= 0; --i) { + message[i] = vshuf(message[i], acc); + acc = vshuf(maps[i], acc); + } + } +}; + +} + diff --git a/tests/pac_list_regression_test.cc b/tests/pac_list_regression_test.cc new file mode 100644 index 0000000..c41faf5 --- /dev/null +++ b/tests/pac_list_regression_test.cc @@ -0,0 +1,216 @@ +/* +Regression Test for the Polarization-Adjusted Convolutional Encoder and List Decoder + +Copyright 2025 Ahmet Inan +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "polar_helper.hh" +#include "pac_list_decoder.hh" +#include "pac_encoder.hh" +#include "polar_sequence.hh" +#include "crc.hh" +#include "sequence.h" + +bool get_bit(const uint32_t *bits, int idx) +{ + return (bits[idx/32] >> (idx%32)) & 1; +} + +int main() +{ + const int M = 10; + const int N = 1 << M; + const bool crc_aided = true; + CODE::CRC crc(0xD419CC15); + const int C = 32; +#if 1 + const int L = 32; + typedef int8_t code_type; +#else + const int L = 8; + typedef float code_type; +#endif + + typedef SIMD simd_type; + + std::random_device rd; + typedef std::default_random_engine generator; + typedef std::uniform_int_distribution distribution; + auto data = std::bind(distribution(0, 1), generator(rd())); + auto frozen = new uint32_t[N/32]; + auto codeword = new code_type[N]; + + const int *reliability_sequence; + double erasure_probability = 0.3; + int K = (1 - erasure_probability) * N; + double design_SNR = 10 * std::log10(-std::log(erasure_probability)); + std::cerr << "design SNR: " << design_SNR << std::endl; + if (0) { + auto construct = new CODE::PolarSeqConst0; + std::cerr << "sizeof(PolarSeqConst0) = " << sizeof(CODE::PolarSeqConst0) << std::endl; + double better_SNR = design_SNR + 1.59175; + std::cerr << "better SNR: " << better_SNR << std::endl; + double probability = std::exp(-pow(10.0, better_SNR / 10)); + std::cerr << "prob: " << probability << std::endl; + auto rel_seq = new int[N]; + (*construct)(rel_seq, M, probability); + delete construct; + reliability_sequence = rel_seq; + } else { + reliability_sequence = sequence; + } + for (int i = 0; i < N / 32; ++i) + frozen[i] = 0; + for (int i = 0; i < N - K; ++i) + frozen[reliability_sequence[i]/32] |= 1 << (reliability_sequence[i]%32); + std::cerr << "Polar(" << N << ", " << K << ")" << std::endl; + auto message = new code_type[K]; + auto decoded = new simd_type[K]; + std::cerr << "sizeof(PACDecoder) = " << sizeof(CODE::PACListDecoder) << std::endl; + auto decode = new CODE::PACListDecoder; + + auto orig = new code_type[N]; + auto noisy = new code_type[N]; + auto symb = new double[N]; + double low_SNR = std::floor(design_SNR-3); + double high_SNR = std::ceil(design_SNR+5); + double min_SNR = high_SNR, max_mbs = 0; + int count = 0; + std::cerr << "SNR BER Mbit/s Eb/N0" << std::endl; + for (double SNR = low_SNR; count <= 3 && SNR <= high_SNR; SNR += 0.1, ++count) { + //double mean_signal = 0; + double sigma_signal = 1; + double mean_noise = 0; + double sigma_noise = std::sqrt(sigma_signal * sigma_signal / (2 * std::pow(10, SNR / 10))); + + typedef std::normal_distribution normal; + auto awgn = std::bind(normal(mean_noise, sigma_noise), generator(rd())); + + int64_t awgn_errors = 0; + int64_t quantization_erasures = 0; + int64_t uncorrected_errors = 0; + int64_t ambiguity_erasures = 0; + int64_t frame_errors = 0; + double avg_mbs = 0; + int64_t loops = 0; + while (uncorrected_errors < 10000 && ++loops < 1000) { + if (crc_aided) { + crc.reset(); + for (int i = 0; i < K-C; ++i) { + bool bit = data(); + crc(bit); + message[i] = 1 - 2 * bit; + } + for (int i = 0; i < C; ++i) { + bool bit = (crc() >> i) & 1; + message[K-C+i] = 1 - 2 * bit; + } + } else { + for (int i = 0; i < K; ++i) + message[i] = 1 - 2 * data(); + } + + CODE::PACEncoder encode; + encode(codeword, message, frozen, M); + + for (int i = 0; i < N; ++i) + orig[i] = codeword[i]; + + for (int i = 0; i < N; ++i) + symb[i] = codeword[i]; + + for (int i = 0; i < N; ++i) + symb[i] += awgn(); + + // $LLR=log(\frac{p(x=+1|y)}{p(x=-1|y)})$ + // $p(x|\mu,\sigma)=\frac{1}{\sqrt{2\pi}\sigma}}e^{-\frac{(x-\mu)^2}{2\sigma^2}}$ + double DIST = 2; // BPSK + double fact = DIST / (sigma_noise * sigma_noise); + for (int i = 0; i < N; ++i) + codeword[i] = CODE::PolarHelper::quant(fact * symb[i]); + + for (int i = 0; i < N; ++i) + noisy[i] = codeword[i]; + + int rank[L]; + auto start = std::chrono::system_clock::now(); + (*decode)(rank, decoded, codeword, frozen, M); + auto end = std::chrono::system_clock::now(); + auto usec = std::chrono::duration_cast(end - start); + double mbs = (double)K / usec.count(); + avg_mbs += mbs; + + int best = 0; + if (crc_aided) { + bool error = true; + for (int k = 0; k < L; ++k) { + crc.reset(); + for (int i = 0; i < K; ++i) + crc(decoded[i].v[k] < 0); + if (crc() == 0) { + best = k; + error = false; + break; + } + } + frame_errors += error; + } else { + bool error = rank[0] == rank[1]; + for (int i = 0; i < K; ++i) + error |= decoded[i].v[0] * message[i] <= 0; + frame_errors += error; + } + + for (int i = 0; i < N; ++i) + awgn_errors += noisy[i] * (orig[i] < 0); + for (int i = 0; i < N; ++i) + quantization_erasures += !noisy[i]; + for (int i = 0; i < K; ++i) + uncorrected_errors += decoded[i].v[best] * message[i] <= 0; + for (int i = 0; i < K; ++i) + ambiguity_erasures += !decoded[i].v[best]; + } + + avg_mbs /= loops; + + max_mbs = std::max(max_mbs, avg_mbs); + double frame_error_rate = (double)frame_errors / (double)loops; + double bit_error_rate = (double)uncorrected_errors / (double)(K * loops); + if (!uncorrected_errors) + min_SNR = std::min(min_SNR, SNR); + else + count = 0; + + int MOD_BITS = 1; // BPSK + double code_rate = (double)K / (double)N; + double spectral_efficiency = code_rate * MOD_BITS; + double EbN0 = 10 * std::log10(sigma_signal * sigma_signal / (spectral_efficiency * 2 * sigma_noise * sigma_noise)); + + if (0) { + std::cerr << SNR << " Es/N0 => AWGN with standard deviation of " << sigma_noise << " and mean " << mean_noise << std::endl; + std::cerr << EbN0 << " Eb/N0, using spectral efficiency of " << spectral_efficiency << " from " << code_rate << " code rate and " << MOD_BITS << " bits per symbol." << std::endl; + std::cerr << awgn_errors << " errors caused by AWGN." << std::endl; + std::cerr << quantization_erasures << " erasures caused by quantization." << std::endl; + std::cerr << uncorrected_errors << " errors uncorrected." << std::endl; + std::cerr << ambiguity_erasures << " ambiguity erasures." << std::endl; + std::cerr << frame_error_rate << " frame error rate." << std::endl; + std::cerr << bit_error_rate << " bit error rate." << std::endl; + std::cerr << avg_mbs << " megabit per second." << std::endl; + } else { + std::cout << SNR << " " << frame_error_rate << " " << bit_error_rate << " " << avg_mbs << " " << EbN0 << std::endl; + } + } + std::cerr << "QEF at: " << min_SNR << " SNR, speed: " << max_mbs << " Mb/s." << std::endl; + double QEF_SNR = design_SNR + 0.5; + assert(min_SNR < QEF_SNR); + std::cerr << "Polarization-Adjusted Convolutional list regression test passed!" << std::endl; + return 0; +}