diff --git a/README.md b/README.md index 4fbcdcc..67da554 100644 --- a/README.md +++ b/README.md @@ -219,3 +219,9 @@ Sometimes we only need [trigonometric functions](https://en.wikipedia.org/wiki/T * [sine function](https://en.wikipedia.org/wiki/Sine) * [cosine function](https://en.wikipedia.org/wiki/Trigonometric_functions#cosine) +### [lms.hh](lms.hh) + +Some [least mean squares filter](https://en.wikipedia.org/wiki/Least_mean_squares_filter) implementations: +* Normalized Least Mean Squares +* Normalized Complex Least Mean Squares + diff --git a/lms.hh b/lms.hh new file mode 100644 index 0000000..1a9978d --- /dev/null +++ b/lms.hh @@ -0,0 +1,83 @@ +/* +Least mean squares filter + +Copyright 2020 Ahmet Inan +*/ + +#pragma once + +#include "bip_buffer.hh" + +namespace DSP { + +template +class NLMS +{ + BipBuffer history; + TYPE filter[SIZE]; +public: + NLMS() + { + for (int i = 0; i < SIZE; ++i) + filter[i] = 0; + } + inline operator const TYPE * () const + { + return filter; + } + TYPE operator () (TYPE input, TYPE desired, TYPE mu = 1) + { + const TYPE *hist = history(input); + TYPE estimate = 0; + for (int i = 0; i < SIZE; ++i) + estimate += filter[i] * hist[i]; + TYPE error = desired - estimate; + TYPE power = 0; + for (int i = 0; i < SIZE; ++i) + power += hist[i] * hist[i]; + if (power == 0) + return error; + TYPE rate = mu / power; + for (int i = 0; i < SIZE; ++i) + filter[i] += rate * error * hist[i]; + return error; + } +}; + +template +class NCLMS +{ + typedef typename CMPLX::value_type VALUE; + BipBuffer history; + CMPLX filter[SIZE]; +public: + NCLMS() + { + for (int i = 0; i < SIZE; ++i) + filter[i] = 0; + } + inline operator const CMPLX * () const + { + return filter; + } + CMPLX operator () (CMPLX input, CMPLX desired, VALUE mu = 1) + { + const CMPLX *hist = history(conj(input)); + CMPLX estimate = 0; + for (int i = 0; i < SIZE; ++i) + estimate += conj(filter[i]) * hist[i]; + CMPLX error = conj(conj(desired) - estimate); + VALUE power = 0; + for (int i = 0; i < SIZE; ++i) + power += norm(hist[i]); + if (power == 0) + return error; + VALUE rate = mu / power; + for (int i = 0; i < SIZE; ++i) + filter[i] += rate * error * hist[i]; + return error; + } +}; + +} +