目录

原地三路快速排序

img

今天讲一点简单的算法,算法参考 Sorting multisets stably in minimum space (opens new window)

为了让文章看起来更全面,就先聊一聊我对快速排序的一些分析。

# 1. 快排的经典实现

快排可以分为二路快排(Lomuto 或 Hoare 划分)、三路快排。

Lomuto 是同向的双指针,Hoare 是相向的双指针,三路快排两个向右、一个向左(或者三个向右)的三指针。

  1. 个人认为同向指针好实现,相向要注意指针相遇的“差一错误”。
  2. 这里可能要提醒一下初学者,Lomuto 或者劣化版 Hoare 划分,在处理相同元素数组时复杂度会退化到 O(n2)O(n^2),除非额外处理。只有原版 Hoare 划分和三路快排不会。

综合各种因素,我就直接给出同向三路快排作为参考。

template <typename RandomIt>
void quicksort(RandomIt first, RandomIt last) {
    if (last - first <= 1) {
        return;
    }
    auto pivot = *first;
    RandomIt p1 = first;
    RandomIt p2 = first;
    for (RandomIt it = first; it < last; ++it) {
        if (*it < pivot) {
            std::swap(*p1, *it);
            std::swap(*p2, *it);
            ++p1;
            ++p2;
        } else if (*it == pivot) {
            std::swap(*p2, *it);
            ++p2;
        }
    }
    quicksort(first, p1);
    quicksort(p2, last);
}

# 2. 快排的问题

从理论上看,快速排序会有两个问题,最坏情况复杂度和非严格原地。

# 2.1. 最坏情况复杂度

众所周知,快速排序会选取 pivot 作为划分的比较对象。在一般的实现中,这个 pivot 要么选择固定某个位置,要么数组里面随机取,要么从首、尾、中间三个数里选中位数(Median of Three 优化)。但是这些策略都无法避免最坏情况 O(n2)O(n^2) 的复杂度。

当然解决办法很简单,用选择算法(求第 k 小元素)可以直接拿到中位数,中位数作为 pivot 就避免了复杂度退化。

(当然面试不要这么写,std::nth_element 是黑盒)

template <typename RandomIt>
void quicksort(RandomIt first, RandomIt last) {
    if (last - first <= 1) {
        return;
    }
    RandomIt mid = first + (last - first) / 2;
    std::nth_element(first, mid, last);
    quicksort(first, mid);
    quicksort(mid + 1, last);
}

# 2.2. 非严格原地

快速排序是一个递归算法,会有 O(logn)O(\log n) 的递归栈。虽然在工程上认为是原地的,但是学术上不这么认为。消除递归栈是这篇文章的重点。

# 3. 前置算法

在往期文章里,我们实现了原地稳定划分和原地稳定选择,在这里就用 inplace_stable_partition_stub, inplace_stable_select_stub 来指代。如果没看过往期文章也没关系,就当作一个黑盒吧。

这里用不稳定的版本也可以,唯一区别是快排失去稳定。

// Stub: delegates to std::stable_partition (non-in-place, O(n) extra space).
template <typename RandomIt, typename Pred>
RandomIt inplace_stable_partition_stub(RandomIt first, RandomIt last, Pred pred) {
    return std::stable_partition(first, last, pred);
}

// Stub: copies to vector and uses std::ranges::nth_element (non-in-place, O(n) extra space).
template <typename RandomIt, typename Proj = std::identity>
void inplace_stable_select_stub(RandomIt first, RandomIt mid, RandomIt last, Proj proj = {}) {
    using T = std::iter_value_t<RandomIt>;
    auto buffer = std::vector<T>(first, last);
    std::ranges::nth_element(
        buffer.begin(), buffer.begin() + (mid - first), buffer.end(), {}, proj);
    T pivot = buffer[mid - first];
    RandomIt pivot_it =
        std::stable_partition(first, last, [&](T x) { return proj(x) < proj(pivot); });
    std::stable_partition(pivot_it, last, [&](T x) { return proj(x) == proj(pivot); });
}

# 4. 原地二路快速排序

消除递归栈的最简单的思路,就是让整个递归树完全可预测,让区间边界落在 2 的幂的位置。

如果再把深度优先改为广度优先就更简单了。分块大小 block 从 bit_ceil(n) 开始不断除以 2。每一个 block 的值都有,对于每一块进行中位数 pivot 划分,具体可以看代码。

template <typename RandomIt, typename Proj = std::identity>
void inplace_stable_quicksort(RandomIt first, RandomIt last, Proj proj = {}) {
    int64_t len = last - first;
    auto bit_ceil = [](int64_t x) {
        return static_cast<int64_t>(std::bit_ceil(static_cast<uint64_t>(x)));
    };
    for (int64_t block = bit_ceil(len); block > 1; block /= 2) {
        for (int64_t start = 0; start + (block / 2) < len; start += block) {
            RandomIt block_l = first + start;
            RandomIt block_r = block_l + std::min(block, last - block_l);
            RandomIt block_m = block_l + (block / 2);
            inplace_stable_select_stub(block_l, block_m, block_r, proj);
        }
    }
}

