#edu3010. 【教程】树上启发式合并

【教程】树上启发式合并

广义的启发式合并就是同一个数据结构之间的合并。例如并查集合并、线段树合并等。

其核心思想就是:将小集合内元素逐一暴力并入大集合,注意大集合内元素不改变

那为何时间复杂度是 O(nlogn)O(n\log{n}) ?

考虑 “小集合规模+大集合规模 大于等于 2倍小集合规模”,小集合每一次合并之后的规模都会至少翻倍,那么对于总共 nn 个元素的问题,合并次数一定是 O(logn)O(\log{n}) , 每次合并操作元素个数不超过 nn , 可以得到总时间复杂度为 O(nlogn)O(n\log{n})

如果使用 set,map 等容器,访问、插入元素时间复杂度会增加 O(logn)O(\log{n}) 倍,时间复杂度变为 O(nlog2n)O(n\log^2{n})

下面利用一道例题讲解“树上启发式合并”基本写法。

例. U41492 树上数颜色

【题意简化】给定一颗 n(1n105)n(1 \le n\le 10^5) 个节点的树,每个节点颜色给定,多次询问某个子树内颜色种类数。

【分析】

  • 思路1: 求出 dfs序,子树 xx 对应区间 [dfn[x],dfn[x]+siz[x]1][dfn[x],dfn[x]+siz[x]-1] 。问题转化为区间数颜色,离线高效思路就是 HH项链,时间复杂度为 O(nlogn)O(n\log{n})。 也可以使用莫队,时间复杂度为 O(nn)O(n\sqrt{n}) 。如果颜色带修改,就可以转成带修莫队。

  • 思路2:启发式合并

这里介绍常见启发式合并写法。

1.利用容器合并

每个节点维护一个 set , DFS 求出每个节点重儿子。首先将节点容器与重儿子交换(swap, 时间复杂度为 O(1)O(1)),然后根节点信息插入,依次将轻儿子元素暴力插入根节点容器。

核心点在于,没有移动重儿子容器元素,只移动轻儿子容器元素。时间复杂度为 O(nlog2n)O(n\log^2{n}) ,空间复杂度为 O(nlogn)O(n\log{n}).

注意:除静态数组容器 array 外所有容器,swap() 只交换一次容器内指针,所以时间复杂度就是 O(1)O(1)

参考代码如下:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
int n,c[N],m,ans[N];
vector<int>g[N];
set<int>st[N];
int siz[N],son[N];  //子树大小 重儿子

void dfs(int x,int fa)
{
    son[x]=0; siz[x]=1;
    for(int y:g[x])
        if(y!=fa){
            dfs(y,x);
            siz[x]+=siz[y];
            if(son[x]==0||siz[y]>siz[son[x]])son[x]=y;  //重儿子
        }

    if(son[x]>0)  //重儿子存在
        swap(st[x],st[son[x]]);  //交换 时间复杂度 O(1) 

    st[x].insert(c[x]);

    //暴力合并轻儿子
    for(int y:g[x])
        if(y!=fa && y!=son[x])
            for(int i:st[y])st[x].insert(i);
    
    ans[x]=st[x].size();

}
int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
    cin>>n;
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for(int i=1;i<=n;i++)cin>>c[i];
    dfs(1,0);
    cin>>m;
    while(m--)
    {
        int q;
        cin>>q;
        cout<<ans[q]<<"\n";
    }
    return 0;
}

2.利用全局计数数组统计种类数

利用容器启发式合并,时间复杂度并不是最优的,可以进一步优化。

可以定义一个全局计数数组,利用计数数组统计种类数。时间复杂度可以优化到 O(nlogn)O(n\log{n})

核心思路如下:

递归时求出当前节点 xx 的重儿子,之后优先递归计算轻儿子,最后递归计算重儿子。

重儿子信息保留(在全局数组中),接下来将所有轻儿子信息加入全局计算数组,根节点信息加入全局数组。

