diff --git a/polar_list_decoder.hh b/polar_list_decoder.hh index b495925..d67dc9d 100644 --- a/polar_list_decoder.hh +++ b/polar_list_decoder.hh @@ -284,9 +284,10 @@ class PolarListDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) + void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) { assert(level <= MAX_M); + PATH metric[TYPE::SIZE]; int count = 0; metric[0] = 0; for (int k = 1; k < TYPE::SIZE; ++k) @@ -311,7 +312,19 @@ public: default: assert(false); } - MAP acc = maps[count-1]; + int perm[TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + perm[k] = k; + std::sort(perm, perm + TYPE::SIZE, [metric](int a, int b) { return metric[a] < metric[b]; }); + for (int i = 0, r = 0; i < TYPE::SIZE; ++i) { + if (i > 0 && metric[perm[i-1]] != metric[perm[i]]) + ++r; + rank[i] = r; + } + MAP acc; + for (int k = 0; k < TYPE::SIZE; ++k) + acc.v[k] = perm[k]; + acc = vshuf(maps[count-1], acc); for (int i = count-2; i >= 0; --i) { message[i] = vshuf(message[i], acc); acc = vshuf(maps[i], acc); diff --git a/polar_parity_aided.hh b/polar_parity_aided.hh index 6149274..c91eda4 100644 --- a/polar_parity_aided.hh +++ b/polar_parity_aided.hh @@ -345,10 +345,11 @@ class PolarParityDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first) + void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first) { assert(level <= MAX_M); int index = 0; + PATH metric[TYPE::SIZE]; metric[0] = 0; for (int k = 1; k < TYPE::SIZE; ++k) metric[k] = 1000; @@ -374,7 +375,19 @@ public: default: assert(false); } - MAP acc = maps[index-1]; + int perm[TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + perm[k] = k; + std::sort(perm, perm + TYPE::SIZE, [metric](int a, int b) { return metric[a] < metric[b]; }); + for (int i = 0, r = 0; i < TYPE::SIZE; ++i) { + if (i > 0 && metric[perm[i-1]] != metric[perm[i]]) + ++r; + rank[i] = r; + } + MAP acc; + for (int k = 0; k < TYPE::SIZE; ++k) + acc.v[k] = perm[k]; + acc = vshuf(maps[index-1], acc); for (int i = index-2; i >= 0; --i) { message[i] = vshuf(message[i], acc); acc = vshuf(maps[i], acc);