小 A 的树
题目描述
小 A 有一棵 \(N\) 个点的树,每个点都有一个小于 \(2^{20}\) 的非负整数权值。现在小 A 从树中随机选择一个点 \(x\),再随机选择一个点 \(y\)(\(x\)、\(y\)可以是同一个点),并对从 \(x\) 到 \(y\) 的路径上 所有的点的权值分别做按位与、按位或、异或运算,最终会求得三个整数。小 A 想知道,他求出的三个数的期望值分别是多少。
输入描述
输入包含多组测试数据。
第一行,一个整数 \(T\),表示测试数据的组数。
接下来 \(T\) 节,每节表示一组测试数据,格式如下:
- 第一行,一个整数 \(N\)。
- 第二行,\(N\) 个整数,其中第 \(i\) 个整数表示第 \(i\) 个点的权值。
- 接下来 \(N-1\) 行,每行两个整数 \(u\)、\(v\),表示树中有一条连接 \(u\)、\(v\) 的边。
输出描述
共 \(T\) 行,每行三个浮点数,保留三位小数,其中第 \(i\) 行的三个浮点数表示第 \(i\) 组数据对应的按位与、按位或、异或的期望。
输入输出描述 #1
输入样例 #1
1
4
1 2 3 4
1 2
2 3
2 4
输出样例 #1
0.875 4.250 3.375
提示/说明
数据范围
- 对于 \(20\)% 的数据,\(1 \leq N \leq 10^3\)。
- 另外 \(20\)% 的数据,\(N\) 个点构成一条链。
- 对于 \(100\)% 的数据,\(1 \leq N \leq 10^5\),\(1\leq T\leq 5\)。
对于不同路径的判断
设有两个树上的点 \(u\)、\(v\)。
- 若 \(u=v\),则路径 \(u \rightarrow v\) 和路径 \(v \rightarrow u\) 是相同路径。
- 若 \(u \neq v\),则路径 \(u \rightarrow v\) 和路径 \(v \rightarrow u\) 是不同路径。
解题报告
神秘树形 DP。
首先要有一个思路:按位与、按位或、异或三个运算都是各进制位独立运算,可以考虑分开进行。
所以,我们可以求出三个运算使每个二进制位为 \(1\) 的期望,最后统计总期望。
由于总路径数量一定,为 \(N \times N\),所以我们对于每个运算只需求出可以使每个二进制位为 \(1\) 的路径数。
一个很常见的思路:对于每个在以 \(u\) 为根的子树,统计以 \(u\) 为一个端点的路径的价值,再转换成每条在以 \(u\) 为根的子树的路径的价值。其实就是把 \(u\) 作为 LCA 的路径 \(s \rightarrow t\),把这条路径分成 \(s \rightarrow u\) 和 \(u \rightarrow t\) 两天路径,单独处理出每个 \(s \rightarrow u\) 和 \(u \rightarrow t\) 的路径的价值,既可以计算出每个路径的价值。
然后就是一个简单的树形 DP。
设 \(dp[u][i][0/1/2]\) 分别表示以节点 \(u\) 的子树中以 \(u\) 为端点的路径中按位与、按位或、异或后可以使第 \(i\) 个二进制位为 \(1\) 的路径数。
设 \(val[u][i]\) 表示节点 \(u\) 的权值第 \(i\) 个二进制位状态。
设 \(siz[u]\) 表示子树 \(u\) 的大小,同时也等价于子树内以 \(u\) 为端点的路径的条数。
设 \(v\) 为 \(u\) 的一个子节点。
转移方程很好推,代码如下:
for(int i=1;i<M;j++)
{if(val[u][i]){dp[u][i][0]+=dp[v][i][0];dp[u][i][1]+=siz[v];dp[u][i][2]+=siz[v]-dp[v][i][2];}else{dp[u][i][1]+=dp[v][j][1];dp[u][j][2]+=dp[v][i][2];}
}
然后就可以对子树 \(u\) 分别统计三个运算中使第 \(i\) 个二进制位为 \(1\) 的路径总数 \(cnt[0/1/2][i]\) :
// 统计u->u的路径
for(int j=1;j<M;j++)
{cnt[0][j]+=val[u][j];cnt[1][j]+=val[u][j];cnt[2][j]+=val[u][j];
}// 统计子树内经过 u 且起始点不同的路径
for(auto v:e[u])
{if(v==fa) continue;for(int j=1;j<M;j++){cnt[0][j]+=2*dp[u][i][0]*dp[v][i][0];cnt[1][j]+=2*(siz[u]*siz[v]-(siz[u]-dp[u][j][1])*(siz[v]-dp[v][j][1]));cnt[2][j]+=2*(dp[u][j][2]*(siz[v]-dp[v][j][2])+dp[v][j][2]*(siz[u]-dp[u][j][2]));}
}
然后统计答案就好了,总代码如下:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int INF=0x3f3f3f3f;
const int N=1001100;
const int M=32;#define ckmax(x,y) ( x=max(x,y) )
#define ckmin(x,y) ( x=min(x,y) )inline int read()
{int f=1,x=0; char ch=getchar();while(!isdigit(ch)) { if(ch=='-') f=-1; ch=getchar(); }while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar(); }return f*x;
}struct node
{int siz;int f[3][M];bool val[M];
}p[N];
int n;
int cnt[3][M];
vector<int> e[N];inline void addedge(int u,int v)
{e[u].push_back(v);e[v].push_back(u);
}inline void Clear()
{memset(cnt,0,sizeof(cnt));for(int i=1;i<=n;i++){memset(p[i].val,0,sizeof(p[i].val));memset(p[i].f,0,sizeof(p[i].f));p[i].siz=0;e[i].clear();}
}inline void debug(int u)
{printf("%d\n",u);for(int i=1;i<M;i++)printf("%d ",p[u].val[i]);cout<<endl;for(int j= 0;j<3;j++,putchar('\n'))for(int i=1;i<M;i++,putchar(' '))cout<<p[u].f[j][i];cout<<endl<<endl;
}void dfs(int u,int fa)
{p[u].siz=1;for(int i=0;i<e[u].size();i++){int v=e[u][i];if(v==fa) continue;dfs(v,u);for(int j=1;j<M;j++){cnt[0][j]+=2*p[u].f[0][j]*p[v].f[0][j];cnt[1][j]+=2*(p[u].siz*p[v].siz-(p[u].siz-p[u].f[1][j])*(p[v].siz-p[v].f[1][j]));cnt[2][j]+=2*(p[u].f[2][j]*(p[v].siz-p[v].f[2][j])+p[v].f[2][j]*(p[u].siz-p[u].f[2][j]));}p[u].siz+=p[v].siz;for(int j=1;j<M;j++){if(p[u].val[j]){p[u].f[0][j]+=p[v].f[0][j];p[u].f[1][j]+=p[v].siz;p[u].f[2][j]+=p[v].siz-p[v].f[2][j];}else{p[u].f[1][j]+=p[v].f[1][j];p[u].f[2][j]+=p[v].f[2][j];}}}// debug(u);
}signed main()
{freopen("tree.in","r",stdin);freopen("tree.out","w",stdout);int Q=read();while(Q--){n=read();for(int i=1;i<=n;i++){int x=read();for(int j=1;j<M;j++){p[i].val[j]=(x>>j-1)&1;p[i].f[0][j]=p[i].val[j];p[i].f[1][j]=p[i].val[j];p[i].f[2][j]=p[i].val[j];cnt[0][j]+=p[i].val[j];cnt[1][j]+=p[i].val[j];cnt[2][j]+=p[i].val[j];}}for(int i=1;i<n;i++)addedge(read(),read());dfs(1,0);int tot=n*n;double ans0=0,ans1=0,ans2=0;for(int i=1;i<M;i++){int tmp=(1<<i-1);ans0+=(double)tmp*cnt[0][i]/(double)tot;ans1+=(double)tmp*cnt[1][i]/(double)tot;ans2+=(double)tmp*cnt[2][i]/(double)tot;}printf("%.3lf %.3lf %.3lf\n",ans0,ans1,ans2);Clear();}return 0;
}