diff --git a/polar_parity_aided.hh b/polar_parity_aided.hh index 4a7f31a..6149274 100644 --- a/polar_parity_aided.hh +++ b/polar_parity_aided.hh @@ -20,10 +20,10 @@ class PolarParityEncoder return (bits[idx/32] >> (idx%32)) & 1; } public: - void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level, int stride) + void operator()(TYPE *codeword, const TYPE *message, const uint32_t *frozen, int level, int stride, int first) { int length = 1 << level; - int count = stride; + int count = first; TYPE parity = PH::one(); for (int i = 0; i < length; i += 2) { TYPE msg0, msg1; @@ -345,7 +345,7 @@ class PolarParityDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride) + void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first) { assert(level <= MAX_M); int index = 0; @@ -356,7 +356,7 @@ public: for (int i = 0; i < length; ++i) soft[length+i] = vdup(codeword[i]); TYPE parity = PH::one(); - int count = stride; + int count = first; switch (level) { case 5: PolarParityTree::decode(metric, message, maps, &index, hard, soft, *frozen, &parity, &count, stride); break; diff --git a/tests/polar_list_regression_test.cc b/tests/polar_list_regression_test.cc index 5cf2cff..3bfb4f3 100644 --- a/tests/polar_list_regression_test.cc +++ b/tests/polar_list_regression_test.cc @@ -84,6 +84,9 @@ int main() for (int i = 0; i < N - K; ++i) frozen[reliability_sequence[i]/32] |= 1 << (reliability_sequence[i]%32); int P = K / (S + 1); + int F = K % (S + 1); + if (!crc_aided) + F += S; if (par_aided) K -= P; std::cerr << "Polar(" << N << ", " << K << ")" << std::endl; @@ -142,7 +145,7 @@ int main() assert(codeword[i] == message[j++]); } else if (par_aided) { CODE::PolarParityEncoder encode; - encode(codeword, message, frozen, M, S); + encode(codeword, message, frozen, M, S, F); } else { CODE::PolarEncoder encode; encode(codeword, message, frozen, M); @@ -169,7 +172,7 @@ int main() auto start = std::chrono::system_clock::now(); if (par_aided) - (*par_dec)(metric, decoded, codeword, frozen, M, S); + (*par_dec)(metric, decoded, codeword, frozen, M, S, F); else (*decode)(metric, decoded, codeword, frozen, M); auto end = std::chrono::system_clock::now();