静态主席树教程
1249
2023.07.19
2023.07.19
发布于 未知归属地

什么是主席树,主席树可以通常被认为是可持久化的线段树的.而静态主席树可以支持:

在区间[l,r]上查找第k大的数

前置知识:

  1. 动态开点线段树
  2. 前缀和
  3. 权值线段树

权值线段树

和常见的维护区间的线段树不同,权值线段树是把整个数轴上的数去分区间,例如[-1,-1]其实也是个区间,[-2147483648,2147483647]也是个区间,不再拘泥于下标为区间.这样一来就有一个特别好的性质:有序性,数从左到右是递增的,计数排序也是基于这种思想
那么不用区间了,我们该如何去书写呢?
常见的我们将val作为轴,将cnt作为val,类似与计数排序.
常见的一种是我个人比较喜爱的,这种写法常常可以把Node单独分离去思考维护量有哪些(但常数上会略大,如果非要卡,需要更改为数组版的动态开点)
数组版动态开点一般开id,Root,L,R,Val这几个数组,然后传引用,if(!curr)curr=++id;分配节点.这里不做长篇赘述

动态开点线段树

和常见的线段树不同,动态开点不需要预先开辟出特别大的大小,而且比较适用与无build的操作的线段树.

Python
class Node:
    __slots__ = "l", "mid", "r", "lazy", "val", "left", "right"

    def __init__(self, l: int, r: int, val: int = 0, left: 'Node' = None, right: 'Node' = None):
        self.l = l
        self.r = r
        self.mid = (l + r) >> 1
        self.val = val
        self.left = left
        self.right = right
        self.lazy = False

这里尤其是注意一点mid的话一定要用位运算实现.例如[-2,-1]这个区间,如果使用mid=(-2-1)/2,那么对于-1.5而言,大部分语言是向0取整,所以mid=-1,这样一来则会分割为
[-2.-1],[-1+1,-1]=>[0,-1]是一个不合法的区间,正确是应该取-2,[-2,-2],[-1,-1]所以这里我们常常使用位运算实现.[l,r]就是表示的一个区间例如[-5,8] (闭区间),如果是开区间我们可以通过-1,+1去变为闭区间.left和right表示的是左右线段树区间节点.lazy一般用于区间修改问题,延迟修改.

这里注意的是,如果有build的情况下,root为[lMin,rMax]叶子节点至少都是rMax-lMin+1的叶子节点大小,所以一般非权值线段树这种需要build的线段树我们都不用动态开点,因为它不会比数组写法更优.(特殊的例如主席树而言需要build的动态开的那我们就需要离散化了)

提提两个动态开点常用函数:

  1. push_up:用于叶子节点信息更新父亲节点
    例如:
Python
def push_up(node: 'Node'):
    node.val = node.left.val + node.right.val

也有其他写法,例如摩尔投票法的状态量合并,总而言之,这里考虑如何合并左右节点信息是最关键的

  1. push_down:不仅用于动态开点,还伴随着区间修改的lazy下传
Python
def push_down(node: 'Node'):
    if node.left is None:
        node.left = Node(node.l, node.mid)
    if node.right is None:
        node.right = Node(node.mid + 1, node.r)
    # if node.lazy:
    #     xxxxxxxx

这两个函数常常作为线段树的核心思考,以及Node需要维护哪些变量.其他的就是大同小异的修改与查询.

关于build,对于动态开点线段树而言,build和update、query的书写代码量和参数都是极小的.
例如build:

Python
def build(node: 'Node'):
    if node.l == node.r:
        # node.val=xxxx
        return
    push_down(node)  # 开点
    build(node.left)
    build(node.right)
    push_up(node)

前缀和

