/* Prime field arithmetic Copyright 2024 Ahmet Inan */ #pragma once namespace CODE { template struct PrimeField { static constexpr TYPE P = PRIME; TYPE v; PrimeField() = default; explicit PrimeField(TYPE v) : v(v) { } PrimeField operator *= (PrimeField a) { return *this = *this * a; } PrimeField operator /= (PrimeField a) { return *this = *this / a; } PrimeField operator += (PrimeField a) { return *this = *this + a; } PrimeField operator -= (PrimeField a) { return *this = *this - a; } TYPE operator () () { return v; } }; template PrimeField reduce(PrimeField a) { return PrimeField(a.v % a.P); } template bool operator == (PrimeField a, PrimeField b) { return a.v == b.v; } template bool operator != (PrimeField a, PrimeField b) { return a.v != b.v; } template PrimeField add(PrimeField a, PrimeField b) { return PrimeField(a.v + b.v); } template PrimeField operator + (PrimeField a, PrimeField b) { return reduce(add(a, b)); } template PrimeField sub(PrimeField a, PrimeField b) { return PrimeField(a.v - b.v + a.P); } template PrimeField operator - (PrimeField a, PrimeField b) { return reduce(sub(a, b)); } template PrimeField neg(PrimeField a) { return PrimeField(a.P - a.v); } template PrimeField operator - (PrimeField a) { return reduce(neg(a)); } template PrimeField mul(PrimeField a, PrimeField b) { return PrimeField(a.v * b.v); } template PrimeField operator * (PrimeField a, PrimeField b) { return reduce(mul(a, b)); } template PrimeField pow(PrimeField a, TYPE m) { PrimeField t(1); for (;m; m >>= 1, a *= a) if (m & 1) t *= a; return t; } template PrimeField rcp(PrimeField a) { assert(a.v); #if 1 return pow(a, a.P - 2); #else if (a.v == 1) return a; TYPE t = 0, newt = 1; TYPE r = a.P, newr = a.v; while (newr) { TYPE quotient = r / newr; t -= quotient * newt; r -= quotient * newr; std::swap(newt, t); std::swap(newr, r); } assert(r == 1); if (t >= a.P) t += a.P; return PrimeField(t); #endif } template PrimeField div(PrimeField a, PrimeField b) { return mul(a, rcp(b)); } template PrimeField operator / (PrimeField a, PrimeField b) { return reduce(div(a, b)); } }