(下)平衡树总结
700
2023.07.03
2023.07.05
发布于 未知归属地

上一期提到了几个常见的平衡树,这一期主要撰写三个不算常见但常数十分优秀,码量较小的平衡树.

SBT

SBT是一种常数较为不错的平衡树,又名(SB Tree.为什么会在(下)当中才提及这三种平衡树,那是因为这三种平衡树在有前面几种平衡树的经验下是十分容易学会和撰写的.

SBT引入学习

全称名为Size Balanced Tree,它的平衡依赖于旋转=>这个旋转实际上可以直接套treap的旋转函数即可.那么它与哪种平衡树比较像呢,其实可以把它类比替罪羊树/AVL树.替罪羊树在上一章提到当超出比例因子时就会重构调整平衡.SBT则是超出xxxx,则会通过旋转调整平滑.

旋转条件

SBT的旋转条件其实很简单.对于一棵SBT树而言,保证上层的size永远大于下层size即可.
而对于某个节点:node
我们考虑这几个点:left,right,left_left,left_right,right_left,right_right
显而易见的是,造成不平衡只有四种可能,两大类:右边right<left_left/left_right,左边left<right_left/right_right
所以我们可以撰写出这样的一个代码:

Python
@staticmethod
    def push_down(node: 'Node'):
        l = get_size(node.child[0])
        r = get_size(node.child[1])
        l_l = l_r = r_l = r_r = 0
        if node.child[0]:
            l_l = get_size(node.child[0].child[0])
            l_r = get_size(node.child[0].child[1])
        if node.child[1]:
            r_l = get_size(node.child[1].child[0])
            r_r = get_size(node.child[1].child[1])
        if l_l > r:
            node = SBT.rot(node, SBT.R)
            node.child[1] = SBT.push_down(node.child[1])
            node = SBT.push_down(node)
        elif l_r > r:
            node.child[0] = SBT.rot(node.child[0], SBT.L)
            node = SBT.rot(node, SBT.R)
            node.child[0] = SBT.push_down(node.child[0])
            node.child[1] = SBT.push_down(node.child[1])
            node = SBT.push_down(node)
        elif r_l > l:
            node.child[1] = SBT.rot(node.child[1], SBT.R)
            node = SBT.rot(node, SBT.L)
            node.child[0] = SBT.push_down(node.child[0])
            node.child[1] = SBT.push_down(node.child[1])
            node = SBT.push_down(node)
        elif r_r > l:
            node = SBT.rot(node, SBT.L)
            node.child[0] = SBT.push_down(node.child[0])
            node = SBT.push_down(node)
        return node

旋转规则的话其实就和AVL四种情况调整差不多.可以参照AVL的四种情况LL,LR,RL,RR类比书写,当然可以观察出其实这四种情况我们可以合并成两种,当然展开写也比较清晰,也是SBT的核心代码之一.

旋转

旋转直接可以套旋转Treap的代码合并成一种情况即可:

Python
# 旋转某个根,旋转类型t
    @staticmethod
    def rot(node: 'Node', rot_type: int):
        new_root = node.child[rot_type]  # 新根
        # 左旋转,原根对应的右子树换成右子树的左子树
        # 右旋转,原根对应的左子树换成左子树的右子树
        # 左边原根的某子树,右边对应同向的子树的相反方向的子树,参照AVL
        node.child[rot_type] = new_root.child[rot_type ^ 1]
        # 新根的对应另一侧为原根,右旋,原变左;左旋,原变右
        new_root.child[rot_type ^ 1] = node
        # 更新树的大小
        node.update()
        new_root.update()
        return new_root

在SBT当中,所有的代码除了remove都可以套旋转treap或者说普通的BST的代码,唯一一个需要注意的点就是删除.这里的删除常见的其实有两种方式前驱删除和类似treap的移根删除.这里提提移根删除,比较容易衔接treap.

删除核心

删除最难的一步其实是删除点有左右子树,然后删掉以后需要合并,也正是因为删除这个操作,导致SBT不容易像其他平衡树一样可以用cnt来记录重复的点的个数,而是一个点只能表示一个数.

回顾合并

我们来回顾下前面提到的每种平衡树是如何处理这种情况.

  1. 旋转treap:利用旋转将根移动到一侧子树上进行删除,最终变为左右子树至少一个不存在
Python
tmp = self.R if node.child[0].rank < node.child[1].rank else self.L
node = self.rot(node, tmp)
node.child[tmp ^ 1] = delete(node.child[tmp ^ 1])
node.update()
  1. 无旋转Treap:利用分割与合并实现
Python
def delete(self, val: int):
    l1, r1 = self.split_val(self.root, val)
    l2, r2 = self.split_val(l1, val - 1)
    if r2.cnt > 1:
        r2.cnt -= 1
        r2.update()
        l2 = self.merge(l2, r2)
    self.root = self.merge(l2, r1)
  1. Splay:利用合并左右子树实现,这里因为splay操作可以将某个点移动到根
Python
def merge(x: 'Node', y: 'Node'):
    if x is None:
        if y:
            y.fa = None
        return y
    if y is None:
        x.fa = None
        return x
    x.fa = y.fa = None
    y, _ = y.kth(1)
    y.child[0] = x
    x.fa = y
    y.update()
    return y
self.root = merge(tmp.child[0], tmp.child[1])
  1. 替罪羊树:懒删除,当重构时再去掉已经删除掉的点
Python
def remove(self, val: int):
    def delete(node: 'Node'):
        if node:
            if node.val == val:
                if node.cnt:
                    node.cnt -= 1
            else:
                if node.val > val:
                    node.child[0] = delete(node.child[0])
                else:
                    node.child[1] = delete(node.child[1])
            node.update()
            if is_need_rebuild(node):
                return rebuild(node)
            return node
        return None
    self.root = delete(self.root)

那么SBT跟那个比较像呢,答案是旋转treap,这里先阐述下合并思路:其实合并两棵子树的核心正如我在splay合并所讲:
1、查找left的最大值(第left_size小),查找完以后由于有splay操作,它会被变为left的根节点(这里注意提前设置left和right的father为空防止转到整棵伸展树的根去).然后left.right=right更新返回left即可
2、也可以反过来查找right的最小的值(第1小),然后right.left=left,更新返回right即可
而这里我们可以采用迭代找到那个点,但是否需要“摘”下来进行拼接呢,答案是这样的话更新不容易:
需要用一个栈存储路径上的点然后依次更新,或者再找的过程中反复-1,然后并且需要取到答案点的前一个点pre(如果存在的话),因为摘掉以后如果ans的另一边还有子树是需要进行与它的father进行拼接的.
这样码量略长,而且容易看出如果cnt不为1,则需要用最终的cnt去更新所有路径上的点,麻烦
这里给出换根删除,我们只需要做一步操作,将curr赋值ans的值然后往对应子树删除curr.val即可

Python
ans = node.child[0]
while ans.child[1]:
    ans = ans.child[1]
node.val = ans.val
node.child[0] = delete(node.child[0], node.val)

以上操作的正确性:cnt固定为1,ans最终需要拼接左右子树,然后删除ans原先那个点,实际上删除点换根到了子树上,变为了左右子树至少一个为空的情况进行删除了.

细节

这里由于size固定为1了,所以我们不能单独考虑val==node.val的情况了,需要合并为node.val<=val和node.val>val两大类情况

SBT参考代码

Python
C++
class Node:
    # 左右子树用child数组,值,优先级(用随机数),重复数,树大小
    __slots__ = "child", "val", "size"

    def __init__(self, val: int, left: 'Node' = None, right: 'Node' = None):
        self.val = val
        self.child = [left, right]
        self.size = 1

    # 更新子树大小
    def update(self):
        self.size = 1 + get_size(self.child[0]) + get_size(self.child[1])


def get_size(node: 'Node'):
    if node:
        return node.size
    return 0


class SBT:
    __slots__ = "root"
    # L左旋,把右子树变根;R右旋,把左子树变根
    L, R = 1, 0

    def __init__(self):
        self.root = None

    # 旋转某个根,旋转类型t
    @staticmethod
    def rot(node: 'Node', rot_type: int):
        new_root = node.child[rot_type]  # 新根
        # 左旋转,原根对应的右子树换成右子树的左子树
        # 右旋转,原根对应的左子树换成左子树的右子树
        # 左边原根的某子树,右边对应同向的子树的相反方向的子树,参照AVL
        node.child[rot_type] = new_root.child[rot_type ^ 1]
        # 新根的对应另一侧为原根,右旋,原变左;左旋,原变右
        new_root.child[rot_type ^ 1] = node
        # 更新树的大小
        node.update()
        new_root.update()
        return new_root

    @staticmethod
    def push_down(node: 'Node'):
        l = get_size(node.child[0])
        r = get_size(node.child[1])
        l_l = l_r = r_l = r_r = 0
        if node.child[0]:
            l_l = get_size(node.child[0].child[0])
            l_r = get_size(node.child[0].child[1])
        if node.child[1]:
            r_l = get_size(node.child[1].child[0])
            r_r = get_size(node.child[1].child[1])
        if l_l > r:
            node = SBT.rot(node, SBT.R)
            node.child[1] = SBT.push_down(node.child[1])
            node = SBT.push_down(node)
        elif l_r > r:
            node.child[0] = SBT.rot(node.child[0], SBT.L)
            node = SBT.rot(node, SBT.R)
            node.child[0] = SBT.push_down(node.child[0])
            node.child[1] = SBT.push_down(node.child[1])
            node = SBT.push_down(node)
        elif r_l > l:
            node.child[1] = SBT.rot(node.child[1], SBT.R)
            node = SBT.rot(node, SBT.L)
            node.child[0] = SBT.push_down(node.child[0])
            node.child[1] = SBT.push_down(node.child[1])
            node = SBT.push_down(node)
        elif r_r > l:
            node = SBT.rot(node, SBT.L)
            node.child[0] = SBT.push_down(node.child[0])
            node = SBT.push_down(node)
        return node

    def add(self, val: int):
        def insert(node: 'Node'):
            if node is None:
                return Node(val)
            if val < node.val:
                node.child[0] = insert(node.child[0])
                node.update()
            else:
                node.child[1] = insert(node.child[1])
                node.update()
            return SBT.push_down(node)

        self.root = insert(self.root)

    def remove(self, v: int):
        def delete(node: 'Node', val: int):
            if val > node.val:
                node.child[1] = delete(node.child[1], val)
            elif val < node.val:
                node.child[0] = delete(node.child[0], val)
            else:
                state = 0
                state |= int(node.child[0] is not None)
                state |= int(node.child[1] is not None) << 1
                # 00表示都无,01右有,10左有,11都有
                if state == 0:
                    return None
                elif state == 1:
                    return node.child[0]
                elif state == 2:
                    return node.child[1]
                elif state == 3:
                    ans = node.child[0]
                    while ans.child[1]:
                        ans = ans.child[1]
                    node.val = ans.val
                    node.child[0] = delete(node.child[0], node.val)

            node.update()
            return node

        self.root = delete(self.root, v)

    def query_kth(self, val: int):
        def ans(node: 'Node'):
            if node is None:
                return 1
            size = get_size(node.child[0])
            if val <= node.val:
                return ans(node.child[0])
            else:
                return size + 1 + ans(node.child[1])

        return ans(self.root)

    def kth_val(self, rank: int):
        def ans(node: 'Node', k: int):
            size = get_size(node.child[0])
            if k <= size:
                return ans(node.child[0], k)
            elif k == size + 1:
                return node.val
            else:
                return ans(node.child[1], k - size - 1)

        return ans(self.root, rank)

    def pre(self, val: int):
        def ans(node: 'Node'):
            res = node.val
            while node:
                if node.val < val:
                    res = node.val
                    node = node.child[1]
                else:
                    node = node.child[0]
            return res

        return ans(self.root)

    def nxt(self, val: int):
        def ans(node: 'Node'):
            res = node.val
            while node:
                if node.val > val:
                    res = node.val
                    node = node.child[0]
                else:
                    node = node.child[1]
            return res

        return ans(self.root)

LeafyTree

这是一类比较特殊的树,与其他树不同,它的信息都是保存在"叶子节点"上,而不是其他非叶子节点.
其实替罪羊树也可以看做一类leafytree.它是通过旋转实现平衡,它的条件其实和替罪羊树差不多,当左子树的大小>n倍右子树大小或者反过来进行调整平滑.每个非叶子节点保存的是叶子节点的合并信息.
所以其实它和线段树的思想是比较相近的.
这里倍数常取4倍.(一些大神的数学计算结果

细节阐述

关于add的话,一般来说是与左子树的value去比较value,当前节点的value记录的其实是右子树的.而size则为二者共同维护.事实上可以写出像push_up这样的代码为
Node(right.val,left.size+right.size,left,right)
删除操作较为容易(因为答案都在叶子节点上,是莫有子树的,美滋滋的)
通常为了使根不变,加入个INF大为根,插入操作左子树保存下界,右子树则是上界,有点二分查找那味了.
旋转操作很easy,可以看做其实就是线段树的push_up,不过根据实际不平衡来push_up

参考代码

Python
C++
from cmath import inf

alpha = 4


def merge(x: 'Node', y: 'Node'):
    return Node(y.val, x.size + y.size, x, y)


class Node:
    __slots__ = "val", "left", "right", "size"

    def __init__(self, val: int, size: int = 1, left: 'Node' = None, right: 'Node' = None):
        self.val = val
        self.left = left
        self.right = right
        self.size = size

    def update(self):
        if self.left is None:
            self.size = 1
            return
        self.size = self.left.size + self.right.size
        self.val = self.right.val

    def rotate(self):
        if self.left is None or self.right is None:
            return
        if self.left.size > alpha * self.right.size:
            self.right = merge(self.left.right, self.right)
            self.left = self.left.left
        if self.right.size > alpha * self.left.size:
            self.left = merge(self.left, self.right.left)
            self.right = self.right.right

    def insert(self, x: int):
        if self.size == 1:
            self.left = Node(min(self.val, x))
            self.right = Node(max(self.val, x))
        else:
            if self.left.val < x:
                self.right.insert(x)
            else:
                self.left.insert(x)
        self.update()
        self.rotate()

    def remove(self, x: int):
        if x > self.left.val:
            curr = self.right
            other = self.left
        else:
            curr = self.left
            other = self.right
        if curr.size == 1:
            if curr.val == x:
                self.left = other.left
                self.right = other.right
                self.val = other.val
            else:
                return
        else:
            curr.remove(x)
        self.update()
        self.rotate()

    def kth(self, k: int):
        if self.size == 1:
            return self.val
        if k > self.left.size:
            return self.right.kth(k - self.left.size)
        else:
            return self.left.kth(k)

    def query(self, val: int):
        if self.size == 1:
            return 1
        if val > self.left.val:
            return self.left.size + self.right.query(val)
        else:
            return self.left.query(val)
class LeafyTree:
    __slots__ = "root"

    def __init__(self):
        self.root = Node(inf, 1)

    def add(self, val: int):
        self.root.insert(val)

    def remove(self, val: int):
        self.root.remove(val)

    def query(self, val: int):
        return self.root.query(val)

    def kth(self, k: int):
        return self.root.kth(k)

    def pre(self, val: int):
        return self.kth(self.query(val) - 1)

    def nxt(self, val: int):
        return self.kth(self.query(val + 1))

WBLT

全称Weight Balanced Leafy Tree,当然也有高度的了,跟AVL的平衡因子很像.
这里可以当做加权平衡的Leafy树,常数更小的特点,这里通常将倍数取做5.

修改之处

只需要把psuh_up修改下即可,将原来的只判断2次改为判断4次,当旋转后还不满足,继续旋转,不过是类似AVL的旋转

Python
def rotate(self, flag: bool):
    if flag:
        self.left = merge(self.left, self.right.left)
        self.right = self.right.right
    else:
        self.right = merge(self.left.right, self.right)
        self.left = self.left.left

def push_up(self):
    if self.left is None or self.right is None:
        return
    if self.left.size > self.right.size * alpha:
        self.rotate(False)
    elif self.right.size > self.left.size * alpha:
        self.rotate(True)
    if self.left.size > self.right.size * alpha:
        self.left.rotate(True)
        self.rotate(False)
    elif self.right.size > self.left.size * alpha:
        self.right.rotate(False)
        self.rotate(True)

原理实际上为WBLT可以看做类似堆一样的结构,可以用来实现堆.拥有堆的性质,所以在旋转的过程中是需要考虑size是否满足堆的性质的.其余与leafytree是一样的

参考代码

Python
C++
from cmath import inf

alpha = 5


def merge(x: 'Node', y: 'Node'):
    return Node(y.val, x.size + y.size, x, y)


class Node:
    __slots__ = "val", "left", "right", "size"

    def __init__(self, val: int, size: int = 1, left: 'Node' = None, right: 'Node' = None):
        self.val = val
        self.left = left
        self.right = right
        self.size = size

    def rotate(self, flag: bool):
        if flag:
            self.left = merge(self.left, self.right.left)
            self.right = self.right.right
        else:
            self.right = merge(self.left.right, self.right)
            self.left = self.left.left

    def push_up(self):
        if self.left is None or self.right is None:
            return
        if self.left.size > self.right.size * alpha:
            self.rotate(False)
        elif self.right.size > self.left.size * alpha:
            self.rotate(True)
        if self.left.size > self.right.size * alpha:
            self.left.rotate(True)
            self.rotate(False)
        elif self.right.size > self.left.size * alpha:
            self.right.rotate(False)
            self.rotate(True)

    def update(self):
        if self.left:
            self.size = self.left.size + self.right.size
            self.val = self.right.val
        else:
            self.size = 1

    def insert(self, x: int):
        if self.size == 1:
            self.left = Node(min(self.val, x))
            self.right = Node(max(self.val, x))
        else:
            if self.left.val < x:
                self.right.insert(x)
            else:
                self.left.insert(x)
        self.update()
        self.push_up()

    def remove(self, x: int):
        if x > self.left.val:
            curr = self.right
            other = self.left
        else:
            curr = self.left
            other = self.right
        if curr.size == 1:
            if curr.val == x:
                self.left = other.left
                self.right = other.right
                self.val = other.val
            else:
                return
        else:
            curr.remove(x)
        self.update()
        self.push_up()

    def kth(self, k: int):
        if self.size == 1:
            return self.val
        if k > self.left.size:
            return self.right.kth(k - self.left.size)
        else:
            return self.left.kth(k)

    def query(self, val: int):
        if self.size == 1:
            return 1
        if val > self.left.val:
            return self.left.size + self.right.query(val)
        else:
            return self.left.query(val)


class WBLT:
    __slots__ = "root"

    def __init__(self):
        self.root = Node(inf, 1)

    def add(self, val: int):
        self.root.insert(val)

    def remove(self, val: int):
        self.root.remove(val)

    def query(self, val: int):
        return self.root.query(val)

    def kth(self, k: int):
        return self.root.kth(k)

    def pre(self, val: int):
        return self.kth(self.query(val) - 1)

    def nxt(self, val: int):
        return self.kth(self.query(val + 1))

给出洛谷树套树板题关于这三棵树的嵌套

C++
C++
C++
#include <bits/stdc++.h>

using namespace std;
typedef pair<int, int> pii;
typedef long long ll;

template<typename T, T N>
T ModInt(T value) { return (value % N + N) % N; }

#define mod(m, n) ModInt<int,n>(m);

struct Node {
    ll val;
    Node *child[2];
    ll size{1};

    Node(ll val) : val(val){
        child[0] = child[1] = nullptr;
    }

    void update() {
        size = 1;
        if (child[0] != nullptr)size += child[0]->size;
        if (child[1] != nullptr)size += child[1]->size;
    }
};

ll get_size(Node *node) {
    return node ? node->size : 0;
}

enum Rota {
    L = 1, R = 0
};

struct SBT {
    Node *root{nullptr};
    void insert(ll v) {
        root = add(root, v);
    }

    static Node *rota(Node *node, Rota t) {
        auto new_root = node->child[t];
        node->child[t] = new_root->child[t ^ 1];
        new_root->child[t ^ 1] = node;
        node->update();
        new_root->update();
        return new_root;
    }

    Node *push_down(Node *node) {
        if(node== nullptr)return nullptr;
        ll l = get_size(node->child[0]);
        ll r = get_size(node->child[1]);
        ll l_l = 0, l_r = 0, r_l = 0, r_r = 0;
        if (node->child[0]) {
            l_l = get_size(node->child[0]->child[0]), l_r = get_size(node->child[0]->child[1]);
        }
        if (node->child[1]) {
            r_l = get_size(node->child[1]->child[0]), r_r = get_size(node->child[1]->child[1]);
        }
        if (l_l > r) {
            node = rota(node, Rota::R);
            node->child[1] = push_down(node->child[1]);
            node = push_down(node);
        } else if (l_r > r) {
            node->child[0] = rota(node->child[0], Rota::L);
            node = rota(node, Rota::R);
            node->child[0] = push_down(node->child[0]);
            node->child[1] = push_down(node->child[1]);
            node = push_down(node);
        } else if (r_l > l) {
            node->child[1] = rota(node->child[1], Rota::R);
            node = rota(node, Rota::L);
            node->child[0] = push_down(node->child[0]);
            node->child[1] = push_down(node->child[1]);
            node = push_down(node);
        } else if (r_r > l) {
            node = rota(node, Rota::L);
            node->child[0] = push_down(node->child[0]);
            node = push_down(node);
        }
        return node;
    }

    Node *add(Node *node, ll val) {
        if (node == nullptr)return new Node(val);
        if (node->val > val) {
            node->child[0] = add(node->child[0], val);
        } else{
            node->child[1] = add(node->child[1], val);
        }
        node->update();
        return push_down(node);
    }

    void remove(ll val) {
        root = del(root, val);
    }

    Node *del(Node *node, ll val) {
        if (node == nullptr)return nullptr;
        if (node->val > val) {
            node->child[0] = del(node->child[0], val);
        } else if (node->val < val) {
            node->child[1] = del(node->child[1], val);
        } else {
            int state = 0;
            if (node->child[0] != nullptr)state |= 1;
            if (node->child[1] != nullptr)state |= 1 << 1;
            if (state == 0)return nullptr;
            else if (state == 1)return node->child[0];
            else if (state == 2)return node->child[1];
            else {
               Node* ans=node->child[0];
               while(ans->child[1])ans=ans->child[1];
               node->val=ans->val;
               node->child[0]=del(node->child[0],node->val);
            }
        }
        node->update();
        return node;
    }
    ll query_kth(ll val) {
        return q_kth(root, val);
    }

    ll q_kth(Node *node, ll val) {
        if(node== nullptr)return 0;
        ll size = node->child[0] == nullptr ? 0 : node->child[0]->size;
        if (val <= node->val) {
            return q_kth(node->child[0], val);
        } else {
            return q_kth(node->child[1], val) + 1 + size;
        }
    }

    ll kth_val(ll k) {
        return k_val(root, k);
    }

    ll k_val(Node *node, ll k) {
        if (node == nullptr)return -1;
        auto size = node->child[0] == nullptr ? 0 : node->child[0]->size;
        if (k <= size)return k_val(node->child[0], k);
        else if (k == size +1)return node->val;
        else return k_val(node->child[1], k - 1 - size);
    }

    ll pre(ll v) {
        return pre_ans(root, v);
    }

    ll pre_ans(Node *node, ll v) {
        ll ans = -2147483647;
        while (node != nullptr) {
            if (node->val < v) {
                ans = node->val;
                node = node->child[1];
            } else node = node->child[0];
        }
        return ans;
    }

    ll nxt(ll v) {
        return nxt_ans(root, v);
    }

    ll nxt_ans(Node *node, ll v) {
        ll ans = 2147483647;
        while (node != nullptr) {
            if (node->val > v) {
                ans = node->val;
                node = node->child[0];
            } else node = node->child[1];
        }
        return ans;
    }
};
struct Node2 {
    Node2 *left;
    Node2 *right;
    SBT *root;
    int l, r, mid;

    Node2(int l, int r) : l(l), r(r) {
        mid = (l + r) >> 1;
        left = right = nullptr;
        root = new SBT();
    }
};
void push_down(Node2 *node) {
    if (node->left == nullptr)node->left = new Node2(node->l, node->mid);
    if (node->right == nullptr)node->right = new Node2(node->mid + 1, node->r);
}

struct Seg {
    Node2 *root;
    int *b;

    Seg(int *a, int n) {
        auto build = [&](auto d, Node2 *node) -> void {
            push_down(node);
            for (int i = node->l; i <= node->r; i++)node->root->insert(a[i]);
            if (node->l == node->r) {
                return;
            }
            d(d, node->left);
            d(d, node->right);
        };
        root = new Node2(0, n - 1);
        build(build, root);
        b = a;
    }

    void update(int k, int val) {
        u(b[k], k, val, root);
        b[k] = val;
    }

    void u(int old, int x, int val, Node2 *node) {
        node->root->remove(old);
        node->root->insert(val);
        if (node->l == x && node->r == x) {
            return;
        }
        if (x <= node->mid)u(old, x, val, node->left);
        else u(old, x, val, node->right);
    }

    ll query(int l, int r, int val) {
        return q(l, r, val, root) + 1;
    }

    ll q(int l, int r, int val, Node2 *node) {
        if (l <= node->l && node->r <= r) return node->root->query_kth(val);
        ll ans = 0;
        if (l <= node->mid)ans += q(l, r, val, node->left);
        if (r > node->mid) ans += q(l, r, val, node->right);
        return ans;
    }

    ll kth(int l, int r, int k) {
        ll left = 0;
        ll right = 1e8;
        while (left < right) {
            auto mid = (right + left + 1) >> 1;
            if (query(l, r, (int) mid) <= k)left = mid;
            else right = mid - 1;
        }
        return left;
    }

    ll pre(int l, int r, int val) {
        return pre_ans(l, r, val, root);
    }

    ll pre_ans(int l, int r, int val, Node2 *node) {
        if (l <= node->l && node->r <= r)return node->root->pre(val);
        ll ans = -2147483647;
        if (l <= node->mid)ans = max(ans, pre_ans(l, r, val, node->left));
        if (r > node->mid)ans = max(ans, pre_ans(l, r, val, node->right));
        return ans;
    }

    ll nxt(int l, int r, int val) {
        return nxt_ans(l, r, val, root);
    }

    ll nxt_ans(int l, int r, int val, Node2 *node) {
        if (l <= node->l && node->r <= r)return node->root->nxt(val);
        ll ans = 2147483647;
        if (l <= node->mid)ans = min(ans, nxt_ans(l, r, val, node->left));
        if (r > node->mid)ans = min(ans, nxt_ans(l, r, val, node->right));
        return ans;
    }
};

void solve() {
    int n, m;
    cin >> n >> m;
    int a[n];
    for (int i = 0; i < n; i++)cin >> a[i];
    auto tree = new Seg(a, n);
    while (m--) {
        int t;
        cin >> t;
        if (t == 1) {
            int l, r, val;
            cin >> l >> r >> val;
            l--, r--;
            cout << tree->query(l, r, val) << endl;
        } else if (t == 2) {
            int l, r, k;
            cin >> l >> r >> k;
            l--, r--;
            cout << tree->kth(l, r, k) << endl;
        } else if (t == 3) {
            int pos, k;
            cin >> pos >> k;
            pos--;
            tree->update(pos, k);
        } else if (t == 4) {
            int l, r, k;
            cin >> l >> r >> k;
            l--, r--;
            cout << tree->pre(l, r, k) << endl;
        } else {
            int l, r, k;
            cin >> l >> r >> k;
            l--, r--;
            cout << tree->nxt(l, r, k) << endl;
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
//------------------------------------------------------
    int t = 1;
//    cin >> t;
    while (t--)solve();
}

总结

SBT是比较优秀的数据结构,其实很接近稳定的BST,可以将它作为一个好的选择.
另两种树属于空间换常数时间,它们的空间一般都为普通平衡树的2.5-3倍.如果题目给出的空间限制为100mb以下还是不要尝试.优点为码量小,常数小.SBT可以用来实现哈希表,在oiwiki上给出了具体实现代码.而另两种可以用来实现堆.

补充:若文章有不足之处,欢迎提出改正.红黑树与AVL属于码量基本上很长,不利于比赛debug,后面有时间会更新乐扣支持的绝大部分语言(大概十几种的样子),每一种平衡树的实现代码(大工程~.这里不得不提提大部分语言内置的平衡树,c++、java为红黑树.python大部分选手都是用的第三方库的sortedList,而c#虽然叫sortedList,但本质底层是红黑树.kotlin、scala都可以调用java的,php、ruby、js、ts不借用第三方库基本都没有.Go常用自己手写的旋转treap,rust是B树,Dart为伸展树...

评论 (2)