这里的前缀和更多是一种思想层次上的.
举个例子:
1 1 1 2 2 2 2 3 请问第4大的数是多少
1:3个,2:4个,3:1个

  1. 那么先看1,3<4,说明我们应该找比1大的数的第(4-3)=1个数(并且移除1,其实这是一个新的子问题),1<=4说明应该找小于等于2的第1个数,这个数就是2,所以第四大的数为2
    上述其实蕴含这类似分治与dp的思想.
  2. 加强问题,[2,4]的第2大的数是多少
    那么我们的第一步显而易见的取出[2,4]区间的数1,1,2,2然后变为了上一个问题
  3. 再次加强问题,如果这个数组无序呢?那么我们先取出[2,4],再将[2,4]排序,然后就变为了问题1.这里排序应该用哪种排序呢?
    答案:计数排序=>权值线段树思想.使用权值线段树来维护[l,r]区间的cnt数量,自带有序.
    那么再来思考假如说我们已经通过线段树拿到了[l,r]区间的权值线段树了如何实现问题1?
    那么这个其实就是经典的权值线段树上二分:
    [l,mid]、[mid+1,r],lSize<=k,找左区间第k大数 else 找右区间第k-lSize大的数
    if l==r:那么当前的l/r就是这个数=>轴上的坐标/点就是代表的这个数
    举例:
    上述问题1,线段树维护[-1e9,1e9],容易可知[1,1]size=3,[2,2]size=4,[3,3]size=1,
    [1,2]=7,[2,3]=5,[1,3]=8.
    开始二分[-1e9,0]、[1,1e9]显然lSize<=4=>在[1,1e9]找第4大的数
    =>例如达到left为[1,2]、right为[3,3]这个点.7>4,所以应该在[1,2]里面找第4大的数
    =>[1,1]、[2,2],4>3,所以应该在[2,2]当中找4-3=1的数,l==r,返回2,这里换成第5、6、7都是一样的,因为最后都不会超出[2,2]的size,最终都为l==r==2

主席树引入

基于上述思想其实有两种其他比较不适合初学者的做法:
树套树:

  1. 线段树套平衡树,直接借助平衡树查找
  2. 线段树套权值线段树(外层可以改为树状数组),借助线段树二分查找

主席树的前缀和思想:
这里我们模拟下上述数列插入的过程更能体会:数列的数从左到右反复添加进权值线段树,那么假如我们要找[2,4]的区间,那么我们去观察当插入第一个数的时候有多少1的size=1,插入第四个数的时候每个数有多少个size,我们惊讶的发现如果用插入第四个数的时候1的size-插入第一个数的时候1的size我们可以得到插入了[2,4]区间上的1的个数2
回忆前缀和:
pre[r]-pre[l-1],其中[0,r]-[0,l-1]=[l,r]我们可以惊讶地发现上述要得到某一个数的size都可以借助前缀和的思想去删掉不需要的区间段.这就是主席树的核心思想:前缀和思想,而这里的前缀和=>朴素地来看每次插入一个数,我们建立一棵权值线段树,可以轻松地算出区间上的总size.
举例:
插入第4个数拿到的线段树其实是[1,4]这个"区间状态"对应的线段树,[-1e9,0],[1,1e9]的size轻松得到,同理拿到[1,1]这个"区间状态"对应的线段树,它的[-1e9,0],[1,1e9]的size也可以轻松得到,二者作差就变为了以上所说的过程了

正式实现前的核心点说明:

这里在正式实现前要提出一个问题:每一个点都建立线段树,那么很显然的是=>空间MLE、时间爆TLE
这里不得不提提动态开点的本质了:
思考查找数x,[l,mid],[mid+1,r],x<=mid找左否则找右,每次向下走一层,向下开点,本质上这中间开辟了一条新链,其他信息和原来一样.
那么这里我们就涉及到了继承信息机制:

C++
//继承前一个版本,使用它的子节点信息(old),但当前节点为新节点
    Node(Node *pre) {
        l = pre->l;
        r = pre->r;
        mid = pre->mid;
        left = pre->left;
        right = pre->right;
        cnt = pre->cnt + 1;//在前一个版本上区间里多加了一个点
    }

使用这样的一个复制写法,我们就可以继承上一棵线段树的其他信息,而只更新当前点信息.
image.png
所以我们需要build一棵空树(就是val为0但节点开满的那种),每次采用"动态开点"更新

正式写法

根据以上所说,我们需要"离散化",因为需要build第一棵"空树",其余的就和线段树相同了,具体看我代码注释
(Golang借助splay去重离散化,没有平衡树内置不喜欢写各种map和二分)

Go
C++
Python
package main

import (
	"bufio"
	. "fmt"
	"io"
	"os"
	"runtime/debug"
)

