replaced nth_element with insertion sort

This commit is contained in:
Ahmet Inan 2024-03-06 09:30:29 +01:00
commit d4b836fb4e

52
osd.hh
View file

@ -243,10 +243,10 @@ class OrderedStatisticsListDecoder
static const int S = sizeof(size_t); static const int S = sizeof(size_t);
static const int W = (N+S-1) & ~(S-1); static const int W = (N+S-1) & ~(S-1);
int8_t G[W*K]; int8_t G[W*K];
int8_t codeword[W], candidate[W*2*L]; int8_t codeword[W], candidate[W*L];
int8_t softperm[W]; int8_t softperm[W];
int16_t perm[W]; int16_t perm[W];
int score[2*L], scoreperm[2*L]; int score[L], cperm[L];
void row_echelon() void row_echelon()
{ {
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
@ -302,13 +302,29 @@ class OrderedStatisticsListDecoder
for (int i = 0; i < W; ++i) for (int i = 0; i < W; ++i)
codeword[i] ^= G[W*j+i]; codeword[i] ^= G[W*j+i];
} }
static int metric(const int8_t *hard, const int8_t *soft) int metric()
{ {
int sum = 0; int sum = 0;
for (int i = 0; i < W; ++i) for (int i = 0; i < W; ++i)
sum += (1 - 2 * hard[i]) * soft[i]; sum += (1 - 2 * codeword[i]) * softperm[i];
return sum; return sum;
} }
void update()
{
int j = L-1;
int met = metric();
if (met <= score[j])
return;
int pos = cperm[j];
for (int i = 0; i < W; ++i)
candidate[pos*W+i] = codeword[i];
for (; j > 0 && met > score[j-1]; --j) {
score[j] = score[j-1];
cperm[j] = cperm[j-1];
}
score[j] = met;
cperm[j] = pos;
}
public: public:
void operator()(int *rank, uint8_t *hard, const int8_t *soft, const int8_t *genmat) void operator()(int *rank, uint8_t *hard, const int8_t *soft, const int8_t *genmat)
{ {
@ -331,24 +347,11 @@ public:
encode(); encode();
for (int i = 0; i < W; ++i) for (int i = 0; i < W; ++i)
candidate[i] = codeword[i]; candidate[i] = codeword[i];
score[0] = metric(codeword, softperm); score[0] = metric();
int count = 1; for (int i = 1; i < L; ++i)
int worst = -1; score[i] = -1;
for (int i = 0; i < 2*L; ++i) for (int i = 0; i < L; ++i)
scoreperm[i] = i; cperm[i] = i;
auto update = [this, &count, &worst]() {
int met = metric(codeword, softperm);
if (met > worst) {
score[scoreperm[count]] = met;
for (int i = 0; i < W; ++i)
candidate[scoreperm[count]*W+i] = codeword[i];
if (++count >= 2*L) {
std::nth_element(scoreperm, scoreperm+(L-1), scoreperm+2*L, [this](int a, int b){ return score[a] > score[b]; });
worst = score[scoreperm[L-1]];
count = L;
}
}
};
for (int a = 0; O >= 1 && a < K; ++a) { for (int a = 0; O >= 1 && a < K; ++a) {
flip(a); flip(a);
update(); update();
@ -379,13 +382,12 @@ public:
} }
flip(a); flip(a);
} }
std::sort(scoreperm, scoreperm+count, [this](int a, int b){ return score[a] > score[b]; });
for (int j = 0, r = 0; j < L; ++j) { for (int j = 0, r = 0; j < L; ++j) {
if (j > 0 && score[scoreperm[j-1]] != score[scoreperm[j]]) if (j > 0 && score[j-1] != score[j])
++r; ++r;
rank[j] = r; rank[j] = r;
for (int i = 0; i < N; ++i) for (int i = 0; i < N; ++i)
set_be_bit(hard+j*((N+7)/8), perm[i], candidate[scoreperm[j]*W+i]); set_be_bit(hard+j*((N+7)/8), perm[i], candidate[cperm[j]*W+i]);
} }
} }
}; };