线段树 Segment tree
4954
2022.01.04
2026.02.10
发布于 未知归属地

线段树是算法中常用的用来维护区间信息的数据结构。一个包含n个区间的线段树,空间复杂度为,查询的时间复杂度则为,其中是符合条件的区间数量。
image.png
线段树将每个长度不为 的区间划分成左右两个区间递归求解,把整个线段划分为一个树形结构,通过合并左右两区间信息来求得该区间的信息。这种数据结构可以方便的进行大部分的区间操作。
在实现时,我们考虑递归建树。设当前的根节点为 ,如果根节点管辖的区间长度已经是 ,则可以直接根据 数组上相应位置的值初始化该节点。否则我们将该区间从中点处分割为两个子区间,分别进入左右子节点递归建树,最后合并两个子节点的信息。

segt1.svg

构建线段树

s和t是当前线段树的左右结点范围,p为父结点下标,arr为构建树的输入数组

void build(int s, int t, int p, const vector<int>& arr) {
    if (s == t) {
        tree[p] = SegmentItem(arr[s], 1);
        return;
    }
    int m = s + ((t - s) >> 1);
    build(s, m, p * 2, arr), build(m + 1, t, p * 2 + 1, arr);
    // push_up
    tree[p] = tree[p * 2] + tree[(p * 2) + 1];
}

查询

SegmentItem find(int l, int r, int s, int t, int p) {
    // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
    if (l <= s && t <= r)
        return tree[p];  // 当前区间为询问区间的子集时直接返回当前区间的和
    int m = s + ((t - s) >> 1);
    SegmentItem sum;
    if (r <= m) return find(l, r, s, m, p * 2);
    // 如果左儿子代表的区间 [l, m] 与询问区间有交集, 则递归查询左儿子
    if (l > m) return find(l, r, m + 1, t, p * 2 + 1);
    // 如果右儿子代表的区间 [m + 1, r] 与询问区间有交集, 则递归查询右儿子
    return find(l, r, s, m, p * 2) + find(l, r, m + 1, t, p * 2 + 1);
}

线段树的区间修改与懒惰标记

如果要求修改区间 ,把所有包含在区间 中的节点都遍历一次、修改一次,时间复杂度无法承受。我们这里要引入一个叫做 「懒惰标记」 的东西。懒惰标记,简单来说,就是通过延迟对节点信息的更改,从而减少可能不必要的操作次数。每次执行修改时,我们通过打标记的方法表明该节点对应的区间在某一次操作中被更改,但不更新该节点的子节点的信息。实质性的修改则在下一次访问带有标记的节点时才进行。
仍然以最开始的图为例,我们将执行若干次给区间内的数加上一个值的操作。我们现在给每个节点增加一个表示该节点带的标记值。

更多
更多

###举例
1157.子数组占绝大多数的元素

class SegmentItem{
public:
    SegmentItem(){}; 
    SegmentItem(int val, int cnt):val(val), cnt(cnt){};
    SegmentItem operator+(const SegmentItem& node){
        SegmentItem newNode = *this;
        if(val == node.val) newNode.cnt += node.cnt;
        else if(cnt >= node.cnt) newNode.cnt -= node.cnt;
        else newNode = node, newNode.cnt -= cnt;
        return newNode;
    }
    SegmentItem& operator=(const SegmentItem& node){
        val = node.val, cnt = node.cnt;
        return *this;
    }
    SegmentItem& operator+=(const SegmentItem& node){
        val += node.val, cnt += node.cnt;
        return *this;
    }
    int get_val(){return val;}
private:
    int val{0};
    int cnt{0};
};

class MajorityChecker {
    vector<int> v[20001];
    SegmentItem tree[65536]{};
    int n{0};
public:
    MajorityChecker(vector<int>& arr) {
        for(int i = 0; i < arr.size(); ++i){
            v[arr[i]].push_back(i);
        };
        n = arr.size();
        build(0, n - 1, 1, arr);
    }