type T struct {
	val   int
	child [2]*T
	cnt   int
	size  int
	fa    *T
}

func NewT(val int, fa *T) *T {
	return &T{val: val, fa: fa, child: [2]*T{nil, nil}, cnt: 1, size: 1}
}
func (node *T) update() {
	node.size = node.cnt
	if node.child[0] != nil {
		node.size += node.child[0].size
	}
	if node.child[1] != nil {
		node.size += node.child[1].size
	}
}
func (node *T) idx() int {
	if node.fa != nil && node.fa.child[1] == node {
		return 1
	}
	return 0
}
func (node *T) rotate() {
	fa := node.fa
	if fa != nil {
		faf := fa.fa
		i := node.idx()
		if faf != nil {
			faf.child[fa.idx()] = node
		}
		node.fa = faf
		fa.child[i] = node.child[i^1]
		if node.child[i^1] != nil {
			node.child[i^1].fa = fa
		}
		fa.fa = node
		node.child[i^1] = fa
		fa.update()
		node.update()
	}
}
func (node *T) splay() *T {
	for node.fa != nil {
		if node.fa.fa != nil {
			if node.idx() == node.fa.idx() {
				node.fa.rotate()
			} else {
				node.rotate()
			}
		}
		node.rotate()
	}
	return node
}

type Splay struct {
	root *T
}

func NewSplay() *Splay {
	return &Splay{root: nil}
}
func (node *Splay) insert(val int) {
	if node.root == nil {
		node.root = NewT(val, nil)
		return
	}
	tmp := node.root
	for {
		if tmp.val == val {
			tmp.cnt++
			tmp.size++
			node.root = tmp.splay()
			return
		} else {
			i := 0
			if val > tmp.val {
				i = 1
			}
			if tmp.child[i] == nil {
				tmp.child[i] = NewT(val, tmp)
				tmp.update()
				node.root = tmp.splay()
				return
			} else {
				tmp = tmp.child[i]
			}
		}
	}
}
func (node *Splay) Get() (ans []int) {
	var dfs func(*T)
	dfs = func(t *T) {
		if t != nil {
			dfs(t.child[0])
			ans = append(ans, t.val)
			dfs(t.child[1])
		}
	}
	dfs(node.root)
	return
}

var (
	M    []int //映射,下标->value
	n, m int   //读入
)

// Node 线段树节点
type Node struct {
	l, mid, r, cnt int   //表示的区间范围
	left, right    *Node //左右子树
}

// NewNode 新节点
func NewNode(l int, r int) *Node {
	return &Node{l: l, r: r, mid: (l + r) >> 1, cnt: 0, left: nil, right: nil}
}

// Extend 新节点为继承了前一个版本的节点
func Extend(pre *Node) *Node {
	root := NewNode(pre.l, pre.r)
	root.cnt = pre.cnt + 1
	root.left = pre.left
	root.right = pre.right
	return root
}

// Build 第一个版本[l,r]都有点的树
func Build(node *Node) {
	if node.l == node.r {
		return
	}
	node.left = NewNode(node.l, node.mid)
	node.right = NewNode(node.mid+1, node.r)
	Build(node.left)
	Build(node.right)
}

// Add 往[x,x]添加一点,整体是开一条新链
func Add(pre *Node, x int) *Node {
	node := Extend(pre) //继承上一个版本的新节点
	if node.l == node.r {
		return node
	}
	//已经继承版本了,可以直接将自身子树作为pre
	if x <= node.mid {
		node.left = Add(node.left, x)
	} else {
		node.right = Add(node.right, x)
	}
	return node
}

// Query [l,r]上第k大的值
func (t *Seg) Query(l int, r int, k int) int {
	//在权值线段树上二分找到下标的size,使得前缀size和>=k(min)
	var binary func(*Node, *Node, int) int
	binary = func(left *Node, right *Node, i int) int {
		//如果遍历到了一个点,说明这个就是当前的下标
		if left.l == left.r {
			return left.l
		}
		//前缀和作差为当前区间的size总数,先看左半边区间
		diff := right.left.cnt - left.left.cnt
		//左半边区间,否则右半边区间,记得去掉左半边区间的数量,转为一个新的子问题
		if i <= diff {
			return binary(left.left, right.left, i)
		} else {
			return binary(left.right, right.right, i-diff)
		}
	}
	//类似前缀和作差sum[r]-sum[l-1]得到的就是[l,r]这个区间的所有线段树,M是下标从0开始的有序序列
	return M[binary(t.p[l-1], t.p[r], k)-1]
}

