replaced nth_element with insertion sort

This commit is contained in:
Ahmet Inan 2024-03-06 09:30:29 +01:00
commit d4b836fb4e

52
osd.hh
View file

@ -243,10 +243,10 @@ class OrderedStatisticsListDecoder
static const int S = sizeof(size_t);
static const int W = (N+S-1) & ~(S-1);
int8_t G[W*K];
int8_t codeword[W], candidate[W*2*L];
int8_t codeword[W], candidate[W*L];
int8_t softperm[W];
int16_t perm[W];
int score[2*L], scoreperm[2*L];
int score[L], cperm[L];
void row_echelon()
{
for (int k = 0; k < K; ++k) {
@ -302,13 +302,29 @@ class OrderedStatisticsListDecoder
for (int i = 0; i < W; ++i)
codeword[i] ^= G[W*j+i];
}
static int metric(const int8_t *hard, const int8_t *soft)
int metric()
{
int sum = 0;
for (int i = 0; i < W; ++i)
sum += (1 - 2 * hard[i]) * soft[i];
sum += (1 - 2 * codeword[i]) * softperm[i];
return sum;
}
void update()
{
int j = L-1;
int met = metric();
if (met <= score[j])
return;
int pos = cperm[j];
for (int i = 0; i < W; ++i)
candidate[pos*W+i] = codeword[i];
for (; j > 0 && met > score[j-1]; --j) {
score[j] = score[j-1];
cperm[j] = cperm[j-1];
}
score[j] = met;
cperm[j] = pos;
}
public:
void operator()(int *rank, uint8_t *hard, const int8_t *soft, const int8_t *genmat)
{
@ -331,24 +347,11 @@ public:
encode();
for (int i = 0; i < W; ++i)
candidate[i] = codeword[i];
score[0] = metric(codeword, softperm);
int count = 1;
int worst = -1;
for (int i = 0; i < 2*L; ++i)
scoreperm[i] = i;
auto update = [this, &count, &worst]() {
int met = metric(codeword, softperm);
if (met > worst) {
score[scoreperm[count]] = met;
for (int i = 0; i < W; ++i)
candidate[scoreperm[count]*W+i] = codeword[i];
if (++count >= 2*L) {
std::nth_element(scoreperm, scoreperm+(L-1), scoreperm+2*L, [this](int a, int b){ return score[a] > score[b]; });
worst = score[scoreperm[L-1]];
count = L;
}
}
};
score[0] = metric();
for (int i = 1; i < L; ++i)
score[i] = -1;
for (int i = 0; i < L; ++i)
cperm[i] = i;
for (int a = 0; O >= 1 && a < K; ++a) {
flip(a);
update();
@ -379,13 +382,12 @@ public:
}
flip(a);
}
std::sort(scoreperm, scoreperm+count, [this](int a, int b){ return score[a] > score[b]; });
for (int j = 0, r = 0; j < L; ++j) {
if (j > 0 && score[scoreperm[j-1]] != score[scoreperm[j]])
if (j > 0 && score[j-1] != score[j])
++r;
rank[j] = r;
for (int i = 0; i < N; ++i)
set_be_bit(hard+j*((N+7)/8), perm[i], candidate[scoreperm[j]*W+i]);
set_be_bit(hard+j*((N+7)/8), perm[i], candidate[cperm[j]*W+i]);
}
}
};