[模板]普通平衡树

$\text{Foreword}$

To be continued.

本文代码均为指针版.

本文中所说的平衡树指的均是二叉平衡树.

本文中所说的中序遍历指的均是左中右顺序.

题面传送门: Luogu-P3369

前言

平衡树板子.

给定 $n$ 次操作, 有且仅有下面6种操作:

  1. 插入 $x$.
  2. 删除 $x$ (若有多个相同的, 只删除一个).
  3. 查询 $x$ 的排名(排名定义: 比当前数小的数的个数+1. 若有多个相同的数, 输出最小的排名).
  4. 查询排名为 $x$ 的数.
  5. 求 $x$ 的前驱(前驱定义: 大于 $x$ 的最小的数)
  6. 求 $x$ 的后继(后继定义: 小于 $x$ 的最大的数)

$n \le 10^5,x \in [-10^7,10^7]$.

保证操作合法.

题解

平衡树基础之BST

(注: 如果你对下文内容看的不是非常明白, 请回来点击这里, 自己操作一番, 可以加深理解)

要学习平衡树, 你首先需要了解二叉查找树(Binary Search Tree, BST).

BST定义

  1. 树中每个节点都被赋予了一个权值;(这里先假设权值互不相同)
  2. 若左子树非空, 则左子树上所有节点的值均小于其根节点的值;
  3. 若右子树非空, 则右子树上所有节点的值均大于其根节点的值.

举个例子:

example1

这就是一棵二叉查找树.

看看这张图, 是不是正如我说的一样?

对于66号节点, 左子树中分别为1,19,23,45,54,59, 全部小于66;

右子树中分别为71,77,91, 全部大于66.

对于19号节点, 左子树中分别为1, 全部小于19;

右子树中分别为23,45,54,59, 全部大于19.

现在大致明白了这种神奇的树了吧?

这里给出一个性质:

二叉查找树的中序遍历是一个有序序列.

可以自己动手试一试画一画, 也可以用程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
typedef struct Node* ptr; // 将ptr作为Node*这种类型的别称

struct Node
{
int val;
ptr left, right;

Node(int val) : val(val) // 构造函数, 初始化新节点的左右儿子为空, 值为val
{ left = right = NULL; }
} *root;

void Print_Mid(ptr curr_node) // 中序遍历输出
{
if(!curr_node) // !可用于判断空指针
return;
Print_Mid(curr_node->left);
printf("%d\n", curr_node->val);
Print_Mid(curr_node->right);
}

几个基本操作

查找特定的值

1
2
3
4
5
6
7
8
9
inline ptr Find(int x, ptr curr_node) // 在以curr_node为根的子树中找到值为x的节点
{
while(curr_node->val != x)
if(x < curr_node->val)
curr_node = curr_node->left;
else
curr_node = curr_node->right;
return curr_node;
}

利用的还是二叉查找的思想.

查找最值

1
2
3
4
5
6
7
8
9
10
11
12
13
inline ptr Find_Min(ptr curr_node) // 在以curr_node为根的子树中找到最小值
{
while(curr_node->left) // 判断左子节点是否存在
curr_node = curr_node->left;
return curr_node;
}

inline ptr Find_Max(ptr curr_node) // 在以curr_node为根的子树中找到最大值
{
while(curr_node->right)
curr_node = curr_node->right;
return curr_node;
}

同理, 根据二叉查找树的性质可以得出.

插入节点

1
2
3
4
5
6
7
8
9
10
inline void Insert(int x) // 插入值为x的节点
{
register ptr curr_node = root;
while(curr_node)
if(x < curr_node->val)
curr_node = curr_node->left;
else
curr_node = curr_node->right;
curr_node = new Node(x);
}

同查找节点, 找到适合位置之后新建节点即可.

但是, 上述代码是完全错误的.

如果细心观察会发现: curr_node的改变, 不影响原值.

