O(n) 原地稳定划分

在写完原地归并的文章后,我启动了原地划分的探索,但是过程非常曲折。原论文研究不动,我只好自己构造算法框架,也算是对原论文做了一些简化。
我还用 AI 写了篇论文记录一下这个简化的贡献(当然本文不是 AI 写的)。
主要参考的论文:
- Stable Minimum Space Partitioning in Linear Time (opens new window)
- 我的论文 A Simplified Linear-Time Minimum-Space Algorithm for Stable 0-1 Sorting (opens new window)
# 1. 问题定义
给定一个数组 a 和一个谓词 pred,要求将谓词为真的元素排到谓词为假的元素前面,同时要求 时间复杂度、原地( 额外空间复杂度)、稳定(划分前后,谓词为真的元素相对顺序不变,谓词为假同理)。
这个问题其实等价于 原地稳定 0-1 排序(元素的 key 只有 0 和 1),后面会按 0-1 排序来讲。
算法选用的模型是最常用的 Word RAM。Word RAM 模型的存储单元是 Word,可容纳 w 比特,要求刚好能表示内存里每个存储单元的地址。一个 Word 能存储指针,但是存不了更多的信息。这个特性会在算法里用到。
# 2. 前置算法
# 2.1. 区间旋转
把两个相邻区间 [A B] 原地变成 [B A],保持区间内部顺序不变。经典的做法是三次翻转法(或手摇算法),这里就不展开了。
代码里直接调用标准库的 std::rotate,复杂度 。
# 2.2. 0-1 归并
0-1 归并就是合并两个有序的 0-1 数组。做法非常简单,用区间旋转把左边数组的 1 和右边数组的 0 交换一下。
复杂度 。
template <typename T, typename Proj>
void inplace_01_merge(T* first, T* last, Proj proj) {
T* split_left = std::find_if(first, last, proj);
T* mid = std::find_if(split_left, last, [proj](T x) { return proj(x) == 0; });
T* split_right = std::find_if(mid, last, proj);
std::rotate(split_left, mid, split_right);
assert_or_throw(std::is_sorted(first, last, [proj](T a, T b) { return proj(a) < proj(b); }));
}
# 3. O(n) 原地稳定划分
# 3.1. 提取缓冲区
我们需要 个元素作为缓冲区,要求一半 0 一半 1。0 和 1 一一配对,它们可以提供两个状态(按顺序是 01 和 10),也就是 1 bit 信息;一共 bit,这在后续步骤会用到。
提取过程类似反过来的旋转归并算法。首先收集 0:找第 1 个 0,把它旋转到第 2 个 0 前面,然后把第 1、2 个 0 一起旋转到第 3 个 0 前面。不断这么做直到收集到 个 0,然后把这些 0 旋转到数组开头。
1 也是同理。
对于 个收集的元素,它们最多被旋转 次,乘起来就是复杂度 。其他 个元素都最多被旋转 1 次,复杂度也是 。
template <typename T, typename Pred>
std::tuple<T*, T*, T*> stable_collect_first_n(T* first, T* last, int64_t n, Pred pred) {
T* collect = first;
int64_t count = 0;
for (T* iter = first; iter < last; iter++) {
if (count < n && pred(*iter)) {
std::rotate(collect, collect + count, iter);
collect = iter - count;
count++;
}
}
std::rotate(first, collect, collect + count);
return {first, first + count, last};
}
如果缓冲区 0 或 1 的数量不够,那收集完后再进行一次旋转就能排好整个数组,可以提前退出算法了。
# 3.2. 计数排序与循环置换
这两个算法都很经典。
计数排序,假设要排 n 个数,值域是 。先开个长度 m 的计数数组,遍历 n 个数统计每个值出现的频次。然后计数数组求 exclusive scan(一种前缀和,结果的第一个数是 0)或 inclusive scan。
再遍历一次 n 个数,对于每个数,取出它在计数数组中当前值作为它的目标位置,并将该计数加 1。
此时如果没有原地要求,就开个 buffer 暂存排序好的 n 个数,最后移动回原数组。代码如下:
void counting_sort(std::vector<int>& arr) {
std::vector<int> count(*std::max_element(arr.begin(), arr.end()) + 1, 0);
for (int64_t i = 0; i < int64_t(arr.size()); i++) {
count[arr[i]]++;
}
int sum = 0;
for (int64_t i = 0; i < int64_t(count.size()); i++) {
int tmp = count[i];
count[i] = sum;
sum += tmp;
}
std::vector<int> buffer(arr.size());
for (int64_t i = 0; i < int64_t(arr.size()); i++) {
buffer[count[arr[i]]] = arr[i];
count[arr[i]]++;
}
arr = buffer;
}
如果计数排序有原地要求,只能先存 n 个数的目标位置(存在哪后面讲),然后进行循环置换。
什么是循环置换算法 (cycle permutation)?学过置换或图论的都知道,把 n 个位置看成点,从其初始位置向目标位置连一条有向边,就会形成若干个不相交的环。
算法会依次处理每个环,沿着环的方向绕一圈,通过交换操作将每个元素移动到正确位置。每当一个位置被正确处理,就将目标数组中该位置的值修改为当前索引,表示该位置已处理过,从而避免重复移动。
void cycle_permutation(std::vector<int>& arr, std::vector<int>& dests) {
for (int64_t i = 0; i < int64_t(arr.size()); i++) {
for (int64_t j = dests[i]; j != i;) {
std::swap(arr[i], arr[j]);
int64_t next = dests[j];
dests[j] = j;
j = next;
}
dests[i] = i;
}
}
由于 0-1 排序只有 01 两个值,所以计数数组只有两个,特别简单。因此复杂度 ,不用考虑值域。
# 3.3. word-base 存储器
计数排序需要一个地方存目标数组,可以把它存到 Word 里。Word 只能存整个数组的一个索引,但对于 长度的子数组就不一样了。Word 有 比特,子数组的索引不超过 比特,除一下是 个索引,正好放下目标数组。
这里提供存储器的实现,get / set 复杂度 。
struct WordStorage {
int64_t pos_bits;
uint64_t word = 0;
WordStorage(int64_t pos_bits) : pos_bits(pos_bits) {}
uint64_t get(int64_t index) const { return (word >> (index * pos_bits)) & ((uint64_t{1} << pos_bits) - 1); }
void set(int64_t index, uint64_t value) {
auto slot_mask = ((uint64_t{1} << pos_bits) - 1) << (index * pos_bits);
word = (word & ~slot_mask) | (static_cast<uint64_t>(value) << (index * pos_bits));
}
void reset() { word = 0; }
};
# 3.4. buffer-base 存储器
缓冲区也是类似。对于 的子数组,缓冲区可以存储 比特,子数组的索引不超过 比特,除一下是 个索引,正好放下目标数组。
这里提供存储器的实现,get / set 复杂度 ,B 为子数组大小。
template <typename T, typename Proj>
struct BufferStorage {
T* buf0;
T* buf1;
int64_t buffer_len;
int64_t pos_bits;
Proj proj;
BufferStorage(T* buf0, T* buf1, int64_t buffer_len, int64_t pos_bits, Proj proj)
: buf0(buf0), buf1(buf1), buffer_len(buffer_len), pos_bits(pos_bits), proj(proj) {}
uint64_t get(int64_t index) const {
uint64_t res = 0;
for (int64_t i = 0; i < pos_bits; i++) {
res |= uint64_t(proj(buf0[index * pos_bits + i])) << i;
}
return res;
}
void set(int64_t index, uint64_t value) {
for (int64_t i = 0; i < pos_bits; i++) {
if (uint64_t(proj(buf0[index * pos_bits + i])) != ((value >> i) & 1)) {
std::swap(buf0[index * pos_bits + i], buf1[index * pos_bits + i]);
}
}
}
void reset() { reset_buffer(buf0, buf1, buffer_len, proj); }
};
# 3.5. 块间排序
基于上面计数排序、循环置换和存储器,我们可以完成多个块的排序(要求每个块是纯 0 或纯 1)。这里代码是对算法的一些整合。
复杂度 ,B 是块数量,S 是块大小,T 是存储器 get / set 的复杂度。对于 word-base 存储器,T 是 不用担心复杂度。对于 buffer-base 存储器,T 是 约等于 ,因此 buffer-base 块排序有个门槛 ,来保证不会超过 。(这个结论不好讲明白,先记住)
template <typename T, typename Proj, typename Storage>
void sort_blocks_impl(T* first, T* last, int64_t block_size, Proj proj, Storage& storage) {
int64_t n_blocks = (last - first) / block_size;
if (n_blocks <= 1) {
return;
}
int64_t n_zeros = block_count_if(first, last, block_size, [&proj](T x) { return proj(x) == 0; });
std::array<int64_t, 2> pointers = {0, n_zeros};
for (int64_t i = 0; i < n_blocks; i++) {
int64_t key = proj(first[i * block_size]);
storage.set(i, pointers[key]);
pointers[key]++;
}
for (int64_t i = 0; i < n_blocks; i++) {
for (int64_t j = storage.get(i); j != i;) {
std::swap_ranges(first + i * block_size, first + (i + 1) * block_size, first + j * block_size);
int next = storage.get(j);
storage.set(j, j);
j = next;
}
storage.set(i, i);
}
storage.reset();
}
# 3.6. 块同质化
对于 k 个块,如果块内都已经有序,可以快速把前 k - 1 个块变为同质块(块内是纯 0 或纯 1)。
只要从左到右遍历相邻的两个块。如果两个块的 0 个数大于等于块大小,就用旋转把左边块填充 0,右边还是有序;否则左边块填充 1,右边还是有序。具体细节在代码里呈现。
每个元素被旋转常数次,因此复杂度为 。
template <typename T, typename Proj>
void homogenize_blocks(T* first, T* last, int64_t block_size, Proj proj) {
assert_or_throw((last - first) % block_size == 0);
int64_t n_blocks = (last - first) / block_size;
for (int64_t i = 0; i + 2 <= n_blocks; i++) {
T* left = first + i * block_size;
T* mid = left + block_size;
T* right = mid + block_size;
assert_or_throw(std::is_sorted(left, mid, [proj](T a, T b) { return proj(a) < proj(b); }));
assert_or_throw(std::is_sorted(mid, right, [proj](T a, T b) { return proj(a) < proj(b); }));
T* split_left = std::find_if(left, mid, proj);
T* split_right = std::find_if(mid, right, proj);
int64_t n_zeros = (split_left - left) + (split_right - mid);
if (n_zeros >= block_size) {
// 00000111 00000011 -> 00000[111 | 000000]11 -> 00000000 00011111
std::rotate(split_left, mid, split_right);
} else {
// 00011111 00111111 -> [000 | 11111] 00111111 -> 11111000 00111111
std::rotate(left, split_left, mid);
split_left = mid - split_left + left;
// 11111000 00111111 -> 11111[000 00 | 111]111 -> 11111111 00000111
std::rotate(split_left, split_right, split_right + (mid - split_left));
}
}
}
# 3.7. 块增长
块同质化和块间排序结合,就可以把 k 个大小为 b 的有序块合并为 1 个有序块。
首先用块同质化把前 k - 1 块变为纯 0 或纯 1 块(同质块),然后用计数排序把同质块排序,最后用 0-1 归并完成最后一块的处理。
这里的计数排序有两种,所以块增长也有两种方案。word-base 块增长系数 ,buffer-base 块增长系数 ,但是要求初始块大小大于等于 。
在实现上,增长系数可以用如下代码计算:
int64_t len = last - first;
int64_t max_word_bits = ceil_log2(len);
int64_t buffer_len = std::floor(std::sqrt(len));
int64_t max_blocks_for_word = 1;
while ((max_blocks_for_word + 1) * ceil_log2(max_blocks_for_word + 1) <= max_word_bits) {
max_blocks_for_word++;
}
int64_t max_blocks_for_buffer = 1;
while ((max_blocks_for_buffer + 1) * ceil_log2(max_blocks_for_buffer + 1) <= buffer_len) {
max_blocks_for_buffer++;
}
因此策略是,先用无门槛的 word-base 块增长 2 次,块大小变为 ,达到了 buffer-base 块增长的门槛 。然后用 3 次缓冲区块增长,变为 。这个块大规模数据可以证明大于等于 n,也就完成整个数组的排序。小规模数据逐一测试即可。
我在一开始用的是多层分块,但是实现起来太困难了。因为 、 不是整数,向上或向下取整很容易导致某一层分块数太多,超出 Word 或缓冲区大小。
数组长度不一定是块大小的倍数,不对齐部分的处理是很“脏”的工作。
得益于块增长的设计,只要把不对齐部分隔离出来,完成块同质化、块间排序、最后一块合并后,再用 0-1 归并把不对齐部分合入数组。(看似简单,其实我尝试了好多想法才搞出来)
// block_size 是小块的大小,merge_size 是大块的大小
int64_t merge_size = block_size * k;
for (int64_t start = 0; start < len; start += merge_size) {
int64_t end = std::min(start + merge_size, len);
int64_t end_l2_aligned = end / block_size * block_size;
int64_t mid = std::max(start, end_l2_aligned - block_size);
homogenize_blocks(first + start, first + end_l2_aligned, block_size, proj);
sort_blocks(first + start, first + mid, block_size, ...);
inplace_01_merge(first + start, first + end_l2_aligned, proj);
inplace_01_merge(first + start, first + end, proj);
}
block_size = merge_size;
# 3.8. 归还缓冲区
缓冲区也是一个有序区间,用 0-1 归并即可归还。
inplace_01_merge(buf0, last, proj);
# 4. 完整代码
完整代码包含测试:
https://gist.github.com/axiomofchoice-hjt/3000e1705d29f2ad92551d6c2e085a52 (opens new window)
# 5. 一个争议点
将多个数放到一个 Word 里(Word 压位“作弊”),这在原地算法领域是有争议的,反对者认为应该禁止位运算阻止这个操作。
那么禁止了位运算还能实现吗?还真可以,论文名叫 Radix Sorting With No Extra Space。这个论文非常重量级,可以 完成原地基数排序,以后有机会再来探讨。
# 6. 原论文说了什么
原论文写得很简单,将数组按 大小分块,块内用 Word 存储计数器和标记进行排序,然后块同质化。块间排序依赖另一篇复杂得多的论文,可以 复杂度、 次移动、原地稳定排序,元素的 key 要求可以枚举。
但是块内排序写的太模糊,难以理解,我只能 2 次块增长来替代这一步骤。块间排序更是硬核,我只能 3 次块增长替代这一步骤。
这样虽然不是很优雅,但可以 200 多行代码完整实现,也是一个值得记录的想法。
← 原地算法系列