added support for int16_t

This commit is contained in:
Ahmet Inan 2025-07-01 22:39:07 +02:00
commit e55a701789

View file

@ -153,6 +153,58 @@ struct PolarHelper<SIMD<int8_t, WIDTH>>
}
};
template <int WIDTH>
struct PolarHelper<SIMD<int16_t, WIDTH>>
{
typedef SIMD<int16_t, WIDTH> TYPE;
typedef int PATH;
typedef SIMD<uint16_t, WIDTH> MAP;
static TYPE one()
{
return vdup<TYPE>(1);
}
static TYPE zero()
{
return vzero<TYPE>();
}
static TYPE signum(TYPE a)
{
return vsignum(a);
}
static TYPE qabs(TYPE a)
{
return vqabs(a);
}
static TYPE qadd(TYPE a, TYPE b)
{
return vqadd(a, b);
}
static TYPE qmul(TYPE a, TYPE b)
{
#ifdef __ARM_NEON
return vmul(a, b);
#else
return vsign(a, b);
#endif
}
static TYPE prod(TYPE a, TYPE b)
{
#ifdef __ARM_NEON
return vmul(vmul(vsignum(a), vsignum(b)), vmin(vqabs(a), vqabs(b)));
#else
return vsign(vmin(vqabs(a), vqabs(b)), vsign(vsignum(a), b));
#endif
}
static TYPE madd(TYPE a, TYPE b, TYPE c)
{
#ifdef __ARM_NEON
return vmax(vqadd(vmul(a, vmax(b, vdup<TYPE>(-32767))), c), vdup<TYPE>(-32767));
#else
return vmax(vqadd(vsign(vmax(b, vdup<TYPE>(-32767)), a), c), vdup<TYPE>(-32767));
#endif
}
};
template <>
struct PolarHelper<int8_t>
{
@ -202,5 +254,54 @@ struct PolarHelper<int8_t>
}
};
template <>
struct PolarHelper<int16_t>
{
typedef int PATH;
static int16_t one()
{
return 1;
}
static int16_t zero()
{
return 0;
}
static int16_t signum(int16_t v)
{
return (v > 0) - (v < 0);
}
template <typename IN>
static int16_t quant(IN in)
{
return std::min<IN>(std::max<IN>(std::nearbyint(in), -32767), 32767);
}
static int16_t qabs(int16_t a)
{
return std::abs(std::max<int16_t>(a, -32767));
}
static int16_t qmin(int16_t a, int16_t b)
{
return std::min(a, b);
}
static int16_t qadd(int16_t a, int16_t b)
{
return std::min<int32_t>(std::max<int32_t>(int32_t(a) + int32_t(b), -32767), 32767);
}
static int16_t qmul(int16_t a, int16_t b)
{
// return std::min<int32_t>(std::max<int32_t>(int32_t(a) * int32_t(b), -32767), 32767);
// only used for hard decision values anyway
return a * b;
}
static int16_t prod(int16_t a, int16_t b)
{
return signum(a) * signum(b) * qmin(qabs(a), qabs(b));
}
static int16_t madd(int16_t a, int16_t b, int16_t c)
{
return std::min<int32_t>(std::max<int32_t>(int32_t(a) * int32_t(b) + int32_t(c), -32767), 32767);
}
};
}