type Seg struct {
	h map[int]int //离散化 值=>下标
	p []*Node     //各个版本的根节点
}

func NewSeg(a []int) *Seg {
	h := make(map[int]int)
	var splay = NewSplay() //splay用于去重排序
	for _, v := range a {
		splay.insert(v)
	}
	M = splay.Get() //用于拿到整个有序序列
	for i, v := range M {
		h[v] = i + 1
	}
	var p = make([]*Node, len(a)+1)
	p[0] = NewNode(1, len(M))
	Build(p[0])
	for i, v := range a {
		p[i+1] = Add(p[i], h[v])
	}
	return &Seg{h, p}
}

func main() {
	debug.SetGCPercent(-1)
	Solve(os.Stdin, os.Stdout)
}

func Solve(_r io.Reader, _w io.Writer) {
	in := bufio.NewReader(_r)
	out := bufio.NewWriter(_w)
	defer out.Flush()
	//-----------------------------------------------------
	Fscan(in, &n, &m)
	var a = make([]int, n)
	for i := range a {
		Fscan(in, &a[i])
	}
	var seg = NewSeg(a)
	for m > 0 {
		m--
		var l, r, k int
		Fscan(in, &l, &r, &k)
		Fprintln(out, seg.Query(l, r, k))
	}
}

可持久化数组

可持久化数组嘛就是一个类似上面的东西,这里采用数组动态开点,洛谷那题卡常严重,可以直接预先分配root之类的大小,不过这题数据太大了,对rust来说预先static容易爆,c++可以预先分配(数学上证明一般开至少n(logn+3)的大小,当m(操作次数)和n一个数量级)

Rust
C++
C++(常树较大写法,超大数据难以AC,但便于理解,作为参考)
#![allow(unused_variables)]
#![warn(clippy::large_stack_arrays)]
#![warn(unused_macros)]

use crate::raw::in_out;
use crate::scanner::Scanner;
use std::io::{BufRead, Write};//----------------------------递归闭包---------------------------
// struct Func<'a, A, F>(&'a dyn Fn(Func<'a, A, F>, A) -> F);
//
// impl<'a, A, F> Clone for Func<'a, A, F> {
//     fn clone(&self) -> Self {
//         Self(self.0)
//     }
// }
//
// impl<'a, A, F> Copy for Func<'a, A, F> {}
//
// impl<'a, A, F> Func<'a, A, F> {
//     fn call(&self, f: Func<'a, A, F>, x: A) -> F {
//         (self.0)(f, x)
//     }
// }
//
// fn y<A, R>(g: impl Fn(&dyn Fn(A) -> R, A) -> R) -> impl Fn(A) -> R {
//     move |x| (|f: Func<A, R>, x| f.call(f, x))(Func(&|f, x| g(&|x| f.call(f, x), x)), x)
// }

//Y组合子使用示例:(多参采用元组传参)
// let dfs = | f: & dyn Fn((usize, i32,bool)) -> bool, (i,sum,s): (usize,i32,bool) | -> bool{
//      if i == n {
//          return sum == 0 & & s;
//       }
//      return f((i + 1, sum + a[i], true)) | | f((i + 1, sum, s)) | |
// f((i + 1, sum - a[i], true));
// };
//----------------------------递归闭包---------------------------
//----------------------------常用函数----------------------------
#[allow(dead_code)]
// #[inline]
fn prefix_array<T>(a: &Vec<T>, start: T) -> Vec<T> where T: std::ops::Add<Output=T> + Copy + std::ops::AddAssign {
    (0..=a.len()).scan(start, |x, y| if y == 0 { Some(start) } else {
        *x += a[y - 1];
        Some(*x)
    }).collect::<Vec<T>>()
}

#[allow(dead_code)]
// #[inline]
fn suffix_array<T>(a: &Vec<T>, end: T) -> Vec<T> where T: std::ops::Add<Output=T> + Copy + std::ops::AddAssign {
    let mut tmp = (0..=a.len()).rev().scan(end, |x, y| if y == a.len() { Some(end) } else {
        *x += a[y];
        Some(*x)
    }).collect::<Vec<T>>();
    tmp.reverse();
    tmp
}

