跳到主要内容

树链剖分

参考资料

实现

vector<int> G[N];
int fa[N],son[N],siz[N],dep[N];
int top[N],dfn[N],rnk[N],out[N];
int cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
if(v==fa[u])continue;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}

应用

详见 最近公共祖先

例题

洛谷 P3384 【模板】重链剖分/树链剖分

给定一棵 nn 个结点的树,每个节点上包含一个数值,支持以下操作:

  • 1 x y z:表示将树从 xxyy 结点最短路径上所有节点的值都加上 zz
  • 2 x y:表示求树从 xxyy 结点最短路径上所有节点的值之和。
  • 3 x z:表示将以 xx 为根节点的子树内所有节点值都加上 zz
  • 4 x:表示求以 xx 为根节点的子树内所有节点值之和。
代码(1)
#include <bits/stdc++.h>
#define ls (u<<1)
#define rs (u<<1|1)
#define mid (l+r>>1)
using namespace std;

using ll=long long;
const int N=100005;
vector<int> G[N];
int fa[N],son[N],siz[N],dep[N];
int top[N],dfn[N],rnk[N],out[N];
int mod,cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
if(v==fa[u])continue;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}
ll a[N],val[N<<2],tag[N<<2];
void gx(int u,ll v,int len)
{
val[u]=(val[u]+v*len%mod)%mod;
tag[u]=(tag[u]+v)%mod;
}
void push_up(int u)
{
val[u]=(val[ls]+val[rs])%mod;
}
void push_down(int u,int l,int r)
{
gx(ls,tag[u],mid-l+1);
gx(rs,tag[u],r-mid);
tag[u]=0;
}
void build(int u,int l,int r)
{
if(l==r){val[u]=a[rnk[l]]%mod;return;}
build(ls,l,mid);
build(rs,mid+1,r);
push_up(u);
}
void update(int u,int l,int r,int x,int y,ll v)
{
if(x<=l&&r<=y){gx(u,v,r-l+1);return;}
push_down(u,l,r);
if(x<=mid)update(ls,l,mid,x,y,v);
if(y>mid)update(rs,mid+1,r,x,y,v);
push_up(u);
}
ll query(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return val[u]%mod;
push_down(u,l,r);
ll res=0;
if(x<=mid)res=(res+query(ls,l,mid,x,y))%mod;
if(y>mid)res=(res+query(rs,mid+1,r,x,y))%mod;
return res;
}
void update_path(int n,int x,int y,ll v)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],v);
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
update(1,1,n,dfn[y],dfn[x],v);
}
ll query_path(int n,int x,int y)
{
ll res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=(res+query(1,1,n,dfn[top[x]],dfn[x]))%mod;
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
res=(res+query(1,1,n,dfn[y],dfn[x]))%mod;
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,m,r;
cin>>n>>m>>r>>mod;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(r);
dfs2(r,r);
build(1,1,n);
while(m--)
{
int op,x,y;
ll z;
cin>>op;
if(op==1)
{
cin>>x>>y>>z;
update_path(n,x,y,z);
}
else if(op==2)
{
cin>>x>>y;
cout<<query_path(n,x,y)<<'\n';
}
else if(op==3)
{
cin>>x>>z;
update(1,1,n,dfn[x],out[x],z);
}
else if(op==4)
{
cin>>x;
cout<<query(1,1,n,dfn[x],out[x])<<'\n';
}
}
return 0;
}

洛谷 P5903 【模板】树上 K 级祖先

给定一棵 nn 个点的有根树。

qq 次询问,第 ii 次询问给定 xi,kix_i, k_i,要求点 xix_ikik_i 级祖先,答案为 ansians_i。特别地,ans0=0ans_0 = 0

代码(1)
#include <bits/stdc++.h>
using namespace std;

using ll=long long;
using uint=unsigned int;
const int N=500005;
vector<int> G[N];
int fa[N],son[N],siz[N],dep[N];
int top[N],dfn[N],rnk[N],out[N];
int cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}
int query(int u,int k)
{
while(dep[u]-dep[top[u]]<k)
{
k-=dep[u]-dep[top[u]]+1;
u=fa[top[u]];
}
return rnk[dfn[u]-k];
}
uint s;
inline uint get(uint x)
{
x^=x<<13;
x^=x>>17;
x^=x<<5;
return s=x;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,q;
cin>>n>>q>>s;
int r;
for(int i=1;i<=n;i++)
{
int f;
cin>>f;
if(f==0){r=i;continue;}
G[f].push_back(i);
}
dfs1(r);
dfs2(r,r);
ll ans=0,last=0;
for(int i=1;i<=q;i++)
{
int x=(get(s)^last)%n+1,k=(get(s)^last)%dep[x];
last=query(x,k);
ans^=i*last;
}
cout<<ans<<'\n';
return 0;
}