    void build(int s, int t, int p, const vector<int>& arr) {
        if (s == t) {
            tree[p] = SegmentItem(arr[s], 1);
            return;
        }
        int m = s + ((t - s) >> 1);
        build(s, m, p * 2, arr), build(m + 1, t, p * 2 + 1, arr);
        // push_up
        tree[p] = tree[p * 2] + tree[(p * 2) + 1];
    }
    SegmentItem find(int l, int r, int s, int t, int p) {
        // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
        if (l <= s && t <= r)
            return tree[p];  // 当前区间为询问区间的子集时直接返回当前区间的和
        int m = s + ((t - s) >> 1);
        SegmentItem sum;
        if (r <= m) return find(l, r, s, m, p * 2);
        // 如果左儿子代表的区间 [l, m] 与询问区间有交集, 则递归查询左儿子
        if (l > m) return find(l, r, m + 1, t, p * 2 + 1);
        // 如果右儿子代表的区间 [m + 1, r] 与询问区间有交集, 则递归查询右儿子
        return find(l, r, s, m, p * 2) + find(l, r, m + 1, t, p * 2 + 1);
    }
    int query(int left, int right, int threshold) {
        auto node = find(left, right, 0, n - 1, 1);
        auto i = node.get_val();
        auto l = lower_bound(v[i].begin(), v[i].end(), left) - v[i].begin();
        auto r = upper_bound(v[i].begin(), v[i].end(), right) - v[i].begin();
        if (r - l >= threshold) return i;
        return -1;
    }
};

327. 区间和的个数

class Solution {
    using LL = long long;
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        int ans = 0;
        // 离散
        vector<LL> preSum{0LL};
        for(auto& i : nums) preSum.push_back(preSum.back()+i);
        set<LL> sums;
        for(auto& i : preSum){
            sums.insert(i);
            sums.insert(i-lower);
            sums.insert(i-upper);
        }
        unordered_map<LL, int> idx;
        int i = 0;
        for(auto& v : sums) idx[v] = i++;
        int n = idx.size() - 1;
        // 建树
        for(auto& x : preSum){
            auto l = idx[x-upper], r = idx[x-lower];
            ans += query(l, r, 0, n, 1); //从根节点id=1开始查询
            update(idx[x], idx[x], 0, n, 1, 1);//从根结点id=1开始更新
        }
        return ans;
    }

    inline int ls(int p){return p<<1;}//左儿子 
	inline int rs(int p){return p<<1|1;}//右儿子 
    inline void f(int l, int r, int p, int k){
        tag[p] += k;
        arr[p] += k * (r - l + 1);
        //由于是这个区间统一改变,所以ans数组要加元素个数次
    }
    void push_up(int p){arr[p] = arr[ls(p)] + arr[rs(p)];}
    void push_down(int s, int t, int p){
        // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
        auto m = (s + t) >> 1;
        f(s, m, ls(p), tag[p]);
        f(m + 1, t, rs(p), tag[p]);
        // 清空父节点懒标记
        tag[p] = 0;
    }
    void build(int s, int t, int p) {
        tag[p] = 0;
        // 对 [s,t] 区间建立线段树,当前根的编号为 p
        if (s == t) {
            arr[p] = arr[s];
            return;
        }
        int m = (t + s) >> 1;
        build(s, m, ls(p));
        build(m + 1, t, rs(p));
        // push_up
        push_up(p);
    }
    void update(int l, int r, int s, int t, int p, int k) {
        /// [l, r] 为修改区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号, k 为被修改的元素的变化量
        if (l <= s && t <= r) {
            f(s, t, p, k);
            //arr[p] += k;
            return;
        }
        // push down 懒标记
        push_down(s, t, p);
        
        int m = (t + s) >> 1;
        if (l <= m) update(l, r, s, m, ls(p), k);
        if (r > m) update(l, r, m + 1, t, rs(p), k);
        push_up(p);
    }

    int query(int l, int r, int s, int t, int p) {
        // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
        if (l <= s && t <= r)
            return arr[p];  // 当前区间为询问区间的子集时直接返回当前区间的和
        int m = (t + s) >> 1;
        if (r <= m) return query(l, r, s, m, ls(p));
        // 如果左儿子代表的区间 [l, m] 与询问区间有交集, 则递归查询左儿子
        if (l > m) return query(l, r, m + 1, t, rs(p));
        // 如果右儿子代表的区间 [m + 1, r] 与询问区间有交集, 则递归查询右儿子
        return query(l, r, s, m, ls(p)) + query(l, r, m + 1, t, rs(p));
    }
    
private:
    int arr[2000000]{};
    int tag[2000000]{};
    int n{0};
};

6030. 由单个字符重复的最长子字符串

评论 (0)