关键词
多项式乘法,系数表示,点值表示,单位根
FFT基本思路
- 系数表示->点值多项式
- 点值下直接相乘,时间复杂度O(n)
- 点值多项式->系数表示
系数表示->点值多项式
- 分治思想,奇偶分开,单位根
- 假定\(f(x)=\sum_{i=0}^{n-1}a_ix^i\),其中n为2的幂次
- 对于一个有n个系数的多项式,点值表示需要n个不同的点,
- 那么考虑使用单位根\(x^n=1\)的n个解(\(\omega_{n}^0,\omega_{n}^1...\omega_{n}^{n-1}\)),来作为这n个点
- 那么我们只需求出\(f(\omega_{n}^0),f(\omega_{n}^1)...f(\omega_{n}^{n-1})\)就得到了这个多项式的点值表示
- 具体做法就是利用一点单位根的性质,我们将奇偶项分开
\(A(x)=a_0+a_2x^2+a_4x^4+...\)
\(B(x)=a1+a_3x^3+a_5x^5+...\)
\(f(x)=A(x^2)+xB(x^2)\)
\(f(\omega_{n}^{k})=A(\omega_{n}^{2k})+\omega_{n}^kB(\omega_{n}^{2k})\)
\(f(\omega_{n}^{k})=A(\omega_{\frac{n}{2}}^{k})+\omega_{n}^kB(\omega_{\frac{n}{2}}^{k})\)
\(f(\omega_{n}^{k+n/2})=A(\omega_{\frac{n}{2}}^{k})-\omega_{n}^kB(\omega_{\frac{n}{2}}^{k})\)
那么我们直接递归下去就行
op=1
void fft(cp* a, int n, int op) {if (n == 1) return;cp a1[n / 2], a2[n / 2];for (int i = 0;i * 2 < n;++i) {a1[i] = a[2 * i];a2[i] = a[2 * i + 1];}fft(a1, n / 2, op);fft(a2, n / 2, op);cp wn = (cp){ cos(2 * pi / n), op*sin(2 * pi / n) }; cp w = (cp){ 1,0 };for (int i = 0;i < n / 2;++i) {a[i] = a1[i] + w * a2[i];a[i + n / 2] = a1[i] - w * a2[i];w = w * wn;}
}
乘法
假如我们将两个多项式都使用点值表示,并且是相同的n个点,那么我们直接对应相乘,就得到了乘积多项式的点值表示
点值->系数表示
-
考虑使用拉格朗日插值将点值表示还原到系数表示
-
\(f(x)=\sum_{i=0}^{n-1}f(\omega_{n}^{i})L_i(x)\)
-
\(L_i(x)=\prod_{k\neq i}\frac{x-\omega_{n}^k}{\omega_{n}^i-\omega_{n}^k}\)
-
\(L_i(x)\)可以直接硬求,下面贴一个LLM的做法
-
稍微简单一点的做法,利用单位根的正交性质
-
\(\sum_{k=0}^{n-1} \omega_{n}^{k(i-j)}=n\delta_{i,j}\),\(\delta_{i,j}\)为克罗内克符号,\(\delta_{i,j}\)为1当且仅当\(i=j\)
-
那么我们要构造的\(L_i(x)\)本质上就是要让\(L_i(\omega_{n}^{j})=\delta_{i,j}\)
-
设\(L_i(x)=\sum_{k=0}^{n-1} c_{i,k}x^k\)
-
将\(\omega_{n}^{j}\)代入\(L_i(\omega_{n}^{j})=\sum_{k=0}^{n-1} c_{i,k}\omega_{n}^{j}\),那么我们对比一下它的正交性质的式子,只需令\(c_{i,k}=\frac{\omega_{n}^{-ik}}{n}\)就能搞定
-
因此有\(f(x)=\frac{1}{n}\sum_{i=0}^{n-1} f(\omega_{n}^{i})\sum_{k=0}^{n-1}\omega_{n}^{-ki}x^k\)
-
\(f(x)=\frac{1}{n}\sum_{k=0}^{n-1}x^k \sum_{i=0}^{n-1}f(\omega_{n}^i)\omega_{n}^{-ki}\)
-
\(\frac{1}{n}\sum_{i=0}^{n-1}f(\omega_{n}^i)\omega_{n}^{-ki}\)其实就是\(a_k\)
-
那么\(a_0,a_1...a_{n-1}\) 我们可以看作是求\(g(x)=\frac{1}{n}\sum_{i=0}^{n-1}f(\omega_{n}^i)x^i\)这个多项式在\(\omega_{n}^{-0},\omega_{n}^{-1}...\omega_{n}^{-(n-1)}\)的值
-
而我们第一部分求的是\(f(x)=\sum_{i=0}^{n-1}a_ix^i\)在\(\omega_{n}^{0},\omega_{n}^{1}...\omega_{n}^{(n-1)}\)的值,因此代码是可以复用的
蝶形优化
- 蝶形优化其实就是自底向上计算,那么首先需要求得每个数最后在哪里?
- 经过观察可以发现就是将它的二进制位进行一个翻转,比如n=8时(001->100,110->011)
- 那么将每个数放到最后一层的正确位置后,自底向上计算即可
#include<bits/stdc++.h>
#define lc (o<<1)
#define rc ((o<<1)|1)
using namespace std;
typedef long long ll;
typedef double db;
constexpr int N = 1 << 22;
constexpr ll inf = 1ll << 60;
const db pi = acos(-1);
struct cp {db x = 0, y = 0;cp(db x = 0, db y = 0) : x(x), y(y) {}
};
cp operator + (const cp& a, const cp& b) {return (cp) { a.x + b.x, a.y + b.y };
}
cp operator - (const cp& a, const cp& b) {return (cp) { a.x - b.x, a.y - b.y };
}
cp operator * (const cp& a, const cp& b) {return (cp) { a.x* b.x - a.y * b.y, a.x* b.y + a.y * b.x };
}
int n, m, r[N];
cp a[N], b[N], c[N];
void fft(cp* a, int n, int op) {for (int i = 0;i < n;++i) if (i < r[i]) swap(a[i], a[r[i]]);for (int i = 1;i < n;i *= 2) {cp wn = (cp){ cos(pi / i), sin(pi / i) * op };for (int j = 0;j < n;j += i << 1) {cp w = (cp){ 1,0 }, x, y;for (int k = 0;k < i;++k) {x = a[j + k];y = a[j + k + i];a[j + k] = x + w * y;a[j + k + i] = x - w * y;w = w * wn;}}}
}
void R(int& x) {int t = 0; char ch;for (ch = getchar();!('0' <= ch && ch <= '9');ch = getchar());for (;('0' <= ch && ch <= '9');ch = getchar()) t = t * 10 + ch - '0';x = t;
}
int main() {
#ifdef LOCALfreopen("data.in", "r", stdin);freopen("data.out", "w", stdout);
#endifcin >> n >> m;n++;m++;int t;for (int i = 0;i < n;++i) R(t), a[i].x = t;for (int i = 0;i < m;++i) R(t), b[i].x = t;int lim = 1;while (lim < n + m) lim <<= 1;for (int i = 0;i < lim;++i) {r[i] = r[i >> 1] >> 1;if (i & 1) r[i] += lim / 2;}fft(a, lim, 1);fft(b, lim, 1);for (int i = 0;i < lim;++i) c[i] = a[i] * b[i];fft(c, lim, -1);for (int i = 0;i <= n + m - 2;++i) printf("%d ", (int)(c[i].x / lim + 0.5));return 0;
}