洛谷 P3379 【模板】最近公共祖先(LCA)

给定一棵有根多叉树,请求出指定两个点直接最近的公共祖先。

代码(2)
#include <bits/stdc++.h>
using namespace std;

const int N=500005;
int a[N][25],dep[N];
vector<int> G[N];
bool vis[N];
void dfs(int u,int fa)
{
a[u][0]=fa;
for(int i=1;i<=20;i++)a[u][i]=a[a[u][i-1]][i-1];
dep[u]=dep[fa]+1;
for(auto v:G[u])if(v!=fa)dfs(v,u);
}
int lca(int u,int v)
{
if(dep[u]<dep[v])swap(u,v);
for(int i=20;i>=0;i--)
{
if(dep[a[u][i]]>=dep[v])u=a[u][i];
}
if(u==v)return u;
for(int i=20;i>=0;i--)
{
if(a[u][i]!=a[v][i])
{
u=a[u][i];
v=a[v][i];
}
}
return a[u][0];
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,m,s;
cin>>n>>m>>s;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(s,0);
while(m--)
{
int u,v;
cin>>u>>v;
cout<<lca(u,v)<<'\n';
}
return 0;
}

洛谷 P2590 [ZJOI2008] 树的统计

一棵树上有 nn 个节点,每个节点都有一个权值 wiw_i,支持以下操作:

  • CHANGE u t:把结点 uu 的权值改为 tt
  • QMAX u v:询问从点 uu 到点 vv 的路径上的节点的最大权值。
  • QSUM u v:询问从点 uu 到点 vv 的路径上的节点的权值和。
代码(1)
#include <bits/stdc++.h>
#define ls (u<<1)
#define rs (u<<1|1)
#define mid (l+r>>1)
using namespace std;

const int inf=0x3f3f3f3f;
const int N=30005;
vector<int> G[N];
int fa[N],son[N],siz[N],dep[N];
int top[N],dfn[N],rnk[N],out[N];
int cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
if(v==fa[u])continue;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}
int a[N],val[N<<2],mx[N<<2];
void push_up(int u)
{
val[u]=val[ls]+val[rs];
mx[u]=max(mx[ls],mx[rs]);
}
void build(int u,int l,int r)
{
if(l==r){val[u]=mx[u]=a[rnk[l]];return;}
build(ls,l,mid);
build(rs,mid+1,r);
push_up(u);
}
void update(int u,int l,int r,int x,int v)
{
if(l==r){val[u]=mx[u]=v;return;}
if(x<=mid)update(ls,l,mid,x,v);
else update(rs,mid+1,r,x,v);
push_up(u);
}
int query_max(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return mx[u];
int res=-inf;
if(x<=mid)res=max(res,query_max(ls,l,mid,x,y));
if(y>mid)res=max(res,query_max(rs,mid+1,r,x,y));
return res;
}
int query_sum(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return val[u];
int res=0;
if(x<=mid)res+=query_sum(ls,l,mid,x,y);
if(y>mid)res+=query_sum(rs,mid+1,r,x,y);
return res;
}
int query_path_max(int x,int y,int n)
{
int res=-inf;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=max(res,query_max(1,1,n,dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
res=max(res,query_max(1,1,n,dfn[y],dfn[x]));
return res;
}
int query_path_sum(int x,int y,int n)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res+=query_sum(1,1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
res+=query_sum(1,1,n,dfn[y],dfn[x]);
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin>>n;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1);
dfs2(1,1);
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
int m;
cin>>m;
while(m--)
{
string op;
int u,v;
cin>>op>>u>>v;
if(op=="CHANGE")update(1,1,n,dfn[u],v);
else if(op=="QMAX")cout<<query_path_max(u,v,n)<<'\n';
else if(op=="QSUM")cout<<query_path_sum(u,v,n)<<'\n';
}
return 0;
}