题意:给出一个数 \(n\),现在我通过下面这个算法生成一个长为 \(n\) 的序列 \(a\)。
- 先进行 \(n\) 次随机扔一个硬币,然后如果你最后连续 \(k\) 次投出的是正面朝上,那么就将一个 \(k\) 加入序列末尾。
问序列 \(a\) 的 \(\operatorname{mex}\) 的期望值。
做法:
首先转化为 \(\operatorname{mex}\ge j\) 的概率和。
然后枚举 \(k\in[1,n]\),那么就要求 \([0,k-1]\) 全都要选至少一个,剩余的都可以。我们不妨认为,对于 \(i\in[0,k-1]\) 出现的概率为 \(2^{-(i+1)}\),而 \(i\ge k\) 的概率是 \(2^{-(k+1)}\)。那么我们枚举每个元素出现次数,注意我可以安排他们的位置,可以得到概率是:
把所有东西按 \(i\) 离开,发现就是个 exp 状物,那么可以改写为:
然后注意到中间这个 \(\prod\) 感觉很有性质,因为他们长得都很像。我们这里设 \(xf(x)=e^x-1\),至于为什么需要前面有个 \(x\) 后面会说,再改写柿子:
然后我们考虑,如果我们能算出来 \(g(x) = \prod\limits_{i=0}^{\infty}f(\frac{x}{2^{i+1}})\) 这个无限乘积,那么我就可以用 \(g(x)\times g^{-1}(\frac{x}{2^{k}})\) 直接算出来中间这个 \(\prod\)。
考虑怎么计算 \(g\),求积很麻烦,直接取 \(\log\) 换成求和,因为我们特意上面凑了一个 \(x\) 使得 \(f\) 的零次项是 \(1\) 所以可以直接取 \(\log\),得到:
然后直接展开右边的每一项,那么第 \(n\) 项会多带一个 \(\sum \frac{1}{2^{ni}}\) 的系数,也就是 \(\frac{1}{2^n-1}\),直接给这个 \(f\) 算出来然后很容易算出 \(g\)。
之后为了更方便看出变化的部分,会用颜色标记。
然后我们带回到整体的柿子里去,可以得到:
记 \(h(x) = g^{-1}(x)e^{x}\),有:
把后面这个东西的卷积展开,有:
这个东西中间用 Chirp Z 变换一下,\(k(n-k)=\binom{n}{2}-\binom k 2 -\binom {n-k}2\) 然后分离一下 \(n,k,n-k\),可以得到:
发现这个东西是可以递推计算的。
然后直接带回去解就可以了,复杂度 \(O(n\log n)\),瓶颈在于求 exp。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, mod = 998244353, gb = 3, gi = (mod + 1) / gb;
int qpow(int x, int k, int p) {int res = 1;while(k) {if(k & 1)res = res * x % p;x = x * x % p, k >>= 1;}return res;
}
int rev[maxn], inv[maxn];
void prepare(int n) {inv[1] = inv[0] = 1;for (int i = 2; i <= n; i++)inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
void init(int len) {for (int i = 0; i < len; i++) {rev[i] = rev[i >> 1] >> 1;if(i & 1)rev[i] |= (len >> 1);}
}
struct Poly {vector<int> a;void resize(int N) {a.resize(N);}void clear() {a.clear();}int size() {return a.size();}int& operator[](int x) {return a[x];}Poly() {}Poly(int N) {a.resize(N);}void NTT(int f) {for (int i = 0; i < size(); i++)if(i < rev[i])swap(a[i], a[rev[i]]);for (int h = 2; h <= size(); h <<= 1) {int d = qpow((f == 1 ? gb : gi), (mod - 1) / h, mod);for (int i = 0; i < size(); i += h) {int nw = 1;for (int j = i; j < i + h / 2; j++) {int a0 = a[j], a1 = a[j + h / 2] * nw % mod;a[j] = (a0 + a1) % mod, a[j + h / 2] = (a0 - a1 + mod) % mod;nw = nw * d % mod;}}}if(f == -1) {int inv = qpow(size(), mod - 2, mod);for (int i = 0; i < size(); i++) a[i] = a[i] * inv % mod;}}friend Poly operator*(Poly f, Poly g) {int len = 1, t = f.size() + g.size() - 1;while(len < t)len <<= 1;init(len), f.resize(len), g.resize(len);f.NTT(1), g.NTT(1);for (int i = 0; i < len; i++)f[i] = f[i] * g[i] % mod;f.NTT(-1);f.resize(t);return f; }friend Poly operator+(Poly f, Poly g) {int d = max(f.size(), g.size());f.resize(d), g.resize(d);for (int i = 0; i < d; i++)f[i] = (f[i] + g[i]) % mod;return f;}friend Poly operator-(Poly f, Poly g) {int d = max(f.size(), g.size());f.resize(d), g.resize(d);for (int i = 0; i < d; i++)f[i] = (f[i] - g[i] + mod) % mod;return f;}void print() {for (int i = 0; i < size(); i++)cout << a[i] << " ";cout << endl;}friend Poly operator+(Poly f, int v) {f[0] = (f[0] + v) % mod;return f;}friend Poly operator-(Poly f, int v) {f[0] = (f[0] - v + mod) % mod;return f;}
} ;
Poly get_deriv(Poly f) {Poly g(f.size() - 1);for (int i = 0; i < g.size(); i++)g[i] = f[i + 1] * (i + 1) % mod;return g;
}
Poly get_integ(Poly f) {Poly g(f.size() + 1);for (int i = 1; i < g.size(); i++)g[i] = f[i - 1] * inv[i] % mod;return g;
}
Poly get_inv(Poly f, int lim) {if(lim == 1) {f.resize(1);f[0] = qpow(f[0], mod - 2, mod);return f;}Poly g = get_inv(f, lim + 1 >> 1);int len = 1;while(len < lim * 2)len <<= 1;init(len);f.resize(lim), f.resize(len), g.resize(len);f.NTT(1), g.NTT(1);for (int i = 0; i < len; i++)f[i] = (2 * g[i] - f[i] * g[i] % mod * g[i] % mod + mod) % mod;f.NTT(-1);f.resize(lim);return f;
}
Poly get_ln(Poly f, int lim) {return get_integ(get_deriv(f) * get_inv(f, lim));
}
Poly get_exp(Poly f, int lim) {if(lim == 1) {Poly ans(1); ans[0] = 1;return ans;}f.resize(lim);Poly g = get_exp(f, lim + 1 >> 1), h = get_ln(g, lim);h = (f - h + 1);g = g * h;g.resize(lim);return g;
}
Poly shift(Poly f) {for (int i = 0; i < f.size() - 1; i++)f[i] = f[i + 1];f.resize(f.size() - 1);return f;
}
int n;
Poly g, h;
void get_gh() {Poly f; f.resize(n + 2);f[1] = 1;f = shift(get_exp(f, n + 2) - 1);f = get_ln(f, n + 1);g.resize(n + 1);for (int i = 1; i <= n; i++)g[i] = f[i] * qpow(qpow(2, i, mod) - 1, mod - 2, mod) % mod;g = get_exp(g, n + 1);f.clear(), f.resize(n + 1);f[1] = 1, f = get_exp(f, n + 1);h = get_inv(g, n + 1) * f; h.resize(n + 1);
}
signed main() {cin >> n;prepare(n + 2);get_gh();Poly res; res.resize(n + 1);for (int i = 0; i <= n; i++)h[i] = h[i] * qpow(2, i * (i - 1) / 2, mod) % mod;int inv2 = (mod + 1) / 2;for (int i = 1; i <= n; i++) res[i] = (res[i - 1] * inv2 % mod + inv2 * h[i - 1] % mod) % mod;for (int i = 1; i <= n; i++)res[i] = res[i] * qpow(qpow(2, i * (i - 1) / 2, mod), mod - 2, mod) % mod;res = res * g;for (int i = 1; i <= n; i++)res[n] = res[n] * i % mod;cout << res[n] << endl;return 0;
}