终于下定决心学习了线段树,凑了一个较通用的模板,把力扣上相关的题练习了下。
最大的感受是线段树真的很灵活,很难把它当黑箱调用,最好还是要理解每一个函数,按需修改。
因此,这个小结先给出个人的模板,然后按力扣上的题由易到难总结下怎么修改的。
针对每道题,有 注释 的行一定要看下是否需要修改,其它部分一般可以不动(为了速度,有些可以删除,比如不需要懒标记的将 down 函数都去掉)。
重要的概念:
class Seg:
def __init__(self, n, A=None):
self.n = n
self.t = defaultdict(int) # 树节点,维护区间信息
self.f = defaultdict(int) # 懒标记,注意让初始值代表无标记
if A:
self.A = A
self.build()
def up(self,a,b): # 区间归并函数
return a+b
def do(self,o,l,r,x): # 收到更新信息 x 后,树节点和懒标记的具体操作
self.t[o] += x*(r-l+1)
self.f[o] += x
def build(self,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if l==r:
self.t[o] = self.A[l]
return
m = (l+r)//2
self.build(o*2,l,m)
self.build(o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def down(self,o,l,m,r):
if self.f[o] != self.f[0]:
self.do(o*2,l,m,self.f[o])
self.do(o*2+1,m+1,r,self.f[o])
self.f[o] = self.f[0]
def update(self,a,b,x,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
self.do(o,l,r,x)
return
m = (l+r)//2
self.down(o,l,m,r)
if a<=m:
self.update(a,b,x,o*2,l,m)
if m<b:
self.update(a,b,x,o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def query(self,a,b,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
return self.t[o]
m = (l+r)//2
self.down(o,l,m,r)
res = self.t[0] # 查询时的初值,可能要修改
if a<=m:
res = self.up(res,self.query(a,b,o*2,l,m))
if m<b:
res = self.up(res,self.query(a,b,o*2+1,m+1,r))
return res单点更新不需要懒标记,所以懒标记和 down 函数相关部分都可以去掉
class Seg:
def __init__(self, n, A=None):
self.n = n
self.t = [0]*n*4 # 树节点,维护区间信息
if A:
self.A = A
self.build()
def up(self,a,b): # 区间归并函数
return a+b
def do(self,o,l,r,x): # 收到更新信息 x 后,树节点和懒标记的具体操作
self.t[o] = x
def build(self,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if l==r:
self.t[o] = self.A[l]
return
m = (l+r)//2
self.build(o*2,l,m)
self.build(o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def update(self,a,b,x,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
self.do(o,l,r,x)
return
m = (l+r)//2
if a<=m:
self.update(a,b,x,o*2,l,m)
if m<b:
self.update(a,b,x,o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def query(self,a,b,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
return self.t[o]
m = (l+r)//2
res = self.t[0] # 查询时的初值,可能要修改
if a<=m:
res = self.up(res,self.query(a,b,o*2,l,m))
if m<b:
res = self.up(res,self.query(a,b,o*2+1,m+1,r))
return res
class NumArray:
def __init__(self, nums: List[int]):
self.t = Seg(len(nums),nums)
def update(self, index: int, val: int) -> None:
self.t.update(index,index,val)
def sumRange(self, left: int, right: int) -> int:
return self.t.query(left,right)区间更新就必须用到懒标记和 down 函数了。
class Seg:
def __init__(self, n, A=None):
self.n = n
self.t = defaultdict(int) # 树节点,维护区间信息
self.f = defaultdict(lambda:-1) # 懒标记,注意让初始值代表无标记
def up(self,a,b): # 区间归并函数
return a&b
def do(self,o,l,r,x): # 收到更新信息 x 后,树节点和懒标记的具体操作
self.t[o] = x
self.f[o] = x
def down(self,o,l,m,r):
if self.f[o] != self.f[0]:
self.do(o*2,l,m,self.f[o])
self.do(o*2+1,m+1,r,self.f[o])
self.f[o] = self.f[0]
def update(self,a,b,x,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
self.do(o,l,r,x)
return
m = (l+r)//2
self.down(o,l,m,r)
if a<=m:
self.update(a,b,x,o*2,l,m)
if m<b:
self.update(a,b,x,o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def query(self,a,b,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
return self.t[o]
m = (l+r)//2
self.down(o,l,m,r)
res = 1 # 查询时的初值,可能要修改
if a<=m:
res = self.up(res,self.query(a,b,o*2,l,m))
if m<b:
res = self.up(res,self.query(a,b,o*2+1,m+1,r))
return res
class RangeModule:
def __init__(self):
self.tree = Seg(10**9+1)
def addRange(self, left: int, right: int) -> None:
self.tree.update(left,right-1,1)
def queryRange(self, left: int, right: int) -> bool:
return bool(self.tree.query(left,right-1))
def removeRange(self, left: int, right: int) -> None:
self.tree.update(left,right-1,0)有的题目要同时维护区间的多个信息,比如区间和 + 区间最大值。有时 up 函数就会比较复杂了。
1157. 子数组中占绝大多数的元素
class Seg:
def __init__(self, n, A=None):
self.n = n
self.t = defaultdict(lambda:[0,0]) # 树节点,维护区间信息
if A:
self.A = A
self.build()
def up(self,a,b): # 区间归并函数
if a[0]==b[0]:
return [a[0],a[1]+b[1]]
if a[1]>b[1]:
return [a[0],a[1]-b[1]]
return [b[0],b[1]-a[1]]
def build(self,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if l==r:
self.t[o] = self.A[l]
return
m = (l+r)//2
self.build(o*2,l,m)
self.build(o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def query(self,a,b,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
return self.t[o]
m = (l+r)//2
res = self.t[0] # 查询时的初值,可能要修改
if a<=m:
res = self.up(res,self.query(a,b,o*2,l,m))
if m<b:
res = self.up(res,self.query(a,b,o*2+1,m+1,r))
return res
class MajorityChecker:
def __init__(self, arr: List[int]):
self.d = defaultdict(list)
for i,x in enumerate(arr):
self.d[x].append(i)
self.tree = Seg(len(arr),[(a,1) for a in arr])
def query(self, left: int, right: int, threshold: int) -> int:
x = self.tree.query(left,right)[0]
A = self.d[x]
if bisect_right(A,right)- bisect_left(A,left)>=threshold:
return x
return -1class Seg:
def __init__(self, n, A=None):
self.n = n
self.t = [0]*n*4 # 树节点,维护区间信息
self.f = [0]*n*4 # 懒标记,注意让初始值代表无标记
if A:
self.A = A
self.build()
def up(self,a,b): # 区间归并函数
return [max(a[0],b[0]),a[1]+b[1]]
def do(self,o,l,r,x): # 收到更新信息 x 后,树节点和懒标记的具体操作
self.t[o] = [x,x*(r-l+1)]
self.f[o] = x
def build(self,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if l==r:
self.t[o] = self.A[l]
return
m = (l+r)//2
self.build(o*2,l,m)
self.build(o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def down(self,o,l,m,r):
if self.f[o] != self.f[0]:
self.do(o*2,l,m,self.f[o])
self.do(o*2+1,m+1,r,self.f[o])
self.f[o] = self.f[0]
def update(self,a,b,x,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if a<=l and r<=b:
self.do(o,l,r,x)
return
m = (l+r)//2
self.down(o,l,m,r)
if a<=m:
self.update(a,b,x,o*2,l,m)
if m<b:
self.update(a,b,x,o*2+1,m+1,r)
self.t[o] = self.up(self.t[o*2],self.t[o*2+1])
def qr1(self,a,b,x,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if self.t[o][0]<x:
return -1, None
if l==r:
return l,self.t[o][1]-x
m = (l+r)//2
self.down(o,l,m,r)
if a<=m and self.t[o*2][0]>=x:
return self.qr1(a,b,x,o*2,l,m)
return self.qr1(a,b,x,o*2+1,m+1,r) if m<b else (-1,None)
def qr2(self,a,b,x,o=1,l=0,r=None):
r = self.n-1 if r is None else r
if self.t[o][1]<x:
return -1,None
if l==r:
return l,self.t[o][1]-x
m = (l+r)//2
self.down(o,l,m,r)
y = self.t[o*2][1]
if a<=m and y>=x:
return self.qr2(a,b,x,o*2,l,m)
return self.qr2(a,b,x-y,o*2+1,m+1,r) if m<b else (-1,None)
class BookMyShow:
def __init__(self, n: int, m: int):
self.tree = Seg(n,[(m,m) for _ in range(n)])
self.m = m
def gather(self, k: int, maxRow: int) -> List[int]:
i, r = self.tree.qr1(0,maxRow,k)
if i==-1:
return []
self.tree.update(i,i,r)
return [i,self.m-r-k]
def scatter(self, k: int, maxRow: int) -> bool:
i, r = self.tree.qr2(0,maxRow,k)
if i==-1:
return False
if i:
self.tree.update(0,i-1,0)
self.tree.update(i,i,r)
return True