#[allow(dead_code)]
extern "C" {
    fn getchar() -> i64;
}

#[allow(dead_code)]
fn next() -> i64 {
    let mut r = 0;
    let mut c;
    loop {
        unsafe { c = getchar(); }
        if c >= 48 && c < 48 + 10 { break; }
    }
    loop {
        r = r * 10 + c - 48;
        unsafe { c = getchar(); }
        if c < 48 || c >= 48 + 10 { break; }
    }
    r
}

//----------------------------常用函数----------------------------
//----------------------------Test----------------------------
const N: usize = 1000010;
static mut L: [i32; 25000250] = [0_i32; 25 * N];
static mut R: [i32; 25000250] = [0_i32; 25 * N];
static mut VAL: [i32; 25000250] = [0_i32; 25 * N];
//m次操作,每次多log(n)个数,可以适当开大不少,开到3e7级别
//分配节点
static mut ID: i32 = 0;
// //原数组
// static mut A: [i32; 1000010] = [0; N];
//根
static mut ROOT: Vec<i32> = vec![];

unsafe fn build(A: &Vec<i32>, l: i32, r: i32, curr: &mut i32) {
    if *curr == 0 {
        ID += 1;
        *curr = ID;
    }
    if l == r {
        VAL[*curr as usize] = A[l as usize];
        return;
    }
    let mid = (l + r) >> 1;
    build(A, l, mid, &mut L[*curr as usize]);
    build(A, mid + 1, r, &mut R[*curr as usize]);
}

unsafe fn update(pre: i32, curr: &mut i32, l: i32, r: i32, pos: i32, val: i32) {
    ID += 1;
    *curr = ID;
    L[*curr as usize] = L[pre as usize];
    R[*curr as usize] = R[pre as usize];
    if l == r {
        VAL[*curr as usize] = val;
        return;
    }
    let mid = (l + r) >> 1;
    if pos <= mid {
        update(L[pre as usize], &mut L[*curr as usize], l, mid, pos, val);
    } else {
        update(R[pre as usize], &mut R[*curr as usize], mid + 1, r, pos, val);
    }
}

unsafe fn query(curr: i32, l: i32, r: i32, pos: i32) -> i32 {
    if l == r {
        return VAL[curr as usize];
    }
    let mid = (l + r) >> 1;
    return if pos <= mid {
        query(L[curr as usize], l, mid, pos)
    } else {
        query(R[curr as usize], mid + 1, r, pos)
    };
}