上述代码所犯错误跟这段代码差不多:

1
2
3
ptr p = root;
p = NULL;
// root == NULL? False.

p值的修改是不可能影响到root的. 这个例子应该好懂了些吧?

正确代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
inline void Insert(int x) // 插入值为x的节点
{
if(!root)
{
root = new Node(x);
return;
}
register ptr curr_node = root;
while(true)
{
if(x < curr_node->val)
{
if(!curr_node->left)
{
curr_node->left = new Node(x);
return;
}
curr_node = curr_node->left;
}
else
{
if(!curr_node->right)
{
curr_node->right = new Node(x);
return;
}
curr_node = curr_node->right;
}
}
}

相当于改成这样:

1
2
3
ptr p = root;
p->left = NULL;
// root->left == NULL? True.

从直接修改当前节点变成了修改当前节点的子节点, 保证了不会出现引用啥啥乱七八糟的问题.

删除节点

这里较为复杂.

首先找到该节点, 然后分三种情况讨论

  1. 该节点是叶节点.

直接删除.

  1. 该节点是链节点(只有一个儿子).

用这个儿子代替它的位置就行.

  1. 该节点非叶非链.

这种情况复杂一点, 一般是找到它的后继代替他的位置然后删掉后继.

显然,它的后继一定为链节点或叶节点, 删除后继很好删除.

再用一下上面的图

example1

现在我们要删除19, 只需要找到它的后继23, 将19这个节点的值赋为23后删除23节点即可.

(顺便说一句: 有一个23没有变成23是有意义的, 23是特指节点, 23指的是二十三这个数字.)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
inline void Simple_Delete(ptr& curr_node) // 对应1, 2情况的节点删除
{
if(curr_node->left)
curr_node = curr_node->left;
else
curr_node = curr_node->right;
}

inline void Delete(int x) // 删除值为x的节点
{
ptr curr_node = Find(x, root);
if(curr_node->left and curr_node->right)
{
ptr temp = Find_Min(curr_node->right);
curr_node->val = temp->val;
Simple_Delete(temp);
}
else
Simple_Delete(curr_node);
}

然而同样的 代码完全错误- -

请读者自行思考如何写出正确删除操作

一些小提示:

  • Find函数返回值并未引用.
  • 可以像Insert函数的改正一样, 对当前节点的子节点进行操作, 而不是对当前节点操作.
  • 牺牲一小点效率, 写成递归形式.(这样对于参数的引用就非常好处理)

偷个懒, 这个代码就不放了- -

如果实在想不出来, 可以参照下面的Splay中的非递归形式Treap中的递归形式.

前驱后继

定义请参照题面.

要注意的是, 所查询元素可以不是当前树中的元素.

即如果现在树中有1, 22, 333个节点, 我可以查询24的后继(33)和前驱(22).

两个操作差不多, 就只讲后继了.

find操作里面加点零件就可以了.

定义ret为从根节点走到目前节点为止, 比所查询元素大的最小的元素.

如果当前节点的值大于所查询元素, 更新ret为当前节点的值, 往左子树走;

如果当前节点的值小于所查询元素, 直接往右子树走就行.

原理自己想去

前驱操作把左右和大于小于倒过来就差不多了.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
inline int Pre/*predecessor*/(int val)
{
register ptr curr = root;
register int ret;

while(curr)
{
if(val > curr->val)
{
ret = curr->val;
curr = curr->son[1];
}
else
curr = curr->son[0];
}
return ret;
}

inline int Suc/*successor*/(int val)
{
register ptr curr = root;
register _val ret;

while(curr)
{
if(val < curr->val)
{
ret = curr->val;
curr = curr->son[0];
}
else
curr = curr->son[1];
}
return ret;
}

基本操作就讲这么多.

排名相关

然而…

??? 题目中的3,4操作呢?

别急.

刚刚讲的是最基础的BST, 真正要支持查询排名节点还需要维护一个东西, 叫size.

