分享 | python 线段树入门小结
3368
2023.06.23
2023.06.23
发布于 未知归属地
  1. 前言

  2. 个人模板

  3. 力扣题目

    1. 一、单点更新区间查询
    2. 二、区间更新区间查询
    3. 三、区间多个信息
    4. 四、树上二分
  4. 结尾

前言

终于下定决心学习了线段树,凑了一个较通用的模板,把力扣上相关的题练习了下。

最大的感受是线段树真的很灵活,很难把它当黑箱调用,最好还是要理解每一个函数,按需修改。

因此,这个小结先给出个人的模板,然后按力扣上的题由易到难总结下怎么修改的。

个人模板

针对每道题,有 注释 的行一定要看下是否需要修改,其它部分一般可以不动(为了速度,有些可以删除,比如不需要懒标记的将 down 函数都去掉)。

重要的概念:

  • 树节点 维护区间信息,默认是动态开点,初始值是 0,根据需要修改
  • 懒标记 维护子节点的延后更新信息,默认是动态开点,初始值是 0,根据需要修改
  • up 函数 代表由子节点计算父节点的函数,默认是求和,根据需要修改
  • do 函数 代表对树节点的具体更新操作,默认是将对应区间的每个数都加上 x,根据需要修改
  • 查询时的初值 默认是树节点的初值,特殊情况修改
