目录

原地归并排序

前排提示:这篇文章讲的算法 ACM 用不上。

众所周知,归并排序需要 O(n)O(n) 的额外空间,这无疑成为了归并排序的短板。

实际上,是存在额外空间为 O(1)O(1) 的归并排序(原地归并排序)。

# 1. 归并算法

在归并排序之前,我们需要对归并算法做一些改变。这里的归并算法指将两个有序数列合并为一个有序数列。

由于原地归并排序只能在原数列上操作,为了不丢失信息,必须每次操作都要用交换。

template <typename Iter>
void mergeTo(Iter l1, Iter r1, Iter l2, Iter r2, Iter res) {
    while (l1 != r1 || l2 != r2) {
        if (l2 == r2) {
            std::swap(*res++, *l1++);
        } else if (l1 == r1) {
            std::swap(*res++, *l2++);
        } else {
            std::swap(*res++, (*l2 < *l1) ? *l2++ : *l1++);
        }
    }
}

在上述代码中,mergeTo 函数的作用是将 [l1,r1)[l_1,r_1) 区间和 [l2,r2)[l_2,r_2) 区间这两个有序数组合并到 res 开始的数组里。这个函数的功能和 std::merge 非常像,但是用了交换而非直接赋值。

对于归并算法,有一个关键的操作。假设 [0,a),[b,n)[0,a),[b,n) 这两个区间各自都是有序的,且 abaa\le b-a,能否只用空间 [0,n)[0,n) 来完成这两个区间的合并呢?

答案是肯定的,调用一下 mergeTo(0, a, b, n, a) 即可(类型不严谨,看得懂就行)。

为什么呢?因为对于迭代器 res 来说,要想影响原来的有序数组,必须要 res > l2 才行。但是 swap(*res++, *l2++) 是不会缩短 resl2 距离的,且 swap(*res++, *l1++) 只会执行 aa 次,abaa \le b-a,因此这样的 mergeTo 操作是安全的。

# 2. 传统归并排序

有了归并算法,我们可以实现一个传统的归并排序(没错就是你所了解的那个)。很显然这个传统归并排序并不是最终的原地归并排序,因为 mergeTo 需要一倍的额外空间。

这里偷了个懒,用了递归写法导致额外空间变为 O(logn)O(\log n),不用递归可以进一步减少空间。

template <typename Iter>
void mergeSortTrivial(Iter l, Iter r, Iter work) {
    if (l + 1 == r) {
        return;
    }
    Iter m = l + (r - l) / 2;
    mergeSortTrivial(l, m, work);
    mergeSortTrivial(m, r, work);
    mergeTo(l, m, m, r, work);
    for (int i = 0; i < (r - l); i++) {
        std::swap(l[i], work[i]);
    }
}

在上述代码中,mergeSortTrivial[l,r)[l,r) 区间排序,但是利用了从 work 开始长度为 (rl)(r-l) 的“额外空间”,这部分空间被成为工作区。

mergeSortTrivial 结束后,工作区里的元素只是丢失了原有顺序,没有丢失信息。

# 3. 原地归并排序

终于到了最关键的部分了。

假设区间 [b,n)[b,n) 是有序的,一开始有 b=nb = n

a=b/2,a=b/2a=\lfloor b/2\rfloor,a'=\lceil b/2\rceil,首先用 [a,b)[a,b) 作为工作区对 [0,a)[0,a) 进行传统的归并排序。此时 [0,a)[0,a)[b,n)[b,n) 都是有序的,用那个“关键的操作”将这两个区间合并到 [a,n)[a',n)。这样 bb 就变成 aa',下降到原来的二分之一。

不断地减少 bb 直到 b=1b = 1,也就是说只有第一个元素是无序的,对它特殊处理一下。

template <typename Iter>
void mergeSort(Iter l, Iter r) {
    Iter b = r;
    while (l + 1 < b) {
        Iter a = l + (b - l) / 2;
        mergeSortTrivial(l, a, a);
        mergeTo(l, a, b, r, b - (a - l));
        b -= (a - l);
    }
    if (l + 1 == b) {
        while (l + 1 != r && l[0] > l[1]) {
            std::swap(l[0], l[1]);
            l++;
        }
    }
}

# 4. 复杂度

额外空间复杂度不说了。

对于时间复杂度,因为每次 bb 都变为原来的二分之一,即 b=n,n/2,n/4,b=n,n/2,n/4,\ldots,传统归并排序的复杂度总和为 O(nlogn+n/2logn/2+n/4logn/4+)=O(nlogn)O(n\log n+n/2\cdot\log n/2+n/4\cdot\log n/4+\ldots)=O(n\log n),“关键的操作”复杂度为 O(n+n+n++)O(n+n+n++\ldots)(共 logn\log n 项)=O(nlogn)=O(n\log n),最终复杂度为 O(nlogn)O(n\log n)

# 5. 完整代码

#include <algorithm>
#include <iostream>
#include <vector>

template <typename Iter>
void mergeTo(Iter l1, Iter r1, Iter l2, Iter r2, Iter res) {
    while (l1 != r1 || l2 != r2) {
        if (l2 == r2) {
            std::swap(*res++, *l1++);
        } else if (l1 == r1) {
            std::swap(*res++, *l2++);
        } else {
            std::swap(*res++, (*l2 < *l1) ? *l2++ : *l1++);
        }
    }
}

template <typename Iter>
void mergeSortTrivial(Iter l, Iter r, Iter work) {
    if (l + 1 == r) {
        return;
    }
    Iter m = l + (r - l) / 2;
    mergeSortTrivial(l, m, work);
    mergeSortTrivial(m, r, work);
    mergeTo(l, m, m, r, work);
    for (int i = 0; i < (r - l); i++) {
        std::swap(l[i], work[i]);
    }
}

template <typename Iter>
void mergeSort(Iter l, Iter r) {
    Iter b = r;
    while (l + 1 < b) {
        Iter a = l + (b - l) / 2;
        mergeSortTrivial(l, a, a);
        mergeTo(l, a, b, r, b - (a - l));
        b -= (a - l);
    }
    if (l + 1 == b) {
        while (l + 1 != r && l[0] > l[1]) {
            std::swap(l[0], l[1]);
            l++;
        }
    }
}

signed main() {
    std::vector<int> a = {7, 6, 3, 5, 4, 2, 1};
    mergeSort(a.begin(), a.end());
    for (int i : a) {
        std::cout << i << ' ';
    }
    std::cout << std::endl;
    return 0;
}

# 6. 参考

优化原地归并排序:实现 O (1) 空间复杂度 (opens new window)