这题因为有可重复元素, 还要维护一个东西, 叫cnt.

分别解释一下:

  • size表示以当前节点为根的子树中节点的个数(包含重复元素).
  • cnt表示当前节点有多少个重复元素.(就是val一样, 然后累在一起)

给个定义不够清楚, 还是上刚刚那个图

example1

66size是10, 19size是6, 91size是3, 45size是4, 77size是1, 等等.

这图里面元素不重复, 所以cnt没有体现

如果我现在插入19, 它就不会新建节点, 而是使19cnt++(同时, 1966size也会++).

维护了size域之后, 我们就可以着手做3,4操作了

基础思想是分治.

首先可以知道的是, 在一棵子树中, 根节点的排名大小取决于左子树的size.

可以得出, 根节点的权值在这棵子树中的排名是一个闭区间: $[lsize+1,lsize+cnt]$.

(为什么说是一个区间, 因为有重复元素. 题目里面明确讲了取最小排名.)

(其中 $lsize$ 是左儿子的节点个数, $cnt$ 是当前节点的cnt)

依据这个原理, 分别来做3,4操作:

查询排名第 $k$ 的元素

其实排名为 $k$ 就是查找 $k$ 小.

因为BST的中序遍历有序, 所以, 如果将一个BST执行中序遍历, 比当前节点小的数的个数就是左子树的size(这句话需要好好理解一下).

总共三种情况:

  1. $lsize < k \le lsize + cnt$

要查找元素就是当前元素, 直接return即可.

  1. $k \le lsize$

说明要查找的节点在左子树中, 并且在左子树中还是排名第 $k$ (结合我前面说的, 想想为什么). 在左子树继续查找即可.

  1. $k>lsize+cnt$

说明要查找的节点在右子树中, 并且在右子树的排名为 $k-lsize-cnt$ (同样结合我前面那一句话, 想想为什么). 将 $k$ 减去对应值之后再在右子树查找即可.

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
inline int QK/*Query Kth*/(int k)
{
register ptr curr = root;
register int size;

while(curr)
{
size = Size(curr->son[0]);
if(k > size and k <= size + curr->cnt)
return curr->val;
if(k <= size)
curr = curr->son[0];
else
{
k -= size + curr->cnt;
curr = curr->son[1];
}
}
assert(false); // 正常的话会在while循环里面就return掉了, 如果执行到这里说明程序挂了, 这个指令会自动RE
}

查询元素排名

(本题中保证所查询的元素现在一定存在).

这和上述查找 $k$ 小的操作非常类似, 可以近似认为是互逆运算(其实它更好理解一些).

过程和find差不多, 但是需要额外定义一个k, 表示从根节点遍历到目前为止, 比所查询元素大的元素个数(初始化为0).

find的过程中, 如果你需要往左走, 对于k值是无影响的; 而如果要往右走, 则k需要加上左子树的size和当前节点的cnt.

结合代码理解一下(最后k还加上一个值, 是因为比查找到的节点小的数也包括了其左子树的size, 而在循环中是没有加上的):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
inline int QR/*Query Rank*/(int val)
{
register ptr curr = root;
register int k = 0;

while(curr->val != val)
{
if(curr->val < val)
{
k += Size(curr->son[0]) + curr->cnt;
curr = curr->son[1];
}
else
curr = curr->son[0];
}
k += Size(curr->son[0]) + 1;
return k;
}

平衡树引入

经过刚刚的一通乱讲说明, 这道题已经可以做了, 但时间复杂度还无法保证.

非常显然,BST各个操作的时间复杂度均为 $O(h)$. 其中 $h​$ 为二叉搜索树的高度.

然而BST极其可能会退化成一条链, 或者说树深度过深, 导致 $h$ 趋近于 $n$, 显然每个操作 $O(n)$ 的复杂度是我们无法接受的.

