From 2f42744db7b26c9a8342e70fc4a9230d52efbe23 Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Fri, 9 Feb 2024 13:20:44 +0100 Subject: [PATCH] expose sorted rank instead of the metric --- polar_list_decoder.hh | 17 +++++++++++++++-- polar_parity_aided.hh | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/polar_list_decoder.hh b/polar_list_decoder.hh index b495925..d67dc9d 100644 --- a/polar_list_decoder.hh +++ b/polar_list_decoder.hh @@ -284,9 +284,10 @@ class PolarListDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) + void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level) { assert(level <= MAX_M); + PATH metric[TYPE::SIZE]; int count = 0; metric[0] = 0; for (int k = 1; k < TYPE::SIZE; ++k) @@ -311,7 +312,19 @@ public: default: assert(false); } - MAP acc = maps[count-1]; + int perm[TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + perm[k] = k; + std::sort(perm, perm + TYPE::SIZE, [metric](int a, int b) { return metric[a] < metric[b]; }); + for (int i = 0, r = 0; i < TYPE::SIZE; ++i) { + if (i > 0 && metric[perm[i-1]] != metric[perm[i]]) + ++r; + rank[i] = r; + } + MAP acc; + for (int k = 0; k < TYPE::SIZE; ++k) + acc.v[k] = perm[k]; + acc = vshuf(maps[count-1], acc); for (int i = count-2; i >= 0; --i) { message[i] = vshuf(message[i], acc); acc = vshuf(maps[i], acc); diff --git a/polar_parity_aided.hh b/polar_parity_aided.hh index 6149274..c91eda4 100644 --- a/polar_parity_aided.hh +++ b/polar_parity_aided.hh @@ -345,10 +345,11 @@ class PolarParityDecoder TYPE hard[MAX_N]; MAP maps[MAX_N]; public: - void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first) + void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first) { assert(level <= MAX_M); int index = 0; + PATH metric[TYPE::SIZE]; metric[0] = 0; for (int k = 1; k < TYPE::SIZE; ++k) metric[k] = 1000; @@ -374,7 +375,19 @@ public: default: assert(false); } - MAP acc = maps[index-1]; + int perm[TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + perm[k] = k; + std::sort(perm, perm + TYPE::SIZE, [metric](int a, int b) { return metric[a] < metric[b]; }); + for (int i = 0, r = 0; i < TYPE::SIZE; ++i) { + if (i > 0 && metric[perm[i-1]] != metric[perm[i]]) + ++r; + rank[i] = r; + } + MAP acc; + for (int k = 0; k < TYPE::SIZE; ++k) + acc.v[k] = perm[k]; + acc = vshuf(maps[index-1], acc); for (int i = index-2; i >= 0; --i) { message[i] = vshuf(message[i], acc); acc = vshuf(maps[i], acc);