diff --git a/prime_field.hh b/prime_field.hh new file mode 100644 index 0000000..88996bd --- /dev/null +++ b/prime_field.hh @@ -0,0 +1,145 @@ +/* +Prime field arithmetic + +Copyright 2024 Ahmet Inan +*/ + +#pragma once + +#include + +namespace CODE { + +template +struct PrimeField +{ + static_assert(std::is_unsigned::value, "TYPE must be unsigned"); + static_assert(std::numeric_limits::max() / (PRIME-1) >= (PRIME-1), "Type not wide enough"); + static constexpr TYPE P = PRIME; + TYPE v; + PrimeField() = default; + explicit PrimeField(TYPE v) : v(v) + { + } + PrimeField operator *= (PrimeField a) + { + return *this = *this * a; + } + PrimeField operator /= (PrimeField a) + { + return *this = *this / a; + } + PrimeField operator += (PrimeField a) + { + return *this = *this + a; + } + PrimeField operator -= (PrimeField a) + { + return *this = *this - a; + } + TYPE operator () () + { + return v; + } +}; + +template +PrimeField reduce(PrimeField a) +{ + return PrimeField(a.v % a.P); +} + +template +bool operator == (PrimeField a, PrimeField b) +{ + return a.v == b.v; +} + +template +bool operator != (PrimeField a, PrimeField b) +{ + return a.v != b.v; +} + +template +PrimeField add(PrimeField a, PrimeField b) +{ + return PrimeField(a.v + b.v); +} + +template +PrimeField operator + (PrimeField a, PrimeField b) +{ + return reduce(add(a, b)); +} + +template +PrimeField sub(PrimeField a, PrimeField b) +{ + return PrimeField(a.v - b.v + a.P); +} + +template +PrimeField operator - (PrimeField a, PrimeField b) +{ + return reduce(sub(a, b)); +} + +template +PrimeField neg(PrimeField a) +{ + return PrimeField(a.P - a.v); +} + +template +PrimeField operator - (PrimeField a) +{ + return reduce(neg(a)); +} + +template +PrimeField mul(PrimeField a, PrimeField b) +{ + return PrimeField(a.v * b.v); +} + +template +PrimeField operator * (PrimeField a, PrimeField b) +{ + return reduce(mul(a, b)); +} + +template +PrimeField rcp(PrimeField a) +{ + assert(a.v); + if (a.v == 1) + return a; + TYPE t = 0, newt = 1; + TYPE r = a.P, newr = a.v; + while (newr) { + TYPE quotient = r / newr; + t -= quotient * newt; + r -= quotient * newr; + std::swap(newt, t); + std::swap(newr, r); + } + assert(r == 1); + if (t >= a.P) + t += a.P; + return PrimeField(t); +} + +template +PrimeField div(PrimeField a, PrimeField b) +{ + return mul(a, rcp(b)); +} + +template +PrimeField operator / (PrimeField a, PrimeField b) +{ + return reduce(div(a, b)); +} + +} diff --git a/tests/pf_test.cc b/tests/pf_test.cc new file mode 100644 index 0000000..ee38696 --- /dev/null +++ b/tests/pf_test.cc @@ -0,0 +1,36 @@ +/* +Test for the prime field arithmetic + +Copyright 2024 Ahmet Inan +*/ + +#include +#include +#include +#include "prime_field.hh" + +template +void exhaustive_test() +{ + typedef CODE::PrimeField PF; + for (TYPE a = 0; a < PRIME; ++a) + for (TYPE b = 0; b < PRIME; ++b) + assert((PF(a) * PF(b))() == (a * b) % PRIME); + for (TYPE a = 1; a < PRIME; ++a) + assert(rcp(PF(a)) * PF(a) == PF(1)); + for (TYPE a = 0; a < PRIME; ++a) + for (TYPE b = 0; b < PRIME; ++b) + assert((PF(a) + PF(b))() == (a + b) % PRIME); + for (TYPE a = 0; a < PRIME; ++a) + for (TYPE b = 0; b < PRIME; ++b) + assert((PF(a) - PF(b))() == (a - b + PRIME) % PRIME); +} + +int main() +{ + exhaustive_test(); + exhaustive_test(); + std::cerr << "Prime field arithmetic test passed!" << std::endl; + return 0; +} +