原地归并排序
前排提示:这篇文章讲的算法 ACM 用不上。
众所周知,归并排序需要 的额外空间,这无疑成为了归并排序的短板。
实际上,是存在额外空间为 的归并排序(原地归并排序)。
# 1. 归并算法
在归并排序之前,我们需要对归并算法做一些改变。这里的归并算法指将两个有序数列合并为一个有序数列。
由于原地归并排序只能在原数列上操作,为了不丢失信息,必须每次操作都要用交换。
template <typename Iter>
void mergeTo(Iter l1, Iter r1, Iter l2, Iter r2, Iter res) {
while (l1 != r1 || l2 != r2) {
if (l2 == r2) {
std::swap(*res++, *l1++);
} else if (l1 == r1) {
std::swap(*res++, *l2++);
} else {
std::swap(*res++, (*l2 < *l1) ? *l2++ : *l1++);
}
}
}
在上述代码中,mergeTo
函数的作用是将 区间和 区间这两个有序数组合并到 res
开始的数组里。这个函数的功能和 std::merge
非常像,但是用了交换而非直接赋值。
对于归并算法,有一个关键的操作。假设 这两个区间各自都是有序的,且 ,能否只用空间 来完成这两个区间的合并呢?
答案是肯定的,调用一下 mergeTo(0, a, b, n, a)
即可(类型不严谨,看得懂就行)。
为什么呢?因为对于迭代器 res
来说,要想影响原来的有序数组,必须要 res > l2
才行。但是 swap(*res++, *l2++)
是不会缩短 res
和 l2
距离的,且 swap(*res++, *l1++)
只会执行 次,,因此这样的 mergeTo
操作是安全的。
# 2. 传统归并排序
有了归并算法,我们可以实现一个传统的归并排序(没错就是你所了解的那个)。很显然这个传统归并排序并不是最终的原地归并排序,因为 mergeTo
需要一倍的额外空间。
这里偷了个懒,用了递归写法导致额外空间变为 ,不用递归可以进一步减少空间。
template <typename Iter>
void mergeSortTrivial(Iter l, Iter r, Iter work) {
if (l + 1 == r) {
return;
}
Iter m = l + (r - l) / 2;
mergeSortTrivial(l, m, work);
mergeSortTrivial(m, r, work);
mergeTo(l, m, m, r, work);
for (int i = 0; i < (r - l); i++) {
std::swap(l[i], work[i]);
}
}
在上述代码中,mergeSortTrivial
将 区间排序,但是利用了从 work
开始长度为 的“额外空间”,这部分空间被成为工作区。
在 mergeSortTrivial
结束后,工作区里的元素只是丢失了原有顺序,没有丢失信息。
# 3. 原地归并排序
终于到了最关键的部分了。
假设区间 是有序的,一开始有 。
令 ,首先用 作为工作区对 进行传统的归并排序。此时 和 都是有序的,用那个“关键的操作”将这两个区间合并到 。这样 就变成 ,下降到原来的二分之一。
不断地减少 直到 ,也就是说只有第一个元素是无序的,对它特殊处理一下。
template <typename Iter>
void mergeSort(Iter l, Iter r) {
Iter b = r;
while (l + 1 < b) {
Iter a = l + (b - l) / 2;
mergeSortTrivial(l, a, a);
mergeTo(l, a, b, r, b - (a - l));
b -= (a - l);
}
if (l + 1 == b) {
while (l + 1 != r && l[0] > l[1]) {
std::swap(l[0], l[1]);
l++;
}
}
}
# 4. 复杂度
额外空间复杂度不说了。
对于时间复杂度,因为每次 都变为原来的二分之一,即 ,传统归并排序的复杂度总和为 ,“关键的操作”复杂度为 (共 项),最终复杂度为 。
# 5. 完整代码
#include <algorithm>
#include <iostream>
#include <vector>
template <typename Iter>
void mergeTo(Iter l1, Iter r1, Iter l2, Iter r2, Iter res) {
while (l1 != r1 || l2 != r2) {
if (l2 == r2) {
std::swap(*res++, *l1++);
} else if (l1 == r1) {
std::swap(*res++, *l2++);
} else {
std::swap(*res++, (*l2 < *l1) ? *l2++ : *l1++);
}
}
}
template <typename Iter>
void mergeSortTrivial(Iter l, Iter r, Iter work) {
if (l + 1 == r) {
return;
}
Iter m = l + (r - l) / 2;
mergeSortTrivial(l, m, work);
mergeSortTrivial(m, r, work);
mergeTo(l, m, m, r, work);
for (int i = 0; i < (r - l); i++) {
std::swap(l[i], work[i]);
}
}
template <typename Iter>
void mergeSort(Iter l, Iter r) {
Iter b = r;
while (l + 1 < b) {
Iter a = l + (b - l) / 2;
mergeSortTrivial(l, a, a);
mergeTo(l, a, b, r, b - (a - l));
b -= (a - l);
}
if (l + 1 == b) {
while (l + 1 != r && l[0] > l[1]) {
std::swap(l[0], l[1]);
l++;
}
}
}
signed main() {
std::vector<int> a = {7, 6, 3, 5, 4, 2, 1};
mergeSort(a.begin(), a.end());
for (int i : a) {
std::cout << i << ' ';
}
std::cout << std::endl;
return 0;
}