Python
class Seg:
    def __init__(self, n, A=None):
        self.n = n                     
        self.t = defaultdict(int)      # 树节点,维护区间信息
        self.f = defaultdict(int)      # 懒标记,注意让初始值代表无标记
        if A:                          
            self.A = A
            self.build()

    def up(self,a,b):                  # 区间归并函数
        return a+b

    def do(self,o,l,r,x):              # 收到更新信息 x 后,树节点和懒标记的具体操作
        self.t[o] += x*(r-l+1)
        self.f[o] += x

    def build(self,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if l==r:
            self.t[o] = self.A[l]
            return
        m = (l+r)//2
        self.build(o*2,l,m)
        self.build(o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def down(self,o,l,m,r):
        if self.f[o] != self.f[0]:
            self.do(o*2,l,m,self.f[o])
            self.do(o*2+1,m+1,r,self.f[o])
            self.f[o] = self.f[0]

    def update(self,a,b,x,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            self.do(o,l,r,x)
            return
        m = (l+r)//2
        self.down(o,l,m,r)
        if a<=m:
            self.update(a,b,x,o*2,l,m)
        if m<b:
            self.update(a,b,x,o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def query(self,a,b,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            return self.t[o]
        m = (l+r)//2
        self.down(o,l,m,r)
        res = self.t[0]                      # 查询时的初值,可能要修改 
        if a<=m:
            res = self.up(res,self.query(a,b,o*2,l,m))
        if m<b:
            res = self.up(res,self.query(a,b,o*2+1,m+1,r))
        return res

力扣题目

一、单点更新区间查询

单点更新不需要懒标记,所以懒标记和 down 函数相关部分都可以去掉

307. 区域和检索 - 数组可修改

  • up 函数:树节点维护的是区间和 ,所以无需修改
  • do 函数:更新是将区间(单点区间)赋值,所以修改为赋值

2407. 最长递增子序列 II

  • up 函数:树节点维护的是区间最大值,修改为 max 即可
  • do 函数:同上一题,修改为赋值
  • 没有初始数组,所以 build 部分也可以去掉

2736. 最大和查询

  • 树节点:区间范围很大,所以用动态开点。为了方便,可以令树节点的初值为 -1
  • up 函数:同上一题,修改为 max
  • do 函数:更新是覆盖为更大值,需要修改
307. 区域和检索 - 数组可修改
2407. 最长递增子序列 I
2736. 最大和查询
class Seg:
    def __init__(self, n, A=None):
        self.n = n                     
        self.t = [0]*n*4               # 树节点,维护区间信息
        if A:                          
            self.A = A
            self.build()

    def up(self,a,b):                  # 区间归并函数
        return a+b

    def do(self,o,l,r,x):              # 收到更新信息 x 后,树节点和懒标记的具体操作
        self.t[o] = x

    def build(self,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if l==r:
            self.t[o] = self.A[l]
            return
        m = (l+r)//2
        self.build(o*2,l,m)
        self.build(o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def update(self,a,b,x,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            self.do(o,l,r,x)
            return
        m = (l+r)//2
        if a<=m:
            self.update(a,b,x,o*2,l,m)
        if m<b:
            self.update(a,b,x,o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def query(self,a,b,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            return self.t[o]
        m = (l+r)//2
        res = self.t[0]                      # 查询时的初值,可能要修改 
        if a<=m:
            res = self.up(res,self.query(a,b,o*2,l,m))
        if m<b:
            res = self.up(res,self.query(a,b,o*2+1,m+1,r))
        return res

class NumArray:

    def __init__(self, nums: List[int]):
        self.t = Seg(len(nums),nums)

    def update(self, index: int, val: int) -> None:
        self.t.update(index,index,val)

    def sumRange(self, left: int, right: int) -> int:
        return self.t.query(left,right)

二、区间更新区间查询

区间更新就必须用到懒标记和 down 函数了。

715. Range 模块

  • 树节点:范围大,用动态开点
  • 懒标记:注意本题有赋值 0 操作,0 不代表无标记,所以修改初始值
  • up 函数:维护的是区间与,要修改
  • do 函数:更新是对区间每个数赋值,要修改
  • 查询初值:本题是与运算,因此初值要为 1

732. 我的日程安排表 III

  • 树节点:依然动态开点
  • up 函数:维护的是区间最大值,修改为 max
  • do 函数:更新是将区间每个数都加1,因此树节点(注意维护的是最大值)加 1,下传懒标记也加 1,修改 do 函数
  • 只查询整个区间,所以 query 函数可以删去

2276. 统计区间中的整数数目

  • 树节点:依然动态开点
  • up 函数:维护的是区间和,无需修改
  • do 函数:更新是将区间每个数都赋值 1,因此修改树节点为区间长度,下传懒标记也赋值 1

2569. 更新数组后处理求和查询

  • up 函数:维护的是区间和,无需修改
  • do 函数:更新是将区间每个数都反转,因此树节点 t[o] 更新为 区间长度-t[o],下传懒标记也要反转(想想两个更新操作怎么叠加的)
715. Range 模块
732. 我的日程...
2276. 统计区间...
2569. 更新数组后...
class Seg:
    def __init__(self, n, A=None):
        self.n = n                     
        self.t = defaultdict(int)      # 树节点,维护区间信息
        self.f = defaultdict(lambda:-1)      # 懒标记,注意让初始值代表无标记

    def up(self,a,b):                  # 区间归并函数
        return a&b

    def do(self,o,l,r,x):              # 收到更新信息 x 后,树节点和懒标记的具体操作
        self.t[o] = x
        self.f[o] = x

    def down(self,o,l,m,r):
        if self.f[o] != self.f[0]:
            self.do(o*2,l,m,self.f[o])
            self.do(o*2+1,m+1,r,self.f[o])
            self.f[o] = self.f[0]

    def update(self,a,b,x,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            self.do(o,l,r,x)
            return
        m = (l+r)//2
        self.down(o,l,m,r)
        if a<=m:
            self.update(a,b,x,o*2,l,m)
        if m<b:
            self.update(a,b,x,o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def query(self,a,b,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            return self.t[o]
        m = (l+r)//2
        self.down(o,l,m,r)
        res = 1               # 查询时的初值,可能要修改 
        if a<=m:
            res = self.up(res,self.query(a,b,o*2,l,m))
        if m<b:
            res = self.up(res,self.query(a,b,o*2+1,m+1,r))
        return res

class RangeModule:

    def __init__(self):
        self.tree = Seg(10**9+1)

    def addRange(self, left: int, right: int) -> None:
        self.tree.update(left,right-1,1)

    def queryRange(self, left: int, right: int) -> bool:
        return bool(self.tree.query(left,right-1))

    def removeRange(self, left: int, right: int) -> None:
        self.tree.update(left,right-1,0)

三、区间多个信息

有的题目要同时维护区间的多个信息,比如区间和 + 区间最大值。有时 up 函数就会比较复杂了。
1157. 子数组中占绝大多数的元素

  • 树节点:维护的是区间进行摩尔投票后剩下的元素及其个数,有两个信息
  • up 函数:模拟摩尔投票
  • do 函数:本题没有区间更新,所以懒标记、down、do、update 全都可以去掉。。。

1622. 奇妙序列

  • 懒标记:维护乘和加的系数,两个信息
  • do 函数:线性运算,注意取模

2213. 由单个字符重复的最长子字符串

  • 树节点:为了有足够的信息得到区间的最长重复子串,需要同时维护 "最长重复子串的长度、区间长度、前缀最长重复子串的字符和长度、后缀最长重复子串的字符和长度" 共 6 个信息(不过反正要 build,初值其实无所谓。。。)
  • 懒标记:单点更新,无需懒标记和 down 函数
  • up 函数:根据 2 个子节点的 6 个信息推出父节点的 6 个信息,比较麻烦。。。
  • do 函数:单点更新,直接赋值即可
1157. 子数组中占...
1622. 奇妙序列
2213. 由单个字符...
class Seg:
    def __init__(self, n, A=None):
        self.n = n                     
        self.t = defaultdict(lambda:[0,0])      # 树节点,维护区间信息
        if A:                          
            self.A = A
            self.build()

    def up(self,a,b):                  # 区间归并函数
        if a[0]==b[0]:
            return [a[0],a[1]+b[1]]
        if a[1]>b[1]:
            return [a[0],a[1]-b[1]]
        return [b[0],b[1]-a[1]]

    def build(self,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if l==r:
            self.t[o] = self.A[l]
            return
        m = (l+r)//2
        self.build(o*2,l,m)
        self.build(o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def query(self,a,b,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            return self.t[o]
        m = (l+r)//2
        res = self.t[0]                      # 查询时的初值,可能要修改 
        if a<=m:
            res = self.up(res,self.query(a,b,o*2,l,m))
        if m<b:
            res = self.up(res,self.query(a,b,o*2+1,m+1,r))
        return res

class MajorityChecker:

    def __init__(self, arr: List[int]):
        self.d = defaultdict(list)
        for i,x in enumerate(arr):
            self.d[x].append(i)
        self.tree = Seg(len(arr),[(a,1) for a in arr])

    def query(self, left: int, right: int, threshold: int) -> int:
        x = self.tree.query(left,right)[0]
        A = self.d[x]
        if bisect_right(A,right)- bisect_left(A,left)>=threshold:
            return x
        return -1

四、树上二分

2286. 以组为单位订音乐会的门票

  • 树节点:维护区间最大值、区间和两个信息
  • qr1 函数:二分查找 [a,b] 区间内第一个>=x的位置,仿照 query 函数来写,不过每次只会递归一边
  • qr2 函数:二分查找 [a,b] 区间内第一个使得 [a,i]区间和>=x的位置 i,同样每次递归一边
2286. 以组为单位订音乐会的门票
class Seg:
    def __init__(self, n, A=None):
        self.n = n                     
        self.t = [0]*n*4      # 树节点,维护区间信息
        self.f = [0]*n*4      # 懒标记,注意让初始值代表无标记
        if A:                          
            self.A = A
            self.build()

    def up(self,a,b):                  # 区间归并函数
        return [max(a[0],b[0]),a[1]+b[1]]

    def do(self,o,l,r,x):              # 收到更新信息 x 后,树节点和懒标记的具体操作
        self.t[o] = [x,x*(r-l+1)]
        self.f[o] = x

    def build(self,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if l==r:
            self.t[o] = self.A[l]
            return
        m = (l+r)//2
        self.build(o*2,l,m)
        self.build(o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def down(self,o,l,m,r):
        if self.f[o] != self.f[0]:
            self.do(o*2,l,m,self.f[o])
            self.do(o*2+1,m+1,r,self.f[o])
            self.f[o] = self.f[0]

    def update(self,a,b,x,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if a<=l and r<=b:
            self.do(o,l,r,x)
            return
        m = (l+r)//2
        self.down(o,l,m,r)
        if a<=m:
            self.update(a,b,x,o*2,l,m)
        if m<b:
            self.update(a,b,x,o*2+1,m+1,r)
        self.t[o] = self.up(self.t[o*2],self.t[o*2+1])

    def qr1(self,a,b,x,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if self.t[o][0]<x:
            return -1, None
        if l==r:
            return l,self.t[o][1]-x
        m = (l+r)//2
        self.down(o,l,m,r)
        if a<=m and self.t[o*2][0]>=x:
            return self.qr1(a,b,x,o*2,l,m)
        return self.qr1(a,b,x,o*2+1,m+1,r) if m<b else (-1,None)
    
    def qr2(self,a,b,x,o=1,l=0,r=None):
        r = self.n-1 if r is None else r
        if self.t[o][1]<x:
            return -1,None
        if l==r:
            return l,self.t[o][1]-x
        m = (l+r)//2
        self.down(o,l,m,r)
        y = self.t[o*2][1]
        if a<=m and y>=x:
            return self.qr2(a,b,x,o*2,l,m)
        return self.qr2(a,b,x-y,o*2+1,m+1,r) if m<b else (-1,None)

class BookMyShow:

    def __init__(self, n: int, m: int):
        self.tree = Seg(n,[(m,m) for _ in range(n)])
        self.m = m

    def gather(self, k: int, maxRow: int) -> List[int]:
        i, r = self.tree.qr1(0,maxRow,k)
        if i==-1:
            return []
        self.tree.update(i,i,r)
        return [i,self.m-r-k]
        
    def scatter(self, k: int, maxRow: int) -> bool:
        i, r = self.tree.qr2(0,maxRow,k)
        if i==-1:
            return False
        if i:
            self.tree.update(0,i-1,0)
        self.tree.update(i,i,r)
        return True

结尾

微信图片_20230623162241.jpg

评论 (5)