diff --git a/polar_decoder.hh b/polar_decoder.hh index 206263d..7beea1a 100644 --- a/polar_decoder.hh +++ b/polar_decoder.hh @@ -16,26 +16,135 @@ struct PolarTree { typedef PolarHelper PH; static const int N = 1 << M; - static void decode(TYPE **message, TYPE *hard, TYPE *soft, const uint8_t *frozen) + static void decode(TYPE **message, TYPE *hard, TYPE *soft, const uint32_t *frozen) { for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); PolarTree::decode(message, hard, soft, frozen); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); - PolarTree::decode(message, hard+N/2, soft, frozen+N/2); + PolarTree::decode(message, hard+N/2, soft, frozen+N/2/32); for (int i = 0; i < N/2; ++i) hard[i] = PH::qmul(hard[i], hard[i+N/2]); } }; +template +struct PolarTree +{ + typedef PolarHelper PH; + static const int M = 6; + static const int N = 1 << M; + static void decode(TYPE **message, TYPE *hard, TYPE *soft, const uint32_t *frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard, soft, frozen[0]); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard+N/2, soft, frozen[1]); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(hard[i], hard[i+N/2]); + } +}; + +template +struct PolarTree +{ + typedef PolarHelper PH; + static const int M = 5; + static const int N = 1 << M; + static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(hard[i], hard[i+N/2]); + } +}; + +template +struct PolarTree +{ + typedef PolarHelper PH; + static const int M = 4; + static const int N = 1 << M; + static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(hard[i], hard[i+N/2]); + } +}; + +template +struct PolarTree +{ + typedef PolarHelper PH; + static const int M = 3; + static const int N = 1 << M; + static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(hard[i], hard[i+N/2]); + } +}; + +template +struct PolarTree +{ + typedef PolarHelper PH; + static const int M = 2; + static const int N = 1 << M; + static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); + PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(hard[i], hard[i+N/2]); + } +}; + +template +struct PolarTree +{ + typedef PolarHelper PH; + static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) + { + soft[1] = PH::prod(soft[2], soft[3]); + PolarTree::decode(message, hard, soft, frozen & 1); + soft[1] = PH::madd(hard[0], soft[2], soft[3]); + PolarTree::decode(message, hard+1, soft, frozen >> 1); + hard[0] = PH::qmul(hard[0], hard[1]); + } +}; + template struct PolarTree { typedef PolarHelper PH; - static void decode(TYPE **message, TYPE *hard, TYPE *soft, const uint8_t *frozen) + static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) { - if (*frozen) { + if (frozen) { *hard = PH::one(); } else { *hard = PH::signum(soft[1]); @@ -47,12 +156,13 @@ struct PolarTree template class PolarDecoder { + static_assert(MAX_M >= 5 && MAX_M <= 29); typedef PolarHelper PH; static const int MAX_N = 1 << MAX_M; TYPE soft[2*MAX_N]; TYPE hard[MAX_N]; public: - void operator()(TYPE *message, const TYPE *codeword, const uint8_t *frozen, int level) + void operator()(TYPE *message, const TYPE *codeword, const uint32_t *frozen, int level) { assert(level <= MAX_M); int length = 1 << level; @@ -60,12 +170,7 @@ public: soft[length+i] = codeword[i]; switch (level) { - case 0: PolarTree::decode(&message, hard, soft, frozen); break; - case 1: PolarTree::decode(&message, hard, soft, frozen); break; - case 2: PolarTree::decode(&message, hard, soft, frozen); break; - case 3: PolarTree::decode(&message, hard, soft, frozen); break; - case 4: PolarTree::decode(&message, hard, soft, frozen); break; - case 5: PolarTree::decode(&message, hard, soft, frozen); break; + case 5: PolarTree::decode(&message, hard, soft, *frozen); break; case 6: PolarTree::decode(&message, hard, soft, frozen); break; case 7: PolarTree::decode(&message, hard, soft, frozen); break; case 8: PolarTree::decode(&message, hard, soft, frozen); break; diff --git a/polar_encoder.hh b/polar_encoder.hh index 8eb27b0..6009270 100644 --- a/polar_encoder.hh +++ b/polar_encoder.hh @@ -14,13 +14,17 @@ template class PolarEncoder { typedef PolarHelper PH; + static bool get(const uint32_t *bits, int idx) + { + return (bits[idx/32] >> (idx%32)) & 1; + } public: - void operator()(TYPE *codeword, const TYPE *message, const uint8_t *frozen, int level) + void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level) { int length = 1 << level; for (int i = 0; i < length; i += 2) { - TYPE msg0 = frozen[i] ? PH::one() : *message++; - TYPE msg1 = frozen[i+1] ? PH::one() : *message++; + TYPE msg0 = get(frozen, i) ? PH::one() : *message++; + TYPE msg1 = get(frozen, i+1) ? PH::one() : *message++; codeword[i] = PH::qmul(msg0, msg1); codeword[i+1] = msg1; } @@ -35,13 +39,17 @@ template class PolarSysEnc { typedef PolarHelper PH; + static bool get(const uint32_t *bits, int idx) + { + return (bits[idx/32] >> (idx%32)) & 1; + } public: - void operator()(TYPE *codeword, const TYPE *message, const uint8_t *frozen, int level) + void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level) { int length = 1 << level; for (int i = 0; i < length; i += 2) { - TYPE msg0 = frozen[i] ? PH::one() : *message++; - TYPE msg1 = frozen[i+1] ? PH::one() : *message++; + TYPE msg0 = get(frozen, i) ? PH::one() : *message++; + TYPE msg1 = get(frozen, i+1) ? PH::one() : *message++; codeword[i] = PH::qmul(msg0, msg1); codeword[i+1] = msg1; } @@ -50,8 +58,8 @@ public: for (int j = i; j < i + h; ++j) codeword[j] = PH::qmul(codeword[j], codeword[j+h]); for (int i = 0; i < length; i += 2) { - TYPE msg0 = frozen[i] ? PH::one() : codeword[i]; - TYPE msg1 = frozen[i+1] ? PH::one() : codeword[i+1]; + TYPE msg0 = get(frozen, i) ? PH::one() : codeword[i]; + TYPE msg1 = get(frozen, i+1) ? PH::one() : codeword[i+1]; codeword[i] = PH::qmul(msg0, msg1); codeword[i+1] = msg1; } diff --git a/polar_freezer.hh b/polar_freezer.hh index 35a1a09..755eb74 100644 --- a/polar_freezer.hh +++ b/polar_freezer.hh @@ -12,23 +12,32 @@ namespace CODE { class PolarFreezer { - static void freeze(uint8_t *bits, long double pe, long double th, int i, int h) + static bool get_bit(const uint32_t *bits, int idx) + { + return (bits[idx/32] >> (idx%32)) & 1; + } + static void set_bit(uint32_t *bits, int idx, bool val) + { + bits[idx/32] &= ~(1 << (idx%32)); + bits[idx/32] |= (uint32_t)val << (idx%32); + } + static void freeze(uint32_t *bits, long double pe, long double th, int i, int h) { if (h) { freeze(bits, pe * (2-pe), th, i, h/2); freeze(bits, pe * pe, th, i+h, h/2); } else { - bits[i] = pe > th; + set_bit(bits, i, pe > th); } } public: - int operator()(uint8_t *frozen_bits, int level, long double erasure_probability = 0.5L, long double freezing_threshold = 0.5L) + int operator()(uint32_t *frozen_bits, int level, long double erasure_probability = 0.5L, long double freezing_threshold = 0.5L) { int length = 1 << level; freeze(frozen_bits, erasure_probability, freezing_threshold, 0, length / 2); - int K = 0; + int K = length; for (int i = 0; i < length; ++i) - K += !frozen_bits[i]; + K -= (frozen_bits[i/32] >> (i%32)) & 1; return K; } }; @@ -36,6 +45,14 @@ public: template class PolarCodeConst0 { + static void inform_bit(uint32_t *bits, int idx) + { + bits[idx/32] &= ~(1 << (idx%32)); + } + static void frozen_bit(uint32_t *bits, int idx) + { + bits[idx/32] |= 1 << (idx%32); + } void compute(long double pe, int i, int h) { if (h) { @@ -48,7 +65,7 @@ class PolarCodeConst0 long double prob[1<::decode(metric, message, maps, count, hard, soft, frozen); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); - MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen+N/2); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen+N/2/32); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 6; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, const uint32_t *frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, frozen[0]); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen[1]); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 5; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 4; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 3; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 2; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); + for (int i = 0; i < N/2; ++i) + hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); + return vshuf(lmap, rmap); + } +}; + +template +struct PolarListTree +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static const int M = 1; + static const int N = 1 << M; + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) + { + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); + MAP lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); + for (int i = 0; i < N/2; ++i) + soft[i+N/2] = PH::madd(hard[i], vshuf(soft[i+N], lmap), vshuf(soft[i+N/2+N], lmap)); + MAP rmap = PolarListTree::decode(metric, message, maps, count, hard+N/2, soft, frozen >> (N/2)); for (int i = 0; i < N/2; ++i) hard[i] = PH::qmul(vshuf(hard[i], rmap), hard[i+N/2]); return vshuf(lmap, rmap); @@ -38,11 +170,11 @@ struct PolarListTree typedef PolarHelper PH; typedef typename PH::PATH PATH; typedef typename PH::MAP MAP; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, const uint8_t *frozen) + static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) { MAP map; TYPE hrd, sft = soft[1]; - if (*frozen) { + if (frozen) { for (int k = 0; k < TYPE::SIZE; ++k) if (sft.v[k] < 0) metric[k] -= sft.v[k]; @@ -80,6 +212,7 @@ struct PolarListTree template class PolarListDecoder { + static_assert(MAX_M >= 5 && MAX_M <= 29); typedef PolarHelper PH; typedef typename TYPE::value_type VALUE; typedef typename PH::PATH PATH; @@ -89,7 +222,7 @@ class PolarListDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint8_t *frozen, int level) + void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) { assert(level <= MAX_M); int count = 0; @@ -101,12 +234,7 @@ public: soft[length+i] = vdup(codeword[i]); switch (level) { - case 0: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; - case 1: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; - case 2: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; - case 3: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; - case 4: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; - case 5: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; + case 5: PolarListTree::decode(metric, message, maps, &count, hard, soft, *frozen); break; case 6: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; case 7: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; case 8: PolarListTree::decode(metric, message, maps, &count, hard, soft, frozen); break; diff --git a/tests/polar_list_regression_test.cc b/tests/polar_list_regression_test.cc index dd40282..24c37bc 100644 --- a/tests/polar_list_regression_test.cc +++ b/tests/polar_list_regression_test.cc @@ -17,6 +17,11 @@ Copyright 2020 Ahmet Inan #include "polar_encoder.hh" #include "polar_freezer.hh" +bool get_bit(const uint32_t *bits, int idx) +{ + return (bits[idx/32] >> (idx%32)) & 1; +} + int main() { const int M = 16; @@ -42,7 +47,7 @@ int main() typedef std::default_random_engine generator; typedef std::uniform_int_distribution distribution; auto data = std::bind(distribution(0, 1), generator(rd())); - auto frozen = new uint8_t[N]; + auto frozen = new uint32_t[(N+31)/32]; auto codeword = new code_type[N]; auto temp = new simd_type[N]; @@ -101,7 +106,7 @@ int main() CODE::PolarSysEnc sysenc; sysenc(codeword, message, frozen, M); for (int i = 0, j = 0; i < N; ++i) - if (!frozen[i]) + if (!get_bit(frozen, i)) assert(codeword[i] == message[j++]); } else { CODE::PolarEncoder encode; @@ -138,7 +143,7 @@ int main() CODE::PolarEncoder encode; encode(temp, decoded, frozen, M); for (int i = 0, j = 0; i < N; ++i) - if (!frozen[i]) + if (!get_bit(frozen, i)) decoded[j++] = temp[i]; } diff --git a/tests/polar_regression_test.cc b/tests/polar_regression_test.cc index 07974a4..bf0a6a8 100644 --- a/tests/polar_regression_test.cc +++ b/tests/polar_regression_test.cc @@ -17,6 +17,11 @@ Copyright 2020 Ahmet Inan #include "polar_encoder.hh" #include "polar_freezer.hh" +bool get_bit(const uint32_t *bits, int idx) +{ + return (bits[idx/32] >> (idx%32)) & 1; +} + int main() { const int M = 20; @@ -34,7 +39,7 @@ int main() typedef std::default_random_engine generator; typedef std::uniform_int_distribution distribution; auto data = std::bind(distribution(0, 1), generator(rd())); - auto frozen = new uint8_t[N]; + auto frozen = new uint32_t[(N+31)/32]; auto codeword = new code_type[N]; auto temp = new code_type[N]; @@ -92,7 +97,7 @@ int main() CODE::PolarSysEnc sysenc; sysenc(codeword, message, frozen, M); for (int i = 0, j = 0; i < N; ++i) - if (!frozen[i]) + if (!get_bit(frozen, i)) assert(codeword[i] == message[j++]); } else { CODE::PolarEncoder encode; @@ -129,7 +134,7 @@ int main() CODE::PolarEncoder encode; encode(temp, decoded, frozen, M); for (int i = 0, j = 0; i < N; ++i) - if (!frozen[i]) + if (!get_bit(frozen, i)) decoded[j++] = temp[i]; }