分享|聊聊树链剖分与树上启发式合并基础
2150
2023.08.19
2023.08.19
发布于 未知归属地

树链剖分,一个听上去比较高大上的算法,实际比较容易.为什么要将树上启发式合并(dsu on tree)与树链剖分一起讲了,因为二者都利用到了轻重链的思想.(后续有空更新其他树上问题例如LCT、点分治、点分树等等)

剖分略谈

剖分主要是将树上问题转化为线性问题.那么这有什么好处?线性问题回忆下常见的数组、查询等问题.会发现我们有异常多的ds可供使用.比如线段树、树状数组、ST表等等.
剖分一般分为三种:重链剖分、实链剖分、长链剖分
像轻重链就是我们一般用的剖分了,实链剖分为LCT(动态树最常见的一种)的实现方式(前提得至少会splay平衡树的核心代码书写),长链剖分(解决k级祖先之类的问题)
=>简单说说树链剖分能解决的问题局限性(在树不存在删点、加点的情况下)去维护树上的信息的查询与修改.
=>动态树可以在原来的基础上解决删点和加点以后的信息维护.你大致可以类比下动态开点线段树和静态线段树的区别吧.其中最常用的就是LCT(Link Cut tree),会spaly其实挺好写的.就是函数比较多.

DFS序

简单说说剖分的过程.剖分实际上我们回忆下一个经典的东西--DFS序,这个东西在后文也会用到.
image.png
image.png
简单搬两张图说明下.思考下dfs的过程. 这里主要思考一个点,进入某个点的dfs(curr),当递归完毕以后出来某个点的dfs(curr)

Python
def dfs(curr: int, fa: int):
    for v in child[curr]:
        if v != fa:
            dfs(v, curr)

一个常见的树上dfs观察一下,容易看出来当进入某个节点时==进入以这个节点为根的子树,递归结束出来某个节点的时候=>刚好遍历完以这个节点为根的子树所有节点.我们可以将这个遍历的过程遇到的所有节点按先后顺序全部依次排成一个数组.容易发现每个节点一定都会被访问两次,分别是进入和结束.这个序列我们称之为:"dfs序",有什么好处呢?容易观察到,在两个相同点之间出现的子序列,其实就是以这个点为根的子树所有节点

轻重链剖分

先说说啥叫链吧,类似线段树访问一样,从某个节点每次往下走一步一直到叶子节点,这就是一条链了.引用一张oiwiki的图吧
image.png
图上有个叫dfn序的,其实就是节点dfs的访问顺序,某个节点出现一次,恰不多就这样,简单提提.
图上绿色框起来的其实就是一条条链.轻重链的话,根据字面意思你也大概猜到了.对于某个节点而言,它是有很多儿子的,每个以儿子为根的子树也有自己的节点数.我们将拥有最多节点数的儿子叫做重儿子,其他的就是轻儿子.概念是很easy的吧.

第一次dfs

先不用考虑什么轻重链是啥,我们第一次dfs只需要找到每个节点的重儿子就行,顺便维护点比如深度之类的信息.

C++
void dfs1(int curr, int parent) {
    fa[curr] = parent;
    deep[curr] = deep[parent] + 1;
    siz[curr] = 1;
    for (auto nxt: child[curr]) {
        if (nxt == parent)continue;
        dfs1(nxt, curr);
        siz[curr] += siz[nxt];
        if (siz[nxt] > siz[son[curr]])son[curr] = nxt;
    }
}

!!!:我们这里聊的树是有根树,是有一个明确的树根的,而不是无根树
这个dfs是一个非常常规的dfs,很多都是显而易见的.例如当前节点的父亲,当前节点的深度=父亲节点深度+1,当前节点的子树大小有dfs结束依次加上每个子树为根的子树大小.在这些常规的基础操作下,我们只需要加一句话,找节点数量最多的那个子树的根作为"重儿子".
(关于树上dfs和常规的图的dfs不太一样,只需要维护一个fa不往回走就行,以一个点为根不往回走就能走到底了)

可以看到第一次遍历我们处理了比较多的信息,有fa数组=>表示树的每个节点的父亲,deep数组=>表示每个点的深度,siz数组=>表示以每个节点为根的子树大小,son数组=>每个点的重儿子.
到此以上还是很easy的,聊点题外话

建树

树是一种特殊的图,所以图的三种建图方式树也可以用:领接矩阵在树中用的比较少,这里忽略.