这个时候, 就需要我们的平衡树来救场了.

其实平衡树的本质就是BST加上调整.

调整一般有几种(OI用): 旋转, 重构, 分裂合并.

典型的代表就是 Splay, 替罪羊树, fhq Treap.

Splay

Splay是平衡树的一种.

TarjanDaniel Sleator开发.(感觉信息这方面哪里都能看见Tarjan…)

它的核心操作只有一个:Splay(顾名思义, 伸展操作).

节点定义

Splay本质而言还是一棵树, 我们要先确定其节点要存放什么东西.

相对BST的节点要存的东西其实也没多多少, 只是多了一个父指针指向父亲而已.

(其实也有递归实现的无父指针版, 如果想学可以参考这里)

先留个坑- -放个代码在这里2333

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
/**********************************************************
* Author : EndSaH
* Email : hjxhb1@gmail.com
* Created Time : 2018-12-11 19:26
* FileName : new.cpp
* Website : https://endsah.cf
* *******************************************************/

#include <sys/mman.h>
#include <unistd.h>
#include <cstdio>
#include <cctype>
#include <cassert>

class Istream
{
char *ipos;

public :
Istream()
{
#ifndef ONLINE_JUDGE
freopen("new.in", "r", stdin);
freopen("new.out", "w", stdout);
#endif
ipos = (char*)mmap(NULL, lseek(0, 0, SEEK_END), PROT_READ, MAP_PRIVATE, 0, 0);
}

Istream& operator>>(int& n)
{
register bool flag = false;
n = 0;
while(!isdigit(*ipos))
{
if(*ipos == '-')
flag = true;
++ipos;
}
while(n = (n << 3) + (n << 1) + (*ipos++ & 15), isdigit(*ipos));
if(flag)
n = -n;
return *this;
}
} in;

char _obuf[1 << 20], _stk[20];
class Ostream
{
char *opos, *oend, *stkpos;

public :
Ostream()
{
oend = (opos = _obuf) + (1 << 20) - 1;
stkpos = _stk;
}

~Ostream()
{ fwrite(_obuf, 1, opos - _obuf, stdout); }

void Putchar(char ch)
{
*opos = ch;
if(opos == oend)
{
fwrite(_obuf, 1, 1 << 20, stdout);
opos = _obuf;
}
++opos;
}

Ostream& operator<<(int n)
{
if(n < 0)
{
Putchar('-');
n = -n;
}
do
{
*++stkpos = n % 10 ^ 48;
n /= 10;
} while(n);
while(stkpos != _stk)
Putchar(*stkpos--);
return *this;
}

Ostream& operator<<(char c)
{
Putchar(c);
return *this;
}
} out;

template<typename _Tp>
inline bool Chkmin(_Tp& x, const _Tp& y)
{ return x > y ? x = y, true : false; }

template<typename _Tp>
inline bool Chkmax(_Tp& x, const _Tp& y)
{ return x < y ? x = y, true : false; }

template<typename _val>
struct Splay_Node
{
typedef Splay_Node <_val>* ptr;

ptr son[2], fa;
_val val;
int size, cnt;

Splay_Node(const _val& val, ptr fa) : val(val), fa(fa)
{
size = cnt = 1;
son[0] = son[1] = NULL;
}
};

