2025.10.26 闲话-单位根反演
起因正在和 zxk 探讨 k 叉 bostan-mori。
jijidawang:直接单位根反演。
所以就来学习单位根反演了。
Part.1 主体
首先引入这样一个问题:
求:
可以构造 \(f(x)=\sum_{i=0}^{n}{\binom{n}{i}x^i}\)
然后要求的是所有偶数项的系数之和,发现如果将 \(-x\) 带入,恰好所有的奇数项均被翻转,那么有:
所求即为 \(\frac{f(1)+f(-1)}{2}\)
另一个问题,如何求:
发现不好找一个合适的数使得非奇数项被干掉,但是通过数学知识可得:
\(w_3^0+w_3^1+w_3^2=0\) 这启发我们:
那么不妨推广一下:
考虑怎么证明首先发现后面是一个等差数列,先去除 1 的情况,即 \(w_{k}^{n}=1\),此时可得 \(k|n\),带入成立。
否则:
以上便是单位根反演,更常用的是放到多项式中:
注意到 \(\sum_{i=0}^{n}{a_i(w_k^j)^i}=f(w_k^j)\)
所以:
十分优美。
模意义下的原根等价于单位根,不过我不会原根。
容易发现一件事情,如果 \(k\) 为 \(2^t\),那么 \(f(w_k^j)\) 其实就是进行 NTT 后的系数。
扩展一点,如果将 \([k|i]\) 改为 \(i\bmod k=t\) 怎么做。
容易发现,将 \(f(x)\) 变为 \(x^{k-t}f(x)\) 后与原问题等价。
Part.2 例题
luogu P10664 PYXFIB
模板题,上述过程可以推广到矩阵。
设 \(I\) 为单位矩阵,\(F,G\) 为矩阵求斐波那契转移矩阵。
复杂度 \(O(k\log n)\)(求原根不算在内),注意处理原根。
luogu P5591 小猪佩奇学数学
这题有一万种做法,不过我自己搞了一种简单易懂的做法,通过看题解学会了另一种更简单的做法,那做法吊打了我的做法。
首先推式子:
\( \begin{aligned} ans &= \sum_{i=0}^n \binom n ip^{i}\left\lfloor \frac{i}{k} \right\rfloor \\ &= \sum_{i=0}^n \binom n i p^{i} \frac{i-(i\bmod k)}{k} \\ &= \sum_{i=0}^n \binom n i p^{i} \frac{i}{k} - \sum_{i=0}^n \binom n i p^{i} \frac{(i\bmod k)}{k} \\ &= \frac{1}{k}\sum_{i=0}^n \frac{n!}{i!(n-i)!} p^{i} i - \frac{1}{k}\sum_{i=0}^n \binom n i p^{i} (i\bmod k) \\ &= \frac{np}{k}\sum_{i=1}^n \frac{(n-1)!}{(i-1)!(n-i)!} p^{i-1} - \frac{1}{k}\sum_{t=0}^{k-1}\sum_{i=0}^n \binom n i p^{i} t [i\bmod k=t] \\ &= \frac{np}{k}(p+1)^{n-1} - \frac{1}{k}\sum_{t=0}^{k-1}\sum_{i=0}^n \binom n i p^{i} t [i\bmod k=t] \\ \end{aligned} \)
先讲我的做法,主要是处理后面:
设 \(f(x)=\sum_{i=0}^{n}\binom{n}{i}p^ix^{n-i}=(1+xp)^n\)。
那么可得:
\( \begin{aligned} ans' &= \frac{1}{k}\sum_{t=0}^{k-1}\sum_{i=0}^n \binom n i p^{i} t [i\bmod k=t] \\ &= \frac{1}{k^2}\sum_{i=0}^{k-1}\sum_{t=0}^{k-1}f(w_k^i)(w_k^i)^{k-t}t \\ \end{aligned} \)
此时需要将 \(w_k^i\) 替换为 \(x\),需要处理:
发现只需要进行两次 NTT,然后将两个多项式的点值表示乘起来即可,最后:
复杂度 \(O(k\log k)\)。
做法 1
#include <iostream>
#include <set>
#include <vector>
#include <algorithm>
#include <queue>
#include <cstring>
#include <unordered_map>
#include <map>
#include <ctime>using namespace std;const int N = 4e6 + 10, mod = 998244353;#define int long long#define emp emplace_back
#define pb push_back
#define fi first
#define se secondusing pii = pair <int, int>;int fac[N], inv[N], s[N], p[N], f[N], g[N];namespace Poly
{int qpow(int x, int b){int res = 1;while (b){if (b & 1) res = res * x % mod;x = x * x % mod;b >>= 1;}return res;}int rev[N];void NTT(int *a, int k, bool op = 0){for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? k >> 1 : 0);for (int i = 0; i < k; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);for (int len = 2; len <= k; len <<= 1){int wn = qpow(3, (mod - 1) / len);for (int l = 0, mid = (len >> 1) - 1; l + len - 1 < k; l += len, mid += len){int w = 1;for (int i = l; i <= mid; i++, w = w * wn % mod){int x = a[i], y = a[i + (len >> 1)] * w % mod;a[i] = x + y;if (a[i] >= mod) a[i] -= mod;a[i + (len >> 1)] = x - y;if (a[i + (len >> 1)] < 0) a[i + (len >> 1)] += mod;}}}if (op){reverse(a + 1, a + k);int inv = qpow(k, mod - 2);for (int i = 0; i < k; i++) a[i] = a[i] * inv % mod;}}void fill(int *f, int l, int r, int v) {for (int i = l; i < min((long long)N, r); i++) f[i] = v;}void copy(int *f, int *h, int l, int r) {for (int i = l; i < r; i++) h[i] = f[i];}int mulf[N], mulg[N];void mul(int *f, int *g, int *h, int n, int m){int len = 1;while (len < n + m) len <<= 1;fill(mulf, 0, len, 0), fill(mulg, 0, len, 0);copy(f, mulf, 0, len), copy(g, mulg, 0, len);NTT(mulf, len, 0), NTT(mulg, len, 0);for (int i = 0; i < len; i++) h[i] = mulf[i] * mulg[i] % mod;NTT(h, len, 1);for (int i = n + m - 1; i < len; i++) h[i] = 0;}int invh[N], invf[N];void Inv(int *f, int *h, int n){if (n == 1) return h[0] = qpow(f[0], mod - 2), void();Inv(f, h, (n + 1) >> 1);int len = 1;while (len < 2 * n) len <<= 1;fill(invf, 0, len, 0);copy(f, invf, 0, n);NTT(invf, len, 0), NTT(h, len, 0);for (int i = 0; i < len; i++) h[i] = h[i] * (2 - h[i] * invf[i] % mod + mod) % mod;NTT(h, len, 1);fill(h, n, len, 0);}void dev(int *f, int len) {for (int i = 1; i < len; i++) f[i - 1] = i * f[i] % mod; f[len - 1] = 0;}void redev(int *f, int len) {for (int i = len - 1; i >= 0; i--) f[i + 1] = f[i] * qpow(i + 1, mod - 2) % mod; f[0] = 0;}int lnf[N], lng[N];void ln(int *f, int *h, int n){fill(lnf, 0, 4 * n, 0), fill(lng, 0, 4 * n, 0);copy(f, lnf, 0, n);dev(lnf, n);Inv(f, lng, n);mul(lnf, lng, h, n, n);redev(h, n);fill(h, n, 2 * n, 0);}int _expf[N], expg[N];void exp(int *f, int *h, int n){if (n == 1) return h[0] = 1, void();exp(f, h, (n + 1) >> 1);fill(_expf, 0, 2 * n, 0);fill(expg, 0, 2 * n, 0);copy(h, _expf, 0, n);fill(_expf, n, 2 * n, 0);ln(_expf, expg, n);for (int i = 0; i < n; i++) expg[i] = (-expg[i] + f[i] + mod) % mod;expg[0]++;mul(_expf, expg, h, n, n);fill(h, n, 2 * n, 0);}
}using namespace Poly;int C(int n, int m) {return n >= m ? fac[n] * inv[m] % mod * inv[n - m] % mod : 0;}int A(int n, int m) {return C(n, m) * fac[m] % mod;}signed main()
{// freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);// freopen("mission.in", "r", stdin); freopen("mission.out", "w", stdout);ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);int n, p, k; cin >> n >> p >> k;f[0] = 1, f[1] = p;for (int i = 0; i < k; i++) g[i] = k - i;g[0] = 0;NTT(f, k, 0), NTT(g, k, 0);for (int i = 0; i < k; i++) f[i] = qpow(f[i], n) * g[i] % mod;int ans = 0, invk = qpow(k, mod - 2);for (int i = 0; i < k; i++) ans = (ans + mod - invk * f[i] % mod) % mod;ans = (ans + n * p % mod * qpow(p + 1, n - 1)) % mod;cout << ans * invk % mod;return 0;
}
更为简单的做法是发现 NTT 的操作是模意义下的,也就是如果多项式系数超过了 NTT 中的长度,那么会累加到 \(i\bmod k\) 上,发现这和上述问题中求的东西恰好匹配上了,所以直接对 \(f(x)\) NTT 然后 INTT 就对了。
做法 2
#include <iostream>
#include <set>
#include <vector>
#include <algorithm>
#include <queue>
#include <cstring>
#include <unordered_map>
#include <map>
#include <ctime>using namespace std;const int N = 4e6 + 10, mod = 998244353;#define int long long#define emp emplace_back
#define pb push_back
#define fi first
#define se secondusing pii = pair <int, int>;int fac[N], inv[N], s[N], p[N], f[N], g[N];namespace Poly
{int qpow(int x, int b){int res = 1;while (b){if (b & 1) res = res * x % mod;x = x * x % mod;b >>= 1;}return res;}int rev[N];void NTT(int *a, int k, bool op = 0){for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? k >> 1 : 0);for (int i = 0; i < k; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);for (int len = 2; len <= k; len <<= 1){int wn = qpow(3, (mod - 1) / len);for (int l = 0, mid = (len >> 1) - 1; l + len - 1 < k; l += len, mid += len){int w = 1;for (int i = l; i <= mid; i++, w = w * wn % mod){int x = a[i], y = a[i + (len >> 1)] * w % mod;a[i] = x + y;if (a[i] >= mod) a[i] -= mod;a[i + (len >> 1)] = x - y;if (a[i + (len >> 1)] < 0) a[i + (len >> 1)] += mod;}}}if (op){reverse(a + 1, a + k);int inv = qpow(k, mod - 2);for (int i = 0; i < k; i++) a[i] = a[i] * inv % mod;}}void fill(int *f, int l, int r, int v) {for (int i = l; i < min((long long)N, r); i++) f[i] = v;}void copy(int *f, int *h, int l, int r) {for (int i = l; i < r; i++) h[i] = f[i];}int mulf[N], mulg[N];void mul(int *f, int *g, int *h, int n, int m){int len = 1;while (len < n + m) len <<= 1;fill(mulf, 0, len, 0), fill(mulg, 0, len, 0);copy(f, mulf, 0, len), copy(g, mulg, 0, len);NTT(mulf, len, 0), NTT(mulg, len, 0);for (int i = 0; i < len; i++) h[i] = mulf[i] * mulg[i] % mod;NTT(h, len, 1);for (int i = n + m - 1; i < len; i++) h[i] = 0;}int invh[N], invf[N];void Inv(int *f, int *h, int n){if (n == 1) return h[0] = qpow(f[0], mod - 2), void();Inv(f, h, (n + 1) >> 1);int len = 1;while (len < 2 * n) len <<= 1;fill(invf, 0, len, 0);copy(f, invf, 0, n);NTT(invf, len, 0), NTT(h, len, 0);for (int i = 0; i < len; i++) h[i] = h[i] * (2 - h[i] * invf[i] % mod + mod) % mod;NTT(h, len, 1);fill(h, n, len, 0);}void dev(int *f, int len) {for (int i = 1; i < len; i++) f[i - 1] = i * f[i] % mod; f[len - 1] = 0;}void redev(int *f, int len) {for (int i = len - 1; i >= 0; i--) f[i + 1] = f[i] * qpow(i + 1, mod - 2) % mod; f[0] = 0;}int lnf[N], lng[N];void ln(int *f, int *h, int n){fill(lnf, 0, 4 * n, 0), fill(lng, 0, 4 * n, 0);copy(f, lnf, 0, n);dev(lnf, n);Inv(f, lng, n);mul(lnf, lng, h, n, n);redev(h, n);fill(h, n, 2 * n, 0);}int _expf[N], expg[N];void exp(int *f, int *h, int n){if (n == 1) return h[0] = 1, void();exp(f, h, (n + 1) >> 1);fill(_expf, 0, 2 * n, 0);fill(expg, 0, 2 * n, 0);copy(h, _expf, 0, n);fill(_expf, n, 2 * n, 0);ln(_expf, expg, n);for (int i = 0; i < n; i++) expg[i] = (-expg[i] + f[i] + mod) % mod;expg[0]++;mul(_expf, expg, h, n, n);fill(h, n, 2 * n, 0);}
}using namespace Poly;int C(int n, int m) {return n >= m ? fac[n] * inv[m] % mod * inv[n - m] % mod : 0;}int A(int n, int m) {return C(n, m) * fac[m] % mod;}signed main()
{// freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);// freopen("mission.in", "r", stdin); freopen("mission.out", "w", stdout);ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);int n, p, k; cin >> n >> p >> k;f[0] = 1, f[1] = p;for (int i = 0; i < k; i++) g[i] = k - i;g[0] = 0;NTT(f, k, 0), NTT(g, k, 0);for (int i = 0; i < k; i++) f[i] = qpow(f[i], n) * g[i] % mod;int ans = 0;for (int i = 0; i < k; i++) ans = (ans + mod - qpow(k, mod - 2) * f[i] % mod) % mod;ans = (ans + n * p % mod * qpow(p + 1, n - 1)) % mod;cout << ans * qpow(k, mod - 2) % mod;return 0;
}
速度相差不大。
