expose sorted rank instead of the metric

This commit is contained in:
Ahmet Inan 2024-02-09 13:20:44 +01:00
commit 2f42744db7
2 changed files with 30 additions and 4 deletions

View file

@ -284,9 +284,10 @@ class PolarListDecoder
TYPE hard[MAX_N];
MAP maps[MAX_N];
public:
void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level)
void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level)
{
assert(level <= MAX_M);
PATH metric[TYPE::SIZE];
int count = 0;
metric[0] = 0;
for (int k = 1; k < TYPE::SIZE; ++k)
@ -311,7 +312,19 @@ public:
default: assert(false);
}
MAP acc = maps[count-1];
int perm[TYPE::SIZE];
for (int k = 0; k < TYPE::SIZE; ++k)
perm[k] = k;
std::sort(perm, perm + TYPE::SIZE, [metric](int a, int b) { return metric[a] < metric[b]; });
for (int i = 0, r = 0; i < TYPE::SIZE; ++i) {
if (i > 0 && metric[perm[i-1]] != metric[perm[i]])
++r;
rank[i] = r;
}
MAP acc;
for (int k = 0; k < TYPE::SIZE; ++k)
acc.v[k] = perm[k];
acc = vshuf(maps[count-1], acc);
for (int i = count-2; i >= 0; --i) {
message[i] = vshuf(message[i], acc);
acc = vshuf(maps[i], acc);

View file

@ -345,10 +345,11 @@ class PolarParityDecoder
TYPE hard[MAX_N];
MAP maps[MAX_N];
public:
void operator()(PATH *metric, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first)
void operator()(int *rank, TYPE *message, const VALUE *codeword, const uint32_t *frozen, int level, int stride, int first)
{
assert(level <= MAX_M);
int index = 0;
PATH metric[TYPE::SIZE];
metric[0] = 0;
for (int k = 1; k < TYPE::SIZE; ++k)
metric[k] = 1000;
@ -374,7 +375,19 @@ public:
default: assert(false);
}
MAP acc = maps[index-1];
int perm[TYPE::SIZE];
for (int k = 0; k < TYPE::SIZE; ++k)
perm[k] = k;
std::sort(perm, perm + TYPE::SIZE, [metric](int a, int b) { return metric[a] < metric[b]; });
for (int i = 0, r = 0; i < TYPE::SIZE; ++i) {
if (i > 0 && metric[perm[i-1]] != metric[perm[i]])
++r;
rank[i] = r;
}
MAP acc;
for (int k = 0; k < TYPE::SIZE; ++k)
acc.v[k] = perm[k];
acc = vshuf(maps[index-1], acc);
for (int i = index-2; i >= 0; --i) {
message[i] = vshuf(message[i], acc);
acc = vshuf(maps[i], acc);