题意简述
以结点 为根。自底向上地为每个非叶结点 选重儿子:在儿子们各自的重链长度 确定后, 以正比于 的概率选 ,于是 的重链长度变为 。求每个结点到根路径上轻边数量的期望之和,对 取模。
解题思路
拆贡献到边。 结点 到根的路径经过边 ( 为儿子)当且仅当 在子树 内,这样的 有 个;而该边是轻边当且仅当 没选 。由期望的线性性,
求选择概率。 设 的各儿子重链长度为随机变量 (不同子树相互独立),,则 。分母里的求和是难点,用积分恒等式 把它拆开:记 为 的概率生成函数,由独立性
维护生成函数。 ,故 。同样用上面的积分展开,记 、,可得
而 恰是上式中第 个儿子对所有 的贡献之和。其中 ,在模意义下即 的逆元。
实现。 自底向上 DFS:叶子 ;非叶结点先把所有儿子的 乘成 ,再对每个 用多项式除法得到 ,按上式累加出 与各 。每个结点的多项式次数等于子树高度,由树形背包的配对计数,总复杂度 。
一个常数优化:纯链(含叶子)的 是单项式 ,乘除都退化成移位。把这类儿子单独拎出来——它们对同一父结点贡献的那个积分值完全相同(链长在分子分母里恰好抵消),一次算出再乘以各自的 即可,无需逐个做稠密多项式运算。这样星形、菊花、链等结构从 降到近 。
时间复杂度为 。
参考代码
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
const int mod=998244353;
const int N=5005;
int siz[N];
ll inv[N],ans;
vector<int> G[N];
ll Pow(ll x,ll y)
{
x%=mod;
ll res=1;
while(y)
{
if(y&1)res=res*x%mod;
x=x*x%mod;
y>>=1;
}
return res;
}
void init()
{
for(int i=1;i<N;i++)inv[i]=i==1?1:(mod-mod/i)*inv[mod%i]%mod;
}
vector<ll> dfs(int u,int fa)
{
siz[u]=1;
vector<vector<ll>> fc;
vector<int> cv,mL,ms;
for(int v:G[u])
{
if(v==fa)continue;
vector<ll> fv=dfs(v,u);
siz[u]+=siz[v];
int nz=0,deg=0;
for(int i=0;i<fv.size();i++)if(fv[i]){nz++;deg=i;}
if(nz==1){mL.push_back(deg);ms.push_back(siz[v]);}
else{fc.push_back(move(fv));cv.push_back(v);}
}
if(fc.empty()&&mL.empty())return {0,1};
vector<ll> p={1};
for(auto &f:fc)
{
vector<ll> g(p.size()+f.size()-1,0);
for(int j=0;j<f.size();j++)
{
if(!f[j])continue;
ll fj=f[j];
for(int i=0;i<p.size();i++)if(p[i])g[i+j]=(g[i+j]+p[i]*fj)%mod;
}
p=move(g);
}
int dd=p.size()-1,s=0,mxd=0;
for(int l:mL)
{
s+=l;
mxd=max(mxd,l);
}
for(auto &f:fc)
{
int d=f.size()-1;
mxd=max(mxd,d);
}
vector<ll> fu(mxd+2,0);
if(!mL.empty())
{
ll im=0;
for(int m=0;m<=dd;m++)im=(im+p[m]*inv[s+m])%mod;
for(int i=0;i<mL.size();i++)
{
int l=mL[i];
ll pk=(ll)l*im%mod;
fu[l+1]=(fu[l+1]+pk)%mod;
ans=(ans+(ll)ms[i]*((1-pk+mod)%mod))%mod;
}
}
for(int c=0;c<fc.size();c++)
{
auto &f=fc[c];
int dc=f.size()-1,dr=dd-dc;
vector<int> nz;
for(int j=0;j<=dc;j++)if(f[j])nz.push_back(j);
vector<ll> w(p),r(dr+1,0);
ll il=Pow(f[dc],mod-2);
for(int i=dd;i>=dc;i--)
{
ll q=w[i]*il%mod;
r[i-dc]=q;
for(int j:nz)w[i-dc+j]=((w[i-dc+j]-q*f[j])%mod+mod)%mod;
}
ll pp=0;
for(int l:nz)
{
if(!l)continue;
ll t=0;
for(int m=0;m<=dr;m++)t=(t+r[m]*inv[l+s+m])%mod;
t=t*l%mod*f[l]%mod;
fu[l+1]=(fu[l+1]+t)%mod;
pp=(pp+t)%mod;
}
ans=(ans+(ll)siz[cv[c]]*((1-pp+mod)%mod))%mod;
}
return fu;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
init();
int cas,t;
cin>>cas>>t;
while(t--)
{
int n;
cin>>n;
for(int i=1;i<=n;i++)G[i].clear();
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
ans=0;
dfs(1,0);
cout<<ans<<'\n';
}
return 0;
}