diff --git a/polar_decoder.hh b/polar_decoder.hh index 9cfbe92..c50ad86 100644 --- a/polar_decoder.hh +++ b/polar_decoder.hh @@ -16,11 +16,44 @@ struct PolarNode { typedef PolarHelper PH; static const int N = 1 << M; + static void trans(TYPE *out, const TYPE *inp) + { + for (int i = 0; i < N; i += 2) { + out[i] = PH::qmul(inp[i], inp[i+1]); + out[i+1] = inp[i+1]; + } + for (int h = 2; h < N; h *= 2) + for (int i = 0; i < N; i += 2 * h) + for (int j = i; j < i + h; ++j) + out[j] = PH::qmul(out[j], out[j+h]); + } static void rate0(TYPE *hard) { for (int i = 0; i < N; ++i) hard[i] = PH::one(); } + static void rate1(TYPE **message, TYPE *hard, TYPE *soft) + { + for (int i = 0; i < N; ++i) + hard[i] = PH::signum(soft[i+N]); + trans(*message, hard); + *message += N; + } +}; + +template +struct PolarNode +{ + typedef PolarHelper PH; + static void rate0(TYPE *hard) + { + *hard = PH::one(); + } + static void rate1(TYPE **message, TYPE *hard, TYPE *soft) + { + *hard = PH::signum(soft[1]); + *(*message)++ = *hard; + } }; template @@ -53,12 +86,16 @@ struct PolarTree soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); if (frozen[0] == 0xffffffff) PolarNode::rate0(hard); + else if (frozen[0] == 0) + PolarNode::rate1(message, hard, soft); else PolarTree::decode(message, hard, soft, frozen[0]); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); if (frozen[1] == 0xffffffff) PolarNode::rate0(hard+N/2); + else if (frozen[1] == 0) + PolarNode::rate1(message, hard+N/2, soft); else PolarTree::decode(message, hard+N/2, soft, frozen[1]); for (int i = 0; i < N/2; ++i) @@ -78,12 +115,16 @@ struct PolarTree soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard); + else if ((frozen & ((1<<(1<<(M-1)))-1)) == 0) + PolarNode::rate1(message, hard, soft); else PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard+N/2); + else if (frozen >> (N/2) == 0) + PolarNode::rate1(message, hard+N/2, soft); else PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); for (int i = 0; i < N/2; ++i) @@ -103,12 +144,16 @@ struct PolarTree soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard); + else if ((frozen & ((1<<(1<<(M-1)))-1)) == 0) + PolarNode::rate1(message, hard, soft); else PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard+N/2); + else if (frozen >> (N/2) == 0) + PolarNode::rate1(message, hard+N/2, soft); else PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); for (int i = 0; i < N/2; ++i) @@ -128,12 +173,16 @@ struct PolarTree soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard); + else if ((frozen & ((1<<(1<<(M-1)))-1)) == 0) + PolarNode::rate1(message, hard, soft); else PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard+N/2); + else if (frozen >> (N/2) == 0) + PolarNode::rate1(message, hard+N/2, soft); else PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); for (int i = 0; i < N/2; ++i) @@ -153,12 +202,16 @@ struct PolarTree soft[i+N/2] = PH::prod(soft[i+N], soft[i+N/2+N]); if ((frozen & ((1<<(1<<(M-1)))-1)) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard); + else if ((frozen & ((1<<(1<<(M-1)))-1)) == 0) + PolarNode::rate1(message, hard, soft); else PolarTree::decode(message, hard, soft, frozen & ((1<<(1<<(M-1)))-1)); for (int i = 0; i < N/2; ++i) soft[i+N/2] = PH::madd(hard[i], soft[i+N], soft[i+N/2+N]); if (frozen >> (N/2) == ((1<<(1<<(M-1)))-1)) PolarNode::rate0(hard+N/2); + else if (frozen >> (N/2) == 0) + PolarNode::rate1(message, hard+N/2, soft); else PolarTree::decode(message, hard+N/2, soft, frozen >> (N/2)); for (int i = 0; i < N/2; ++i) @@ -176,31 +229,16 @@ struct PolarTree if (frozen & 1) PolarNode::rate0(hard); else - PolarTree::decode(message, hard, soft, 0); + PolarNode::rate1(message, hard, soft); soft[1] = PH::madd(hard[0], soft[2], soft[3]); if (frozen >> 1) PolarNode::rate0(hard+1); else - PolarTree::decode(message, hard+1, soft, 0); + PolarNode::rate1(message, hard+1, soft); hard[0] = PH::qmul(hard[0], hard[1]); } }; -template -struct PolarTree -{ - typedef PolarHelper PH; - static void decode(TYPE **message, TYPE *hard, TYPE *soft, uint32_t frozen) - { - if (frozen) { - *hard = PH::one(); - } else { - *hard = PH::signum(soft[1]); - *(*message)++ = *hard; - } - } -}; - template class PolarDecoder { diff --git a/polar_list_decoder.hh b/polar_list_decoder.hh index 97f1f20..a558c72 100644 --- a/polar_list_decoder.hh +++ b/polar_list_decoder.hh @@ -33,6 +33,54 @@ struct PolarListNode } }; +template +struct PolarListNode +{ + typedef PolarHelper PH; + typedef typename PH::PATH PATH; + typedef typename PH::MAP MAP; + static MAP rate0(PATH *metric, TYPE *hard, TYPE *soft) + { + *hard = PH::one(); + for (int k = 0; k < TYPE::SIZE; ++k) + if (soft[1].v[k] < 0) + metric[k] -= soft[1].v[k]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = k; + return map; + } + static MAP rate1(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft) + { + TYPE sft = soft[1]; + PATH fork[2*TYPE::SIZE]; + for (int k = 0; k < TYPE::SIZE; ++k) + fork[k] = fork[k+TYPE::SIZE] = metric[k]; + for (int k = 0; k < TYPE::SIZE; ++k) + if (sft.v[k] < 0) + fork[k] -= sft.v[k]; + else + fork[k+TYPE::SIZE] += sft.v[k]; + int perm[2*TYPE::SIZE]; + for (int k = 0; k < 2*TYPE::SIZE; ++k) + perm[k] = k; + std::nth_element(perm, perm+TYPE::SIZE, perm+2*TYPE::SIZE, [fork](int a, int b){ return fork[a] < fork[b]; }); + for (int k = 0; k < TYPE::SIZE; ++k) + metric[k] = fork[perm[k]]; + MAP map; + for (int k = 0; k < TYPE::SIZE; ++k) + map.v[k] = perm[k] % TYPE::SIZE; + TYPE hrd; + for (int k = 0; k < TYPE::SIZE; ++k) + hrd.v[k] = perm[k] < TYPE::SIZE ? 1 : -1; + message[*count] = hrd; + maps[*count] = map; + ++*count; + *hard = hrd; + return map; + } +}; + template struct PolarListTree { @@ -212,62 +260,17 @@ struct PolarListTree if (frozen & 1) lmap = PolarListNode::rate0(metric, hard, soft); else - lmap = PolarListTree::decode(metric, message, maps, count, hard, soft, 0); + lmap = PolarListNode::rate1(metric, message, maps, count, hard, soft); soft[1] = PH::madd(hard[0], vshuf(soft[2], lmap), vshuf(soft[3], lmap)); if (frozen >> 1) rmap = PolarListNode::rate0(metric, hard+1, soft); else - rmap = PolarListTree::decode(metric, message, maps, count, hard+1, soft, 0); + rmap = PolarListNode::rate1(metric, message, maps, count, hard+1, soft); hard[0] = PH::qmul(vshuf(hard[0], rmap), hard[1]); return vshuf(lmap, rmap); } }; -template -struct PolarListTree -{ - typedef PolarHelper PH; - typedef typename PH::PATH PATH; - typedef typename PH::MAP MAP; - static MAP decode(PATH *metric, TYPE *message, MAP *maps, int *count, TYPE *hard, TYPE *soft, uint32_t frozen) - { - MAP map; - TYPE hrd, sft = soft[1]; - if (frozen) { - for (int k = 0; k < TYPE::SIZE; ++k) - if (sft.v[k] < 0) - metric[k] -= sft.v[k]; - hrd = PH::one(); - for (int k = 0; k < TYPE::SIZE; ++k) - map.v[k] = k; - } else { - PATH fork[2*TYPE::SIZE]; - for (int k = 0; k < TYPE::SIZE; ++k) - fork[k] = fork[k+TYPE::SIZE] = metric[k]; - for (int k = 0; k < TYPE::SIZE; ++k) - if (sft.v[k] < 0) - fork[k] -= sft.v[k]; - else - fork[k+TYPE::SIZE] += sft.v[k]; - int perm[2*TYPE::SIZE]; - for (int k = 0; k < 2*TYPE::SIZE; ++k) - perm[k] = k; - std::nth_element(perm, perm+TYPE::SIZE, perm+2*TYPE::SIZE, [fork](int a, int b){ return fork[a] < fork[b]; }); - for (int k = 0; k < TYPE::SIZE; ++k) - metric[k] = fork[perm[k]]; - for (int k = 0; k < TYPE::SIZE; ++k) - map.v[k] = perm[k] % TYPE::SIZE; - for (int k = 0; k < TYPE::SIZE; ++k) - hrd.v[k] = perm[k] < TYPE::SIZE ? 1 : -1; - message[*count] = hrd; - maps[*count] = map; - ++*count; - } - *hard = hrd; - return map; - } -}; - template class PolarListDecoder {