邻接表比较常用,以c++为例.通常开做=>
vector<int> child[N],其中N为预估的节点的上界数

那么加点为:

C++
child[x].push_back(y);
child[y].push_back(x);

当然int可以换成pair<int,int>之类的存储更多信息

链式前向星

这个可能挺多人比较陌生的,也是属于常数比较小的建图方式.(很多人有点望而生畏)

其实这玩意和链表恰不多.举个例子,我们将比如和节点1相邻的所有点取出,按照题目的输入顺序(任意顺序都行),排成一行:
比如?题目输入当中有:
1 2
1 4
1 3
这样的三条边(只给出与1相连的)
那么2、4、3就与1相邻,我们排成一行:
1 2 4 3,类似链表的形式我们这样写:1    2<-4<-3
很容易知道类似有nxt[2]=4,nxt[4]=3,nxt[3]=null,而head[1]=2.这是灰常灰常简单的链表知识.现在我们将一条条链放在了一个数组上:
1 2
2 3
2 4
1 5
无标题.png
容易看出有

对于1来说  2<-5
对于2来说  3<-4
对于3来说  2
对于4来说  2
对于5来说  1
那么我们可以把这些链分布在数组上(很显然,一条链的每个点不一定相连)
例如根据插入顺序我们依次放点:
1 2
2 3
2 4
1 5
每个边分别插入两次:例如1->2,2->1
1 2 2 1 2 3 3 2 2 4 4 2 1 5 5 1(全部写出来,实际上我们只需要写出出边,例如1->2只写2)

2 1 3 2 4 2 5 1(分别省略一个点)
把这些点分布在数组上,可以看出每个点被分配了下标,我们可以用类似链表的形式开一个数组:
nxt[下标]=这条链的下一个点的下标.特殊的,空节点分配为0/-1
那么对于1的相邻的有2和5,它们对应的下标为1和7,那么很显然有nxt[1]=7,nxt[7]=0

So?就这样了吗?链式"前向"星,多读几遍名字,实际上这个链表和我们常见的链表不一样!!!

是倒着建的: nxt[7]=1,nxt[1]=0,这样做有一个好处,后文提到,接下来继续说说其他东西:
我们光知道下标,还需要知道下标对应的数呢,显然开一个to数组,存储下标的数,比如to[1]=2,to[7]=5,当然当然当然咯,这里如果不仅存数,还存边的权也可以开pair<int,int>这类,最后的最后,起码要记录链表的头是吧?head[1]=7,对于1的相邻的点的组成的链表的头部显然是下标7位置的数.

这里很容易hua现我们每次多加一个点是在链表"头部"加点,并更改head数组,而不是常规的尾部加,这就是这个建图方式名字的由来了.

C++
const int N = 1e6 + 10;
int head[N];
int nxt[N];
int to[N];
int cnt;

void add(int x, int y) {
    nxt[++cnt] = head[x];//新分配的位置的下一个点显然为原来的"头部"
    to[cnt] = y;//这个点显然是y
    head[x] = cnt;//新的头部为新的点
}

测试:

C++
const int N = 1e6 + 10;
int head[N];
int nxt[N];
int to[N];
int cnt;

void add(int x, int y) {
    nxt[++cnt] = head[x];//新分配的位置的下一个点显然为原来的"头部"
    to[cnt] = y;//这个点显然是y
    head[x] = cnt;//新的头部为新的点
}

void dfs(int curr, int fa) {
    cout << curr << " ";
    for (int i = head[curr]; i; i = nxt[i]) {
        int v = to[i];
        if (v == fa)continue;
        dfs(v, curr);
    }
}

int n;

void solve() {
    cin >> n;
    forn(i, 1, n - 1) {
        int x, y;
        cin >> x >> y;
        add(x, y);
        add(y, x);
    }
    dfs(1, 0);
}

输出dfn序(上文已经提到过了)就是这样子的了.
输出dfs序只需要加一行:

C++
void dfs(int curr, int fa) {
    cout << curr << " ";
    for (int i = head[curr]; i; i = nxt[i]) {
        int v = to[i];
        if (v == fa)continue;
        dfs(v, curr);
    }
    cout << curr << " ";
}

