From d4b836fb4ebb80359ec99b75e7a4ab4e130c657b Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Wed, 6 Mar 2024 09:30:29 +0100 Subject: [PATCH] replaced nth_element with insertion sort --- osd.hh | 52 +++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/osd.hh b/osd.hh index 91b9dcf..27bc9b3 100644 --- a/osd.hh +++ b/osd.hh @@ -243,10 +243,10 @@ class OrderedStatisticsListDecoder static const int S = sizeof(size_t); static const int W = (N+S-1) & ~(S-1); int8_t G[W*K]; - int8_t codeword[W], candidate[W*2*L]; + int8_t codeword[W], candidate[W*L]; int8_t softperm[W]; int16_t perm[W]; - int score[2*L], scoreperm[2*L]; + int score[L], cperm[L]; void row_echelon() { for (int k = 0; k < K; ++k) { @@ -302,13 +302,29 @@ class OrderedStatisticsListDecoder for (int i = 0; i < W; ++i) codeword[i] ^= G[W*j+i]; } - static int metric(const int8_t *hard, const int8_t *soft) + int metric() { int sum = 0; for (int i = 0; i < W; ++i) - sum += (1 - 2 * hard[i]) * soft[i]; + sum += (1 - 2 * codeword[i]) * softperm[i]; return sum; } + void update() + { + int j = L-1; + int met = metric(); + if (met <= score[j]) + return; + int pos = cperm[j]; + for (int i = 0; i < W; ++i) + candidate[pos*W+i] = codeword[i]; + for (; j > 0 && met > score[j-1]; --j) { + score[j] = score[j-1]; + cperm[j] = cperm[j-1]; + } + score[j] = met; + cperm[j] = pos; + } public: void operator()(int *rank, uint8_t *hard, const int8_t *soft, const int8_t *genmat) { @@ -331,24 +347,11 @@ public: encode(); for (int i = 0; i < W; ++i) candidate[i] = codeword[i]; - score[0] = metric(codeword, softperm); - int count = 1; - int worst = -1; - for (int i = 0; i < 2*L; ++i) - scoreperm[i] = i; - auto update = [this, &count, &worst]() { - int met = metric(codeword, softperm); - if (met > worst) { - score[scoreperm[count]] = met; - for (int i = 0; i < W; ++i) - candidate[scoreperm[count]*W+i] = codeword[i]; - if (++count >= 2*L) { - std::nth_element(scoreperm, scoreperm+(L-1), scoreperm+2*L, [this](int a, int b){ return score[a] > score[b]; }); - worst = score[scoreperm[L-1]]; - count = L; - } - } - }; + score[0] = metric(); + for (int i = 1; i < L; ++i) + score[i] = -1; + for (int i = 0; i < L; ++i) + cperm[i] = i; for (int a = 0; O >= 1 && a < K; ++a) { flip(a); update(); @@ -379,13 +382,12 @@ public: } flip(a); } - std::sort(scoreperm, scoreperm+count, [this](int a, int b){ return score[a] > score[b]; }); for (int j = 0, r = 0; j < L; ++j) { - if (j > 0 && score[scoreperm[j-1]] != score[scoreperm[j]]) + if (j > 0 && score[j-1] != score[j]) ++r; rank[j] = r; for (int i = 0; i < N; ++i) - set_be_bit(hard+j*((N+7)/8), perm[i], candidate[scoreperm[j]*W+i]); + set_be_bit(hard+j*((N+7)/8), perm[i], candidate[cperm[j]*W+i]); } } };