线段树合并
首先指定根节点为 $1$,合法路径数量与树是否有根无关。
使用线段树合并维护节点 $i$ 子树中所有有效颜色的出现次数(如果路径 $(i,j)$ 上除 $j$ 外没有其他颜色为 $c_j$ 的节点,且节点 $j$ 在节点 $i$ 子树中,就认为 $c_j$ 是节点 $i$ 子树中的有效颜色),在线段树合并过程中可以计算所有路径端点以 $i$ 为最近公共祖先的合法路径数量。
对于端点颜色为 $c_i$ 的路径,由于 $i$ 的颜色已经是 $c_i$,可以通过在合并 $i$ 的所有儿子节点,不包含 $i$ 自身信息的线段树上查询有效颜色 $c_i$ 的出现次数 $cnt$ 计算贡献,其对答案的贡献就是 $cnt$ 本身。
对于端点颜色不为 $c_i$ 的路径,节点 $i$ 的每个孩子节点的子树中的有效颜色节点可以两两搭配,对答案的贡献是一个经典的乘积之和形式,这可以在合并线段树的叶子节点时计算。
在合并 $i$ 的所有儿子节点得到节点 $i$ 的线段树后,将节点 $i$ 的颜色信息加入线段树。也就是将有效颜色 $c_i$ 的出现次数置为 $1$,因为从 $i$ 子树外任意颜色为 $c_i$ 的节点到达 $i$ 子树内的节点都需要经过 $i$,这样一来原先所有颜色为 $c_i$ 的有效节点都不再是合法的路径端点。
笔者的实现中合并了计算端点颜色为 $c_i$ 的路径与将节点 $i$ 的信息加入线段树的操作,在计入颜色 $c_i$ 的贡献后直接将 $c_i$ 的出现次数置为 $1$ 以减少代码量。
时间复杂度 $O(n \log n)$。
#include<bits/stdc++.h>
using namespace std;
struct Segment{int w,ls,rs;}s[200'005*20];
int n,seg,a[200'005],ro[200'005];
vector<int> v[200'005];
long long ans;
void update(int &u,int l,int r,int p){
if(u==0) u=++seg;
if(l==r) return ans+=s[u].w,s[u].w=1,void();
int mid=(l+r)>>1;
if(p<=mid) update(s[u].ls,l,mid,p);
else update(s[u].rs,mid+1,r,p);
}
int merge(int x,int y,int l,int r,int co){
if(x==0||y==0) return x+y;
if(l==r){
if(l!=co) ans+=1ll*s[x].w*s[y].w;
s[x].w+=s[y].w;
}
else{
int mid=(l+r)>>1;
s[x].ls=merge(s[x].ls,s[y].ls,l,mid,co);
s[x].rs=merge(s[x].rs,s[y].rs,mid+1,r,co);
}
return x;
}
void dfs(int x,int las){
for(int i:v[x]) if(i!=las){
dfs(i,x);
ro[x]=merge(ro[x],ro[i],1,n,a[x]);
}
update(ro[x],1,n,a[x]);
}
void solution(){
for(int i=1;i<=seg;i++) s[i].w=0,s[i].ls=0,s[i].rs=0;
cin>>n,seg=0,ans=0ll;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) ro[i]=0,v[i].clear();
for(int i=1;i<=n-1;i++){
static int x,y;
cin>>x>>y;
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1,0);
cout<<ans<<'\n';
}
int T;
int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>T;
while(T--) solution();
return 0;
}