跳到主要内容

树链剖分

参考资料

实现

vector<int> G[N];
int fa[N],son[N],siz[N],dep[N];
int top[N],dfn[N],rnk[N],out[N];
int cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
if(v==fa[u])continue;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}

应用

例题

洛谷 P3384 【模板】重链剖分/树链剖分

给定一棵 nn 个结点的树,每个节点上包含一个数值,支持以下操作:

  • 1 x y z:表示将树从 xxyy 结点最短路径上所有节点的值都加上 zz
  • 2 x y:表示求树从 xxyy 结点最短路径上所有节点的值之和。
  • 3 x z:表示将以 xx 为根节点的子树内所有节点值都加上 zz
  • 4 x:表示求以 xx 为根节点的子树内所有节点值之和。
#include <bits/stdc++.h>
#define ls (u<<1)
#define rs (u<<1|1)
#define mid (l+r>>1)
using namespace std;

using ll=long long;
const int N=100005;
vector<int> G[N];
int fa[N],son[N],siz[N],dep[N];
int top[N],dfn[N],rnk[N],out[N];
int mod,cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
if(v==fa[u])continue;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}
ll a[N],val[N<<2],tag[N<<2];
void gx(int u,ll v,int len)
{
val[u]=(val[u]+v*len%mod)%mod;
tag[u]=(tag[u]+v)%mod;
}
void push_up(int u)
{
val[u]=(val[ls]+val[rs])%mod;
}
void push_down(int u,int l,int r)
{
gx(ls,tag[u],mid-l+1);
gx(rs,tag[u],r-mid);
tag[u]=0;
}
void build(int u,int l,int r)
{
if(l==r){val[u]=a[rnk[l]]%mod;return;}
build(ls,l,mid);
build(rs,mid+1,r);
push_up(u);
}
void update(int u,int l,int r,int x,int y,ll v)
{
if(x<=l&&r<=y){gx(u,v,r-l+1);return;}
push_down(u,l,r);
if(x<=mid)update(ls,l,mid,x,y,v);
if(y>mid)update(rs,mid+1,r,x,y,v);
push_up(u);
}
ll query(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return val[u]%mod;
push_down(u,l,r);
ll res=0;
if(x<=mid)res=(res+query(ls,l,mid,x,y))%mod;
if(y>mid)res=(res+query(rs,mid+1,r,x,y))%mod;
return res;
}
void update_path(int n,int x,int y,ll v)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],v);
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
update(1,1,n,dfn[y],dfn[x],v);
}
ll query_path(int n,int x,int y)
{
ll res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=(res+query(1,1,n,dfn[top[x]],dfn[x]))%mod;
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
res=(res+query(1,1,n,dfn[y],dfn[x]))%mod;
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,m,r;
cin>>n>>m>>r>>mod;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(r);
dfs2(r,r);
build(1,1,n);
while(m--)
{
int op,x,y;
ll z;
cin>>op;
if(op==1)
{
cin>>x>>y>>z;
update_path(n,x,y,z);
}
else if(op==2)
{
cin>>x>>y;
cout<<query_path(n,x,y)<<'\n';
}
else if(op==3)
{
cin>>x>>z;
update(1,1,n,dfn[x],out[x],z);
}
else if(op==4)
{
cin>>x;
cout<<query(1,1,n,dfn[x],out[x])<<'\n';
}
}
return 0;
}

洛谷 P2590 [ZJOI2008] 树的统计

一棵树上有 nn 个节点,每个节点都有一个权值 wiw_i,支持以下操作:

  • CHANGE u t:把结点 uu 的权值改为 tt
  • QMAX u v:询问从点 uu 到点 vv 的路径上的节点的最大权值。
  • QSUM u v:询问从点 uu 到点 vv 的路径上的节点的权值和。
#include <bits/stdc++.h>
#define ls (u<<1)
#define rs (u<<1|1)
#define mid (l+r>>1)
using namespace std;

const int inf=0x3f3f3f3f;
const int N=30005;
vector<int> G[N];
int fa[N],son[N],dep[N],siz[N];
int top[N],dfn[N],rnk[N],out[N];
int cnt=0;
void dfs1(int u)
{
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto v:G[u])
{
if(v==fa[u])continue;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
dfn[u]=++cnt;
rnk[cnt]=u;
if(son[u])dfs2(son[u],t);
for(auto v:G[u])
{
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
out[u]=cnt;
}
int a[N],val[N<<2],mx[N<<2];
void push_up(int u)
{
val[u]=val[ls]+val[rs];
mx[u]=max(mx[ls],mx[rs]);
}
void build(int u,int l,int r)
{
if(l==r){val[u]=a[rnk[l]];mx[u]=a[rnk[l]];return;}
build(ls,l,mid);
build(rs,mid+1,r);
push_up(u);
}
void update(int u,int l,int r,int x,int v)
{
if(l==r){val[u]=v;mx[u]=v;return;}
if(x<=mid)update(ls,l,mid,x,v);
else update(rs,mid+1,r,x,v);
push_up(u);
}
int query_max(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return mx[u];
int res=-inf;
if(x<=mid)res=max(res,query_max(ls,l,mid,x,y));
if(y>mid)res=max(res,query_max(rs,mid+1,r,x,y));
return res;
}
int query_sum(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return val[u];
int res=0;
if(x<=mid)res+=query_sum(ls,l,mid,x,y);
if(y>mid)res+=query_sum(rs,mid+1,r,x,y);
return res;
}
int query_path_max(int x,int y,int n)
{
int res=-inf;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=max(res,query_max(1,1,n,dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
res=max(res,query_max(1,1,n,dfn[y],dfn[x]));
return res;
}
int query_path_sum(int x,int y,int n)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res+=query_sum(1,1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
res+=query_sum(1,1,n,dfn[y],dfn[x]);
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin>>n;
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
G[x].push_back(y);
G[y].push_back(x);
}
dfs1(1);
dfs2(1,1);
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
build(1,1,n);
int m;
cin>>m;
while(m--)
{
string op;
int x,y;
cin>>op>>x>>y;
if(op=="CHANGE")
{
update(1,1,n,dfn[x],y);
}
else if(op=="QMAX")
{
cout<<query_path_max(x,y,n)<<'\n';
}
else if(op=="QSUM")
{
cout<<query_path_sum(x,y,n)<<'\n';
}
}
return 0;
}