# 5. 原地三路快速排序

转了一圈终于回到论文。既然有了原地二路快速排序,为什么还需要三路呢?

事实上论文的目标是 O(nlogni=1mnilogni+n)O(n\log n - \sum_{i=1}^{m}n_i\log n_i + n) 原地稳定排序,nin_i 是第 i 个相同元素出现次数,这是多重集排序的下界。论文利用的就是三路快排快速跳过重复元素的特性。

用两个字总结就是:哨兵。

# 5.1. 哨兵

算法需要两个哨兵。

一开始,两个哨兵位于当前处理区间的右边,区间所有元素都小于哨兵。快排的划分结束后会出现左中右三个区间,左右两个都会递归。把两个哨兵放在右区间的两边(交换第一个哨兵和右区间最左边的数),这样左区间递归完后,就能定位到右区间了。

# 5.2. 初始化

哨兵要严格大于剩下的数。因此等于哨兵但不是哨兵的元素可以放在最右边,它们已经有序了,不参与剩下排序的步骤。

代码里是先找到最大值,按最大值进行划分,如果哨兵数小于 2 个再重复。tail_ittail_it + 1 被选取为哨兵。

void inplace_stable_quicksort(RandomIt first, RandomIt last, Proj proj = {}) {
    using T = std::iter_value_t<RandomIt>;
    RandomIt tail_it = last;
    while (last - tail_it < 2) {
        if (tail_it == first) {
            return;
        }
        T max = *std::ranges::max_element(first, tail_it, {}, proj);
        tail_it =
            inplace_stable_partition_stub(first, tail_it, [&](T x) { return proj(x) < proj(max); });
    }
    // ...
}

# 5.3. 模拟递归

处理当前区间,保证一开始 rightright + 1 是两个哨兵且第一个哨兵小于等于第二个。

此时如果区间长度大于 1 就往下走。首先用原地稳定选择找中位数,根据中位数划分为左、中、右三个区间。设置哨兵,即交换第一个哨兵和右区间最左边的数。然后就是递归左区间。

这里递归左区间能保证它有两个哨兵吗?答案是肯定的,中间区间的元素大于左区间,可以当哨兵用。如果中间区间只有一个元素,它的右边刚好是我们设置的第一个哨兵,同样大于左区间。第一个哨兵小于等于第二个也是满足的。

如果区间长度小于等于 1 就转移到右区间。首先 left = std::ranges::find_if(...); 跳过所有 pivot(就是父问题划分时产生的三个区间里中间那个区间),找到第一个哨兵。然后 right = std::ranges::find_if(...) - 1; 找到第二个哨兵左边的位置。最后 std::swap(*left, *right); 复原之前的交换。

template <typename RandomIt, typename Proj = std::identity>
std::tuple<RandomIt, RandomIt> three_way_partition(
    RandomIt first, RandomIt last, std::iter_value_t<RandomIt> pivot, Proj proj = {}) {
    using T = std::iter_value_t<RandomIt>;
    RandomIt pivot_start =
        inplace_stable_partition_stub(first, last, [&](T x) { return proj(x) < proj(pivot); });
    RandomIt pivot_end = inplace_stable_partition_stub(
        pivot_start, last, [&](T x) { return proj(x) == proj(pivot); });
    return {pivot_start, pivot_end};
}

void inplace_stable_quicksort(RandomIt first, RandomIt last, Proj proj = {}) {
    // ...
    RandomIt left = first;
    RandomIt right = tail_it;
    while (true) {
        if (right - left > 1) {
            inplace_stable_select_stub(left, left + ((right - left) / 2), right, proj);
            T pivot = left[(right - left) / 2];
            auto [pivot_start, pivot_end] = three_way_partition(left, right, pivot, proj);
            std::swap(*pivot_end, *right);
            right = pivot_start;
        } else {
            assert_or_throw(right <= tail_it);
            if (right == tail_it) {
                break;
            }
            left =
                std::ranges::find_if(right + 1, last, [&](T x) { return proj(x) != proj(*right); });
            right = std::ranges::find_if(left + 1, last, [&](T x) {
                return proj(x) >= proj(*left);
            }) - 1;
            std::swap(*left, *right);
        }
    }
}

# 6. 完整代码

完整实现 (opens new window)测试 (opens new window)

# 7. 结尾

这一期讲了快速排序的很多东西,并给出了最终的理论算法。虽然这个最终算法只是一个壳,原地稳定选择才是真正的核心,但是不妨碍我们感受哨兵的优雅机制。

目前原地算法系列已经把归并排序、快速排序两条主线都讲完了,算是一个小小的里程碑。