template<typename _val>
class Splay_Tree
{
typedef Splay_Node <_val>* ptr;

ptr root;

public :
Splay_Tree()
{ root = NULL; }

int Ident/*Identify*/(ptr x)
{ return x->fa == NULL ? 0 : x->fa->son[1] == x; }

int Size(ptr x)
{ return x == NULL ? 0 : x->size; }

void Update(ptr x)
{ x->size = Size(x->son[0]) + Size(x->son[1]) + x->cnt; }

void Link(ptr fa, ptr son, int dir/*direction*/)
{
if(!fa)
root = son;
else
fa->son[dir] = son;
if(son)
son->fa = fa;
}

void Rotate(ptr x)
{
register ptr fa = x->fa;
register int dir = Ident(x);

Link(fa->fa, x, Ident(fa));
Link(fa, x->son[dir ^ 1], dir);
Link(x, fa, dir ^ 1);
Update(fa);
}

void Splay(ptr x, ptr goal)
{
register ptr fa;

while(x->fa != goal)
{
fa = x->fa;
if(fa->fa == goal)
{
Rotate(x);
break;
}
if(Ident(fa) ^ Ident(x))
Rotate(x);
else
Rotate(fa);
Rotate(x);
}
Update(x);
}

void Insert(const _val& val)
{
if(!root)
{
root = new Splay_Node<_val>(val, NULL);
return;
}
register ptr curr = root;
register int dir = val > curr->val;

while(true)
{
++curr->size;
if(curr->val == val)
{
++curr->cnt;
return;
}
dir = val > curr->val;
if(curr->son[dir])
curr = curr->son[dir];
else
{
curr->son[dir] = new Splay_Node<_val>(val, curr);
Splay(curr->son[dir], NULL);
return;
}
}
}

void Delete(const _val& val)
{
register ptr curr/*current*/ = root, pre/*predecessor*/;

while(curr->val != val)
curr = curr->son[val > curr->val];
Splay(curr, NULL);
if(curr->cnt > 1)
{
--curr->cnt, --curr->size;
return;
}
pre = curr->son[0];
if(pre)
{
while(pre->son[1])
pre = pre->son[1];
Splay(pre, curr);
root = pre;
root->fa = NULL, root->son[1] = curr->son[1];
if(root->son[1])
{
root->son[1]->fa = root;
root->size += root->son[1]->size;
}
}
else
{
root = curr->son[1];
if(root)
root->fa = NULL;
}
delete curr;
}

int QR/*Query Rank*/(const _val& val)
{
register ptr curr = root;
register int k = 0;

while(curr->val != val)
{
if(curr->val < val)
{
k += Size(curr->son[0]) + curr->cnt;
curr = curr->son[1];
}
else
curr = curr->son[0];
}
k += Size(curr->son[0]) + 1;
Splay(curr, NULL);
return k;
}

_val QK/*Query Kth*/(int k)
{
register ptr curr = root;
register int size;

while(curr)
{
size = Size(curr->son[0]);
if(k > size and k <= size + curr->cnt)
return curr->val;
if(k <= size)
curr = curr->son[0];
else
{
k -= size + curr->cnt;
curr = curr->son[1];
}
}
assert(false);
}

_val Pre(const _val& val)
{
register ptr curr = root;
register _val ret;

while(curr)
{
if(val > curr->val)
{
ret = curr->val;
curr = curr->son[1];
}
else
curr = curr->son[0];
}
return ret;
}

_val Suc/*successor*/(const _val& val)
{
register ptr curr = root;
register _val ret;

while(curr)
{
if(val < curr->val)
{
ret = curr->val;
curr = curr->son[0];
}
else
curr = curr->son[1];
}
return ret;
}
};

int main()
{
register int n, opt, x;
Splay_Tree<int> wib;

in >> n;
while(n--)
{
in >> opt >> x;
switch(opt)
{
case 1 :
wib.Insert(x);
break;
case 2 :
wib.Delete(x);
break;
case 3 :
out << wib.QR(x) << '\n';
break;
case 4 :
out << wib.QK(x) << '\n';
break;
case 5 :
out << wib.Pre(x) << '\n';
break;
case 6 :
out << wib.Suc(x) << '\n';
break;
}
}
}

Treap

Treap是一种相对较为随机的平衡树.

具体来说, 就是对每个节点赋一个优先级. 保证不破坏其二叉查找树的本质的同时, 使节点按所赋的优先级排为堆序.

继续留坑(代码比较丑, 到时候重写一下再放上来)