From 389fcc3bd8884c8b064910ba787982ede40c9cd8 Mon Sep 17 00:00:00 2001 From: Ahmet Inan Date: Fri, 8 Mar 2024 13:00:40 +0100 Subject: [PATCH] added comp argument and test for stability --- sort.hh | 31 ++++++++++++++++++++++++------- tests/sort_regression_test.cc | 24 +++++++++++++----------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/sort.hh b/sort.hh index 0db41d2..8742e6c 100644 --- a/sort.hh +++ b/sort.hh @@ -13,7 +13,18 @@ static void insertion_sort(TYPE *a, int n) { for (int i = 1, j; i < n; ++i) { TYPE t = a[i]; - for (j = i; j > 0 && a[j-1] > t; --j) + for (j = i; j > 0 && t < a[j-1]; --j) + a[j] = a[j-1]; + a[j] = t; + } +} + +template +static void insertion_sort(TYPE *a, int n, COMP comp) +{ + for (int i = 1, j; i < n; ++i) { + TYPE t = a[i]; + for (j = i; j > 0 && comp(t, a[j-1]); --j) a[j] = a[j-1]; a[j] = t; } @@ -25,7 +36,7 @@ static void insertion_sort(INDEX *p, TYPE *a, int n) p[0] = 0; for (int i = 1, j; i < n; ++i) { TYPE t = a[i]; - for (j = i; j > 0 && a[j-1] > t; --j) { + for (j = i; j > 0 && t < a[j-1]; --j) { a[j] = a[j-1]; p[j] = p[j-1]; } @@ -38,27 +49,33 @@ template class MergeSort { TYPE tmp[MAX_N]; - void merge(TYPE *a, int n, int left, int right, int end) + template + void merge(TYPE *a, int n, int left, int right, int end, COMP comp) { if (right > n) right = n; if (end > n) end = n; for (int i = left, j = right, k = left; k < end; ++k) - tmp[k] = (i < right && (j >= end || a[i] <= a[j])) ? a[i++] : a[j++]; + tmp[k] = (i >= right || (j < end && comp(a[j], a[i]))) ? a[j++] : a[i++]; } public: - void operator()(TYPE *a, int n) + template + void operator()(TYPE *a, int n, COMP comp) { for (int i = 0; i < n; i += M) - insertion_sort(a+i, i > n-M ? n-i : M); + insertion_sort(a+i, i > n-M ? n-i : M, comp); for (int l = M; l < n; l *= 2) { for (int i = 0; i < n; i += 2*l) - merge(a, n, i, i+l, i+2*l); + merge(a, n, i, i+l, i+2*l, comp); for (int i = 0; i < n; ++i) a[i] = tmp[i]; } } + void operator()(TYPE *a, int n) + { + operator()(a, n, [](TYPE x, TYPE y){ return x < y; }); + } }; } diff --git a/tests/sort_regression_test.cc b/tests/sort_regression_test.cc index 310dd47..464e541 100644 --- a/tests/sort_regression_test.cc +++ b/tests/sort_regression_test.cc @@ -22,25 +22,27 @@ int main() typedef std::default_random_engine generator; typedef std::uniform_int_distribution distribution; auto rand = std::bind(distribution(1, MAX_N), generator(seed)); - int a[MAX_N], b[MAX_N], c[MAX_N], d[MAX_N], e[MAX_N], f[MAX_N], g[MAX_N]; - CODE::MergeSort merge_sort; + typedef std::pair Pair; + Pair a[MAX_N], b[MAX_N], c[MAX_N]; + int d[MAX_N], e[MAX_N]; + auto comp = [](Pair a, Pair b){ return a.second < b.second; }; + CODE::MergeSort merge_sort; for (int loop = 0; loop < 1000000; ++loop) { int size = rand(); for (int i = 0; i < size; ++i) - a[i] = b[i] = c[i] = d[i] = e[i] = rand(); - std::sort(a, a+size); - CODE::insertion_sort(b, size); + a[i].first = b[i].first = c[i].first = e[i] = i; + for (int i = 0; i < size; ++i) + a[i].second = b[i].second = c[i].second = d[i] = rand(); + std::stable_sort(a, a+size, comp); + CODE::insertion_sort(b, size, comp); for (int i = 0; i < size; ++i) assert(a[i] == b[i]); - merge_sort(c, size); + merge_sort(c, size, comp); for (int i = 0; i < size; ++i) assert(a[i] == c[i]); + CODE::insertion_sort(e, d, size); for (int i = 0; i < size; ++i) - f[i] = i; - std::stable_sort(f, f+size, [d](int i, int j){ return d[i] < d[j]; }); - CODE::insertion_sort(g, e, size); - for (int i = 0; i < size; ++i) - assert(f[i] == g[i]); + assert(a[i].first == e[i]); } std::cerr << "Sorting regression test passed!" << std::endl; return 0;