Sushi
给你 \(n(1\leq n\leq 300)\) 碗寿司,每碗寿司有 \(a_i(1\leq a_i\leq 3)\) 个,每次进行如下操作:
- 等概率地得到一个 \(i\),然后如果当前的 \(a_i>0\),就让 \(a_i-1\),否则就不需要做任何事。
求全部吃完的期望操作次数。
题目分析
注意到 \(a_i\in[1,3]\),故设 \(f_{i,j,k}\) 表示现在碗里只有 \(1\) 个的有 \(i\) 碗,只有 \(2\) 个的有 \(j\) 碗,只有 \(3\) 个的有 \(k\) 碗。
我们发现从 \(f_{x,y,z}\) 推到 \(f_{0,0,0}\) 是复杂的。
不妨将整个过程反过来变成生产寿司从 \(f_{0,0,0}\) 推到 \(f_{x,y,z}\) 即可。
那么我们有转移:
\[f_{i,j,k}=\frac{n-i-j-k}{n}(f_{i,j,k}+1)+\frac{i}{n}(f_{i-1,j,k}+1)+\frac{j}{n}(f_{i+1,j-1,k}+1)+\frac{k}{n}(f_{i,j + 1,k-1}+1)
\]
化简有:
\[(i+j+k)f_{i,j,k}=\frac{i}{n}f_{i-1,j,k}+\frac{j}{n}f_{i+1,j-1,k}+\frac{k}{n}f_{i,j+1,k-1}+n
\]
最后再除过去就可以了,于是你写出了以下的代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <iomanip>
#define int long long
#define N 305
using namespace std;
int n,a[N];
double f[N][N][N];
int sum1,sum2,sum3;
signed main(){cin >> n;for (int i = 1;i <= n;i ++) scanf("%lld",&a[i]);f[0][0][0] = 0.0;for (int i = 1;i <= n;i ++) sum1 += (a[i] == 1),sum2 += (a[i] == 2),sum3 += (a[i] == 3);for (int i = 0;i <= n;i ++)for (int j = 0;j <= n;j ++)for (int k = 0 + (i == j && j == 0);i + j + k <= n;k ++) {if (i) f[i][j][k] = f[i - 1][j][k] * i + f[i][j][k];if (j) f[i][j][k] = f[i + 1][j - 1][k] * j + f[i][j][k];if (k) f[i][j][k] = f[i][j + 1][k - 1] * k + f[i][j][k];f[i][j][k] = (f[i][j][k] + n) / (i + j + k);}cout << fixed << setprecision(10) << f[sum1][sum2][sum3];return 0;
}
这是错的,因为在转移 \(j\) 可行的时候,f_{i+1,j-1,k} 还没有被更新。
这似乎更那个区间 \(dp\) 直接枚举 \(i,j\) 一样是有问题的,于是我们可以像区间 \(dp\) 那样子先枚举长度就行了。
于是你得到:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <iomanip>
#define int long long
#define N 305
using namespace std;
int n,a[N];
double f[N][N][N];
int sum1,sum2,sum3;
signed main(){cin >> n;for (int i = 1;i <= n;i ++) scanf("%lld",&a[i]);f[0][0][0] = 0.0;for (int i = 1;i <= n;i ++) sum1 += (a[i] == 1),sum2 += (a[i] == 2),sum3 += (a[i] == 3);for (int len = 1;len <= n;len ++)for (int i = 0;i <= n;i ++)for (int j = 0;i + j <= len;j ++) {int k = len - i - j;f[i][j][k] = 0;if (i) f[i][j][k] += f[i - 1][j][k] * i;if (j) f[i][j][k] += f[i + 1][j - 1][k] * j;if (k) f[i][j][k] += f[i][j + 1][k - 1] * k;f[i][j][k] = (f[i][j][k] + n) / (i + j + k);}cout << fixed << setprecision(10) << f[sum1][sum2][sum3];return 0;
}
这还是错的,同样的错误,应该从 \(k\) 到 \(j\) 枚举。
代码
时间复杂度 \(\mathcal{O}(n^3)\),代码如下:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <iomanip>
#define int long long
#define N 305
using namespace std;
int n,a[N];
double f[N][N][N];
int sum1,sum2,sum3;
signed main(){cin >> n;for (int i = 1;i <= n;i ++) scanf("%lld",&a[i]);f[0][0][0] = 0.0;for (int i = 1;i <= n;i ++) sum1 += (a[i] == 1),sum2 += (a[i] == 2),sum3 += (a[i] == 3);for (int len = 1;len <= n;len ++)for (int k = 0;k <= len;k ++)for (int j = 0;k + j <= len;j ++) {int i = len - k - j;f[i][j][k] = 0;if (i) f[i][j][k] += f[i - 1][j][k] * i;if (j) f[i][j][k] += f[i + 1][j - 1][k] * j;if (k) f[i][j][k] += f[i][j + 1][k - 1] * k;f[i][j][k] = (f[i][j][k] + n) / (i + j + k);}cout << fixed << setprecision(10) << f[sum1][sum2][sum3];return 0;
}