回溯的时候,如果当前节点是轻儿子,要删除。示意图如下:

启发式合并示意图

也就是说,先处理轻儿子,最后处理重儿子。递归返回时,重儿子信息保留在全局数组中。只需要将轻儿子暴力再次加入,就是当前树的所有信息。

但是由于有多个轻儿子,当轻儿子访问结束时,需要删除掉所有轻儿子信息,即回溯时删除轻儿子信息。递归函数需要区别回溯时删除或者不删除,增加一个参数 keep,核心参加代码框架如下:

void dfs2(int x,int fa,int keep)
{
    //先计算轻儿子
    for(int y:g[x])
        if(y!=fa && y!=son[x])
            dfs2(y,x,0);

    //计算重儿子 并且会保留 重儿子信息
    if(son[x]>0)dfs2(son[x],x,1);

    //只计算轻儿子
    for(int y:g[x])
    	if(y!=fa && y!=son[x])
        Add(y,x); //将子树 y 加入 
    
    add(x);//加入 x 信息
  
    ans[x]=num;  //保留答案

    if(keep==0)  //删除轻儿子信息
        Del(x,fa);   
}

常见会有两种写法,本质都是相同的。

一种递归添加子树删除子树,参考如下。另外一种将所有子树利用 dfs 序变为区间,添加和删除子树时访问 dfs序上一段区间。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
int n,c[N],m;
vector<int>g[N];
int son[N],siz[N],dfn[N],L[N],R[N],tim;  //dfs 序
int cnt[N],num,ans[N]; 

void dfs1(int x,int fa)
{
    dfn[++tim]=x; //dfs 序
    L[x]=tim;     //x 子树中 dfs 序最小值
    siz[x]=1;    son[x]=0;
    for(int y:g[x])
        if(y!=fa)
        {
            dfs1(y,x);
            siz[x]+=siz[y];
            if(son[x]==0||siz[y]>siz[son[x]])son[x]=y;  //重儿子
        }
    R[x]=tim; // x 子树中 dfs 序最大值
}
void add(int x)
{
    cnt[c[x]]++;
    if(cnt[c[x]]==1)num++;
}
void del(int x)
{
    cnt[c[x]]--;
    if(cnt[c[x]]==0)num--;
}
void Add(int x,int fa)
{
	cnt[c[x]]++; if(cnt[c[x]]==1)num++;
	for(int y:g[x])
		if(y!=fa)Add(y,x);
}
void Del(int x,int fa)
{
	cnt[c[x]]--; if(cnt[c[x]]==0)num--;
	for(int y:g[x])
		if(y!=fa)Del(y,x);
	
}
void dfs2(int x,int fa,int keep)
{
    //先计算轻儿子
    for(int y:g[x])
        if(y!=fa && y!=son[x])
            dfs2(y,x,0);

    //计算重儿子 并且会保留 重儿子信息
    if(son[x]>0)dfs2(son[x],x,1);

    //只计算轻儿子
    for(int y:g[x])
    	if(y!=fa && y!=son[x])Add(y,x);
    
    add(x);
    ans[x]=num;

    if(keep==0)
        Del(x,fa);   
}
int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
    cin>>n;
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for(int i=1;i<=n;i++)
        cin>>c[i];

    dfs1(1,0);
    dfs2(1,0,0);

    cin>>m;
    for(int i=1;i<=m;i++)
    {
        int q;
        cin>>q;
        cout<<ans[q]<<"\n"; 
    }        
    
    return 0;
}

另外一种将子树转换为区间,参考代码:点击


其他例题

1.[ABC372E] K-th Largest Connected Components

2.GYM106161 L. Label Matching

3.P3201 [HNOI2009] 梦幻布丁

4.CF1899G Unusual Entertainment

5.CF600E Lomsat gelral

6.CF375D Tree and Queries


学习完毕

{{ select(1) }}

  • YES
  • NO