//----------------------------Test----------------------------
// #[inline]
pub unsafe fn solve<R: BufRead, W: Write>(mut scanner: Scanner<R>, out: &mut W) {
    //---------------------------------------------常用宏---------------------------------------------
    #[allow(unused_macros)]
    macro_rules! puts {($($format:tt)*) => (let _ = writeln!(out,$($format)*););}
    #[allow(unused_macros)]
    macro_rules! r_usize {() => {scanner.next::<usize>()};}
    #[allow(unused_macros)]
    macro_rules! r_i32 {() => {scanner.next::<i32>()};}
    #[allow(unused_macros)]
    macro_rules! r_i64 {() => {scanner.next::<i64>()};}
    #[allow(unused_macros)]
    macro_rules! r_i128 {() => {scanner.next::<i128>()};}
    #[allow(unused_macros)]
    macro_rules! r_isize {() => {scanner.next::<isize>()};}
    #[allow(unused_macros)]
    macro_rules! r_s_u8 {() => {scanner.next::<String>().into_bytes()};}
    #[allow(unused_macros)]
    macro_rules! read_usize {($n:expr) => {(0..$n).map(|_|scanner.next::<usize>()).collect::<Vec<usize>>()};}
    #[allow(unused_macros)]
    macro_rules! read_i32 {($n:expr) => {(0..$n).map(|_|scanner.next::<i32>()).collect::<Vec<i32>>()};}
    #[allow(unused_macros)]
    macro_rules! read_i64 {($n:expr) => {(0..$n).map(|_|scanner.next::<i64>()).collect::<Vec<i64>>()};}
    #[allow(unused_macros)]
    macro_rules! read_i128 {($n:expr) => {(0..$n).map(|_|scanner.next::<i128>()).collect::<Vec<i128>>()};}
    #[allow(unused_macros)]
    macro_rules! read_tow_array_usize {($n:expr,$m:expr) => {(0..$n).map(|_|read_usize!($m)).collect::<Vec<Vec<usize>>>()};}
    #[allow(unused_macros)]
    macro_rules! read_tow_array_i32 {($n:expr,$m:expr) => {(0..$n).map(|_|read_i32!($m)).collect::<Vec<Vec<i32>>>()};}
    #[allow(unused_macros)]
    macro_rules! read_tow_array_i64 {($n:expr,$m:expr) => {(0..$n).map(|_|read_i64!($m)).collect::<Vec<Vec<i64>>>()};}
    #[allow(unused_macros)]
    macro_rules! count_bit {($n:expr) => {{let(mut ans,mut k)=(0_usize,$n);while k>0{ans+=1;k&=k-1;}ans}};}
    #[allow(unused_macros)]
    macro_rules! print_all {($A:expr) => {{for &v in &$A{let _ = write!(out, "{} ", v);}puts!();}};}
    //-----------------------------------------------------------------------------------------------
    let n = r_usize!();
    let m = r_usize!();
    let mut A = vec![0; n + 1];
    for i in 1..=n {
        A[i] = scanner.next::<i32>();
    }
    ROOT.push(0);
    build(&A, 1, n as i32, &mut ROOT[0]);
    for i in 1..=m {
        ROOT.push(0);
        let v = r_i32!();//版本号
        if r_usize!() == 1 {
            let pos = r_i32!();
            let val = r_i32!();
            update(ROOT[v as usize], &mut ROOT[i], 1, n as i32, pos, val);
        } else {
            let pos = r_i32!();
            puts!("{}",query(ROOT[v as usize],1,n as i32,pos));
            ROOT[i] = ROOT[v as usize];
        }
    }
}

//-----------------------------main-------------------------------------
fn main() {
    let (stdin, mut stdout) = in_out();
    unsafe { solve(Scanner::new(stdin), &mut stdout); }
}

// --------------------------- tools -----------------------------------
mod raw {
    use std::fs::File;
    use std::io::{BufRead, BufReader, BufWriter, stdin, stdout, Write};


    #[cfg(windows)]
    pub fn in_out() -> (impl BufRead, impl Write) {
        use std::os::windows::prelude::{AsRawHandle, FromRawHandle};
        unsafe {
            let stdin = File::from_raw_handle(stdin().as_raw_handle());
            let stdout = File::from_raw_handle(stdout().as_raw_handle());
            (BufReader::new(stdin), BufWriter::new(stdout))
        }
    }

    #[cfg(unix)]
    pub fn in_out() -> (impl BufRead, impl Write) {
        use std::os::unix::prelude::{AsRawFd, FromRawFd};
        unsafe {
            let stdin = File::from_raw_fd(stdin().as_raw_fd());
            let stdout = File::from_raw_fd(stdout().as_raw_fd());
            (BufReader::new(stdin), BufWriter::new(stdout))
        }
    }
}

mod scanner {
    use std::io::BufRead;

    pub struct Scanner<R> {
        reader: R,
        buf_str: Vec<u8>,
        buf_iter: std::str::SplitAsciiWhitespace<'static>,
    }

    impl<R: BufRead> Scanner<R> {
        pub fn new(reader: R) -> Self {
            Self { reader, buf_str: Vec::new(), buf_iter: "".split_ascii_whitespace() }
        }
        pub fn next<T: std::str::FromStr>(&mut self) -> T {
            loop {
                if let Some(token) = self.buf_iter.next() {
                    return token.parse().ok().expect("Failed parse");
                }
                unsafe { self.buf_str.set_len(0); }
                self.reader.read_until(b'\n', &mut self.buf_str).expect("Failed read");
                self.buf_iter = unsafe {
                    let slice = std::str::from_utf8_unchecked(&self.buf_str);
                    std::mem::transmute(slice.split_ascii_whitespace())
                }
            }
        }
    }
}
评论 (0)