比较easy吧,为啥从后往前建呢,因为你仔细想想啦,我们其实是并不知道"尾部"节点下标的,要么开个数组维护,要么遍历找(想想都不可能吧),倒着往前加是因为链表的头部才是我们一直已知的(想想你在乐扣写的链表题,是不是给了你链表的头部)
到此建树常见的两种写法完毕(当然还有二叉树之类的一些其他存储方式暂且不谈

回到第一次dfs位置

我们从dfs位置拿到了所有重儿子,沿着重儿子走的链就是重链,其他的则为轻链.很显然的是,这些链我们可以通过再次dfs立马取出来

C++
void dfs2(int curr, int root) {
    idx[curr] = ++T;
    rev[T] = curr;
    top[curr] = root;
    if (!son[curr])return;
    dfs2(son[curr], root);
    for (auto nxt: child[curr])if (nxt != son[curr] && nxt != fa[curr])dfs2(nxt, nxt);
}

T吗?其实和上面那个链表前向星的cnt一样,都是维护下一个未使用的位置的.简单说说咯.
idx表示的是当前curr节点的下标,说白咯,idx就是类似维护dfn序,参照上面dfs的代码理解.
当这个点第一次被访问时就会开辟一个位置把它放在后面.
来提提其他几个重要的数组.

rev数组吗其实就跟to数组恰不多,就是当前位置是哪个数:比如咯,当前节点curr为2,它前面已经dfs过了3个点
1,3,4,那么显然我们就把数组第四个位置给它,idx[2]=4,rev[4]=2.这两个数组看了上文是很容易get到意思的吧?top数组吗,就是链的起点(类比head数组),想象一下链的样子,树链都是从上到下的,我们就是记录这个"顶部起点".如果没有重儿子就停止遍历咯.否则先遍历"重儿子"再遍历轻儿子.其实嘛,这玩意不就是更改了你dfs的先后顺序嘛,你一般来说dfs就是随便遍历的,这里只是更改为重儿子优先遍历,很容易知道叶子节点是没得重儿子的,本身就不往下遍历咯.

你问我有啥好处呀?咱回忆下最开始讲的dfn序是啥?dfn序记录了dfs序遍历的顺序,那么你优先重儿子遍历,很显而易见的是,所有的重儿子会在dfn序上从最开始往后排,连成一起,形成一个"连续序列",同理,其他轻链紧随其后,最终你会hua现,我们将一棵树拆成若干个链huang在了一个数组上,一个树上问题变为了线性问题.

嗷,你问我为啥要要这么拆,我不是可以随便拆链吗,为啥辉要轻重儿子拆,这里可以参照oiwiki上的证明,这里本质目的是为了"降低复杂度",虽说变为了线性,但你的链的长度吗,如果有轻重儿子控制,你可以想象一下链长会长啥样.本质上你随便取几个链用dfn序也能达到化树上问题为线性,但显而易见需要利用一些"启发式"的思想去降低hu杂度.证明复杂度比较多的文章,具体可以参照相关文献.主要提提应用.

线性问题

思考下咯,我们现在有哪些比较有用的东东:
有top数组、deep数组、idx数组、rev数组,fa数组(当然咯,如果每个节点都有val,显而易见的是,我们需要开个数组存储)
那么,那么,那么,那么?先看看一些有意思的东西,我们在rev数组上有若干个点,这些点在数组上被划分成了一块块区域也就是我们剖分的"链".soga,思考下这些区域的点显而易见在树上也是连续的?先来说说第一个问题咯:LCA
LCA最常见的求法莫过于倍增.这里提提树链剖分咋求LCA.
模版题:LCA

Go
func lca(x int, y int) int {
	for top[x] != top[y] {
		if deep[top[x]] > deep[top[y]] {
			x = fa[top[x]]
		} else {
			y = fa[top[y]]
		}
	}
	if deep[x] < deep[y] {
		return x
	}
	return y
}

灰常的easy.大致解释下,思考下求x和y的lca,如果它俩的top相同会咋样?噢,显而易见吗,这两个点处于同一条链上,显而易见的是它们深度越浅显而相对越高.所以如果他们在同一条链上事,显而易见的是,他们就该该直接返回深度低的点作为LCA(因为剖分序列是连续的啦).如果链起点不一样,显而易见嘛.我们类似倍增LCA,让top深度更深的那个点往上跳.那么很显然的是至少要"跳出当前链",进入top的父亲所在的链底(top深的往上跳一格至少不会比那个top浅的还要浅,深度一样跳谁都无所谓)
引用下网上找的图(自己画的太难看咯~~~~)
image.png
就这个图而言你能一眼看出哪几条链吗?
显然1->5->6->9->(10/11,看你怎么建树的)就是重链了.7->8是一条轻链,求8和11的LCA,假如10作为重链的一部分,那么11单独是一条轻链,11的top为11,8的top为7,显然咯,11的top深一点,往上走一格到fa[11],不就是9了嘛.再看看8和9,9的top显然是1咯,而8的是7,很显然咯,该8跳咯.8一跃而起,跳到6,这个时候看看6和9,top显然一样咯,哪个浅?6!!!,肉眼可见.它就是LCA
参考代码:

Golang
package main

import (
	"bufio"
	. "fmt"
	"io"
	"os"
	"runtime/debug"
)

const N = 500010

var (
	n, m, s int
	fa      [N]int
	size    [N]int
	son     [N]int
	deep    [N]int
	idx     [N]int
	top     [N]int
	rev     [N]int
	v       [N][]int
	T       int
)

func dfs1(curr int, parent int) {
	fa[curr] = parent
	size[curr] = 1
	deep[curr] = deep[parent] + 1
	for _, t := range v[curr] {
		if t == parent {
			continue
		}
		dfs1(t, curr)
		size[curr] += size[t]
		if size[t] > size[son[curr]] {
			son[curr] = t
		}
	}
}
func dfs2(curr int, root int) {
	top[curr] = root
	T++
	idx[curr] = T
	rev[T] = curr
	if son[curr] == 0 {
		return
	}
	dfs2(son[curr], root)
	for _, t := range v[curr] {
		if t != fa[curr] && t != son[curr] {
			dfs2(t, t)
		}
	}
}
func lca(x int, y int) int {
	for top[x] != top[y] {
		if deep[top[x]] > deep[top[y]] {
			x = fa[top[x]]
		} else {
			y = fa[top[y]]
		}
	}
	if deep[x] < deep[y] {
		return x
	}
	return y
}
func main() {
	debug.SetGCPercent(-1)
	Solve(os.Stdin, os.Stdout)
}

func Solve(_r io.Reader, _w io.Writer) {
	in := bufio.NewReader(_r)
	out := bufio.NewWriter(_w)
	defer out.Flush()
	//-----------------------------------------------------
	Fscan(in, &n, &m, &s)
	for i := 1; i <= n; i++ {
		v[i] = make([]int, 0)
	}
	for i := 0; i < n-1; i++ {
		var x, y int
		Fscan(in, &x, &y)
		v[x] = append(v[x], y)
		v[y] = append(v[y], x)
	}
	dfs1(s, 0)
	dfs2(s, s)
	for i := 0; i < m; i++ {
		var x, y int
		Fscan(in, &x, &y)
		Fprintln(out, lca(x, y))
	}
}

最常见的应用

树链剖分最常见的应用莫过于结合线段树维护一些修改信息.比如就拿上面的LCA来说吧,这个过程我们可以观察到每个链作为连续的,可以看做线段树的区间,比如修改1->8的简单路径,上面每个点都加加加.这里很显然可以用到线段树的区改.显而易见的是,不同链的点,显然可以可以拆成几段区间修改.这个复杂度的证明可以见oiwiki.
来一道洛谷的题吧:树链剖分
image.png
嚯,好hu杂.肿么一上来就这么猛.在会这些的前提,先会个lazy区间修改线段树吧(zkw之类的也行),反正就是会一个灰常基础的区间修改线段树咯
线段树的代码就不提了,这里关于lazy线段树的话,有兴趣可以读我的一篇题解:

  1. 更新数组后处理求和查询题解:关于线段树lazy的那些事
    也可以看文章出处:关于线段树lazy的那些事

    本篇详细讲了下lazy线段树.

    进入正文了.先说说第一个操作吧,结合代码:
C++
void update(int x, int y, int z) {
    int topx = top[x];
    int topy = top[y];
    while (topx != topy) {
        if (deep[topx] >= deep[topy])
            seg->update_range(idx[topx], idx[x], z, seg->root), x = fa[topx];
        else
            seg->update_range(idx[topy], idx[y], z, seg->root), y = fa[topy];
        topx = top[x];
        topy = top[y];
    }
    if (idx[x] < idx[y]) {
        seg->update_range(idx[x], idx[y], z, seg->root);
    } else {
        seg->update_range(idx[y], idx[x], z, seg->root);
    }
}

嚯,很easy吧?(哪里easy了^v^),其实在剖分中你要学会一招:利用top到处跳,利用deep控制该怎么跳
首先吗?老规矩,先讨论下top一不一样,top一样不就是最常规的区改了嘛,在一条链上咯,注意的是,我们需要判断哪个点作为左端点,哪个点作为右端点

C++
if (idx[x] < idx[y]) {
    seg->update_range(idx[x], idx[y], z, seg->root);
} else {
    seg->update_range(idx[y], idx[x], z, seg->root);
}

idx嘛.就是前文提到的每个点对应的dfn序,说白咯,就是线段树维护的序列下标.灰常easy吧.
如果top不同嘛,众所周知,链都是连续的,那就网上跳咯,看看谁在上方,在跳的过程中,把跳过的地方全部做区间修改就行咯.是不是灰常的easy啊.其实难点就在于你会不会写区改线段树.跟上文的LCA几乎一致.
操作2呢?嚯?操作是三不是显而易见的嘛(手动划掉),忘记了?忘记了?忘记了?
dfn序可以看成dfs序去掉最后的出点得到的,先回忆下dfs序,每个相同的点之间就是啥<<<<<<子树!!!>>>>>>,现在只是去掉了相同点,该是子树的部分还是子树部分,只是我们不能"肉眼看出咯",有没有法子法子法子???有的!!我们不知道子树的终点,但恰到!(子树起点和大小),siz最开始记录了每个点的子树大小,那么久sosososo easy了(不好好搞懂每个点很容易迷糊的)

void updateSon(int x, int val) {
    //它的子树范围是连续的
    seg->update_range(idx[x], idx[x] + siz[x] - 1, val, seg->root);
}

so easy吧?什么,其他两个?你这两个都"fei"了,其他两个肯定也会了

C++
ll query(int x, int y) {
    ll ans = 0;
    int topx = top[x];
    int topy = top[y];
    while (topx != topy) {
        if (deep[topx] >= deep[topy])
            ans += seg->query_range(idx[topx], idx[x], seg->root), x = fa[topx];
        else
            ans += seg->query_range(idx[topy], idx[y], seg->root), y = fa[topy];
        ans %= p;
        topx = top[x];
        topy = top[y];
    }
    if (idx[x] < idx[y]) {
        ans += seg->query_range(idx[x], idx[y], seg->root);
    } else {
        ans += seg->query_range(idx[y], idx[x], seg->root);
    }
    return ans % p;
}
ll querySon(int x) {
    return seg->query_range(idx[x], idx[x] + siz[x] - 1, seg->root) % p;
}

查找无非多个取模操作.一模一样咯
最后贴上一份参考代码:

C++
#include <bits/stdc++.h>


using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define all(v) v.begin(),v.end()
#define yes cout<<"YES";
#define no cout<<"NO";
#define print(v); for(auto i:v){cout<<i<<" ";}
#define print_rev_ARR(v, n); for(int i=n-1;i>=0;i--){cout<<v[i]<<" ";}
#define in(v); for(auto &i:v){cin>>i;}
#define divisors(n, v);  for (long long i = 1; i*i <=n; i++){if(n%i==0){v.push_back(i);if((n/i!=i)){v.push_back(n/i);}}}
#define prefix(pref, arr, n); pref[0]=arr[0]; for (int i = 1; i < n; i++){pref[i]=pref[i-1]+arr[i];}
#define suffix(suff, arr, n); suff[n-1]=arr[n-1]; for (int i = n-2; i >=0 ; i--){suff[i]=suff[i+1]+arr[i];}
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("input.txt", "r", stdin),freopen("output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)

template<typename T, T N>
T ModInt(T value) { return (value % N + N) % N; }

#define mod(m, n) ModInt<int,n>(m);
const int N = 1e5 + 10;
int T = 0;
int deep[N], siz[N], son[N], fa[N];
int top[N], idx[N], rev[N];
int n, m, r, p;
vector<int> child[N];
ll value[N];

void dfs1(int curr, int parent) {
    fa[curr] = parent;
    deep[curr] = deep[parent] + 1;
    siz[curr] = 1;
    for (auto nxt: child[curr]) {
        if (nxt == parent)continue;
        dfs1(nxt, curr);
        siz[curr] += siz[nxt];
        if (siz[nxt] > siz[son[curr]])son[curr] = nxt;
    }
}

void dfs2(int curr, int root) {
    idx[curr] = ++T;
    rev[T] = curr;
    top[curr] = root;
    if (!son[curr])return;
    dfs2(son[curr], root);
    for (auto nxt: child[curr])if (nxt != son[curr] && nxt != fa[curr])dfs2(nxt, nxt);
}

struct Node {
    int l, mid, r, len;
    ll val;
    Node *left, *right;
    int lazy;

    Node(int l, int r) : l(l), r(r) {
        len = r - l + 1;
        mid = (l + r) >> 1;
        val = 0;
        left = right = nullptr;
        lazy = 0;
    }
};

typedef Node *A;

void Mod(A node) {
    node->val %= p;
    node->lazy %= p;
}

void push_down(A node) {
    if (!node->left)node->left = new Node(node->l, node->mid);
    if (!node->right)node->right = new Node(node->mid + 1, node->r);
    if (node->lazy) {
        node->left->val += node->left->len * node->lazy;
        node->right->val += node->right->len * node->lazy;
        node->left->lazy += node->lazy;
        node->right->lazy += node->lazy;
        node->lazy = 0;
        Mod(node->left);
        Mod(node->right);
    }
}

void push_up(A node) {
    node->val = node->left->val + node->right->val;
    Mod(node);
}

void build(A node) {
    if (node->l == node->r) {
        node->val = value[rev[node->l]];
        return;
    }
    push_down(node);
    build(node->left);
    build(node->right);
    push_up(node);
}

struct Seg {
    A root;

    Seg() {
        root = new Node(1, n);
        build(root);
    }

    void update_range(int left, int right, int val, A node) {
        if (left <= node->l && node->r <= right) {
            node->val += node->len * val;
            node->lazy += val;
            Mod(node);
            return;
        }
        push_down(node);
        if (left <= node->mid)update_range(left, right, val, node->left);
        if (right > node->mid)update_range(left, right, val, node->right);
        push_up(node);
    }

    ll query_range(int left, int right, A node) {
        if (left <= node->l && node->r <= right) {
            return node->val;
        }
        push_down(node);
        ll ans = 0;
        if (left <= node->mid)ans = (ans + query_range(left, right, node->left)) % p;
        if (right > node->mid)ans = (ans + query_range(left, right, node->right)) % p;
        return ans;
    }
};

Seg *seg;

void update(int x, int y, int z) {
    int topx = top[x];
    int topy = top[y];
    while (topx != topy) {
        if (deep[topx] >= deep[topy])
            seg->update_range(idx[topx], idx[x], z, seg->root), x = fa[topx];
        else
            seg->update_range(idx[topy], idx[y], z, seg->root), y = fa[topy];
        topx = top[x];
        topy = top[y];
    }
    if (idx[x] < idx[y]) {
        seg->update_range(idx[x], idx[y], z, seg->root);
    } else {
        seg->update_range(idx[y], idx[x], z, seg->root);
    }
}

void updateSon(int x, int val) {
    //它的子树范围是连续的
    seg->update_range(idx[x], idx[x] + siz[x] - 1, val, seg->root);
}

ll query(int x, int y) {
    ll ans = 0;
    int topx = top[x];
    int topy = top[y];
    while (topx != topy) {
        if (deep[topx] >= deep[topy])
            ans += seg->query_range(idx[topx], idx[x], seg->root), x = fa[topx];
        else
            ans += seg->query_range(idx[topy], idx[y], seg->root), y = fa[topy];
        ans %= p;
        topx = top[x];
        topy = top[y];
    }
    if (idx[x] < idx[y]) {
        ans += seg->query_range(idx[x], idx[y], seg->root);
    } else {
        ans += seg->query_range(idx[y], idx[x], seg->root);
    }
    return ans % p;
}

ll querySon(int x) {
    return seg->query_range(idx[x], idx[x] + siz[x] - 1, seg->root) % p;
}

void solve() {
    cin >> n >> m >> r >> p;
    forn(i, 1, n)cin >> value[i];
    forn(i, 1, n - 1) {
        int x, y;
        ll v;
        cin >> x >> y;
        child[x].emplace_back(y);
        child[y].emplace_back(x);
    }
    dfs1(r, 0);
    dfs2(r, r);
    seg = new Seg();
    while (m--) {
        int t;
        cin >> t;
        if (t == 1) {
            int x, y, z;
            cin >> x >> y >> z;
            update(x, y, z);
        } else if (t == 2) {
            int x, y;
            cin >> x >> y;
            cout << query(x, y) << endl;
        } else if (t == 3) {
            int x, z;
            cin >> x >> z;
            updateSon(x, z);
        } else {
            int x;
            cin >> x;
            cout << querySon(x) << endl;
        }
    }
}

int main() {
    Spider
    //------------------------------------------------------
    int test = 1;
    //    cin >> test;
    while (test--)solve();
}

小总结

树链剖分本身并不难,它的好处在于将树上问题转化为了线性问题,只需要考虑如何维护这些序列信息就行.其实代码挺固定的,无非就学会两点,利用top关系"爬树",利用deep关系怎么爬?时刻注意dfn序列的使用即可.祝你早日掌握重链剖分

树上启发式合并

篇幅太多咯,留给dsu on tree的空间不多咯.
因为树上启发式合并最常用的也是利用轻重儿子思想来实现的.所以就留在这讲.直接上题吧.
模版题:树上启发式合并
一道很典的dsu on tree模版题
image.png
这题是有毒!!!!的,不需要按照询问输出,直接按1-n的顺序输出答案就行,具体代码oiwiki上有.
这里先提提启发式思想,其实早在A*、并查集按秩合并我们就对这种思想有过接触.我们习惯性地将小的合并到大的当中.这样一来实际上会发现最后每块的大小是恰不多大的.不加思索地感觉这样好像挺不错的,比起一块特别大,一块特别小的块来说似乎感觉"好看点"(手动划掉).树上启发式合并是干嘛的咯?一般用于解决树上计数类问题,当子树的答案对父亲是有贡献的情况下使用.
大致提提思路: 大部分可能直接喜欢用的是类似按秩合并的思路合并块统计答案,这里提一个比较类似于树链剖分的思想.也是oiwiki上的一个思想:常常有时候对于某个点来说,它的重链的贡献比其他轻链加起来还要多,直接重新计算是灰常耗时间的.所以对于某个贡献我们常常保留重链贡献而重新算轻链贡献.具体看代码:

C++
void dfs1(int curr, int parent) {
    idx[curr] = ++T;
    rev[T] = curr;
    siz[curr] = 1;
    fa[curr] = parent;
    for (auto &nxt: child[curr]) {
        if (nxt == parent)continue;
        dfs1(nxt, curr);
        siz[curr] += siz[nxt];
        if (siz[nxt] > siz[son[curr]])son[curr] = nxt;
    }
    down[curr] = T;
}

嚯?怎么和树链剖分很像.这里就是使用dfs序了,idx记录起点,down记录终点,中间就是子树咯.同理的记算重儿子和子树大小.灰常得常规.
这题嘛由于要计算颜色种类,我们需要两个函数

C++
int color[N];//每个点的颜色
int cnt[N];//表示当前每种颜色个数
int colorCount;//表示当前不同颜色的个数

void add(int x) {
    if (cnt[color[x]] == 0) ++colorCount;
    cnt[color[x]]++;
}

void del(int x) {
    cnt[color[x]]--;
    if (cnt[color[x]] == 0) --colorCount;
}

增加某个点颜色color[x],很easy吧,如果原本没有这个颜色,这个颜色种类显然++,del同理.
这里思考下dfs过程,假如有某个点:它有两棵子树,当统计完left的答案时,在统计right的时候,是不是要删掉left答案,再去统计right,然后最后要算curr的答案,是不是left又重新算了一遍?

这里可以看出right算了一遍,left算了两遍,可以手玩n棵子树发现,最后一定有一棵子树在算curr的时候不需要重新算一遍,而其余每棵子树都要算两遍.
这里我们将算一遍的子树用于重儿子所在的子树.而其他的则为轻儿子所在子树.很容易感觉到一个重儿子抵得上好几个轻儿子的计算量(比较玄学吧,具体可以参照oiwiki上的复杂度证明,这个复杂度证明并不是太容易的).所以思路很简单了,先计算每个轻儿子所在的子树的答案,然后去掉算下一个轻儿子,最后计算重儿子所在子树的答案,然后不删掉,然后把轻儿子加回来算当前curr的答案往上走咯
参考代码:

C++
//isW是否是重儿子所在子树
void dfs2(int curr, bool isW) {
    //算轻儿子
    for (auto &v: child[curr])if (v != fa[curr] && v != son[curr])dfs2(v, false);
    if (son[curr])dfs2(son[curr], true);
    for (auto &v: child[curr]) if (v != fa[curr] && v != son[curr]) forn(i, idx[v], down[v]) add(rev[i]);
    add(curr);
    ans[curr] = colorCount;
    if (!isW)forn(i, idx[curr], down[curr])del(rev[i]);
}

1、dfs所有轻儿子,直到叶子节点.如果当前不是叶子节点肯定有重儿子了,在dfs一次重儿子.统计答案只需要注意到轻子树的答案在算完curr以后需要删掉,而重子树的则不需要,留到算curr.可以自己模拟下dfs的过程就会hua现(好像另类的暴力噢!划掉)咱是复杂度很好的"思想".当然更多dsu on tree的题,cf上倒是有一些,蓝桥杯的颜色统计、植物学家(印象大致叫这玩意方向的,反正可以看标签),注意点的是这类问题难点都在统计上,统计方式都是大同小异的.轻子树算两遍,重子树算一遍,及时清理轻子树答案算下个轻子树.
参考代码:

C++
Python
#include <bits/stdc++.h>


using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define all(v) v.begin(),v.end()
#define yes cout<<"YES";
#define no cout<<"NO";
#define print(v); for(auto i:v){cout<<i<<" ";}
#define print_rev_ARR(v, n); for(int i=n-1;i>=0;i--){cout<<v[i]<<" ";}
#define in(v); for(auto &i:v){cin>>i;}
#define divisors(n, v);  for (long long i = 1; i*i <=n; i++){if(n%i==0){v.push_back(i);if((n/i!=i)){v.push_back(n/i);}}}
#define prefix(pref, arr, n); pref[0]=arr[0]; for (int i = 1; i < n; i++){pref[i]=pref[i-1]+arr[i];}
#define suffix(suff, arr, n); suff[n-1]=arr[n-1]; for (int i = n-2; i >=0 ; i--){suff[i]=suff[i+1]+arr[i];}
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("input.txt", "r", stdin),freopen("output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)

template<typename T, T N>
T ModInt(T value) { return (value % N + N) % N; }

#define mod(m, n) ModInt<int,n>(m);
const int N = 1e5 + 10;

int cnt[N];//表示当前每种颜色个数
int colorCount;//表示当前不同颜色的个数
int T;//存当前dfs序列下标
int idx[N];//点到下标
int rev[N];//下标到点
int down[N];//当前这段序列对应的终点
int siz[N];//算子树大小
int ans[N];//统计答案
int son[N];//算重儿子
int fa[N];//父亲
vector<int> child[N];//邻接表集
int color[N];//每个点的颜色
void add(int x) {
    if (cnt[color[x]] == 0) ++colorCount;
    cnt[color[x]]++;
}

void del(int x) {
    cnt[color[x]]--;
    if (cnt[color[x]] == 0) --colorCount;
}

void dfs1(int curr, int parent) {
    idx[curr] = ++T;
    rev[T] = curr;
    siz[curr] = 1;
    fa[curr] = parent;
    for (auto &nxt: child[curr]) {
        if (nxt == parent)continue;
        dfs1(nxt, curr);
        siz[curr] += siz[nxt];
        if (siz[nxt] > siz[son[curr]])son[curr] = nxt;
    }
    down[curr] = T;
}

//isW是否是重儿子所在子树
void dfs2(int curr, bool isW) {
    //算轻儿子
    for (auto &v: child[curr])if (v != fa[curr] && v != son[curr])dfs2(v, false);
    if (son[curr])dfs2(son[curr], true);
    for (auto &v: child[curr]) if (v != fa[curr] && v != son[curr]) forn(i, idx[v], down[v]) add(rev[i]);
    add(curr);
    ans[curr] = colorCount;
    if (!isW)forn(i, idx[curr], down[curr])del(rev[i]);
}

int n;

void solve() {
    cin >> n;
    forn(i, 1, n - 1) {
        int l, r;
        cin >> l >> r;
        child[l].push_back(r);
        child[r].push_back(l);
    }
    forn(i, 1, n)cin >> color[i];
    dfs1(1, 0);
    dfs2(1, false);
    forn(i, 1, n)cout << ans[i] << endl;
}

int main() {
    Spider
    //------------------------------------------------------
    int test = 1;
    //    cin >> test;
    while (test--)solve();
}

总结

篇幅有限,如果有啥不太准确的地方欢迎评论区指正.对于这类个常见的"工具",可以多多做题熟悉,本身原理也不是很难

评论 (8)