用于高效解决线性递推问题
前言
在算法竞赛和实际编程中,我们经常遇到需要计算矩阵的高次幂的问题。如果直接用朴素的矩阵乘法来计算,时间复杂度会达到O(n³ × k),其中n是矩阵的维度,k是幂次。当k非常大时(比如\(10^9\)),这样的时间复杂度是无法接受的。矩阵快速幂就是为了解决这个问题而诞生的算法。
什么是矩阵快速幂?
矩阵快速幂就是快速幂算法再矩阵运算中的应用。它基于这样的思想:
\(A^8\) = \((( A^2 )^2)^2\)
\(A^9\)=\(A × A^8\)
通过将指数进行二进制分解,我们可以在O(log k)次矩阵乘法内计算出\(A^k\)。
快速幂基本原理
以计算a^n为例:
-
如果n是偶数:\(a^n\) = \((a^{n/2})^2\)
-
如果n是奇数:\(a^n\) = a ×$ a^{n-1}$
这个过程可以用递归或迭代实现。
矩阵快速幂的实现步骤
矩阵定义
struct matrix {int mat[6][6];void init() {memset(mat, 0, sizeof(mat));}
};
矩阵乘法
由于幂的过程中,我们要实现一个矩阵的乘法,所以首先要将矩阵乘法写出;
matrix mul(matrix a, matrix b) { //return a*bmatrix c;c.init();for (int i = 0; i < 6; i++) {for (int j = 0; j < 6; j++) {for (int k = 0; k < 6; k++) {c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;c.mat[i][j] %= mod;}}}return c;
}
矩阵快速幂
matrix fast_pow(matrix A, int n) { //return A^n%modmatrix B;B.init();for (int i = 0; i < 6; i++) { //单位矩阵B.mat[i][i] = 1;}while (n) {if (n & 1) {B = mul(B, A);}A = mul(A, A);n >>= 1;}return B;
}
例题HDU - 2802
本题就是根据递推公式,求其中的某一项即F(n),那么可以用矩阵快速幂来解决此题(本题还可以打表找出规律求解,这里只介绍矩阵快速幂方法^^)
在写之前先求出转移矩阵是此题的关键
递推式化简:
\(F(N)=F(N−2)+N^3−(N−1)^3\)
展开:
\(N^3−(N−1)^3=3*N^2−3*N+16\)
所以:
\(F(N)=F(N−2)+3*N^2−3*N+1\)
状态向量构造:
我们需要同时存$ F(N)\(,以及和 N 相关的多项式项(\)N_2\(, N, 常数)。 因为转移里有\) F(N−2)$,所以状态至少要带上 \(F(N)\),\(F(N−1)\)。
定义:
转移矩阵:
根据递推式:
\(F(N)=F(N−2)+3*N^2−3*N+1\)
所以:
-
\(F(N)\) 依赖 \(F(N−2)\),而$ F(N−2) $就是上一步向量里的第二维。
-
多项式部分直接线性组合 \(N_2\),\(N\),1。
然后要写出矩阵,把 \(S(N)\) 表达为 \(M⋅S(N−1)\)。
完整代码
#include <iostream>
#include <cstring>
#define int long long
using namespace std;const int mod = 2009;struct matrix {int mat[6][6];void init() {memset(mat, 0, sizeof(mat));}
};matrix mul(matrix a, matrix b) { //return a*bmatrix c;c.init();for (int i = 0; i < 6; i++) {for (int j = 0; j < 6; j++) {for (int k = 0; k < 6; k++) {c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;c.mat[i][j] %= mod;}}}return c;
}matrix fast_pow(matrix A, int n) { //return A^n%modmatrix B;B.init();for (int i = 0; i < 6; i++) { //单位矩阵B.mat[i][i] = 1;}while (n) {if (n & 1) {B = mul(B, A);}A = mul(A, A);n >>= 1;}return B;
}signed main() {int N;while (cin >> N && N) {if (N == 1) {cout << 1 % mod << "\n";continue;}if (N == 2) {cout << 7 % mod << "\n";continue;}// 状态向量: [F(n), F(n-1), n^2, n, 1, dummy]// 6维里最后一个可以闲置// 我们要求 F(N),利用矩阵快速幂推到目标// 转移矩阵 M: S(k) -> S(k+1)matrix M;M.init();// F(k+1) = F(k-1) + 3(k+1)^2 - 3(k+1) + 1// => 依赖 F(k-1) 以及 (k+1)^2, (k+1), 1// 这里直接手写对应项// [F(k+1)] = [0 1 3 -3 1 0] * S(k)// [F(k)] = [1 0 0 0 0 0] * S(k)// [ (k+1)^2 ] = 转移自 N^2, N, 1// [ (k+1) ] = ...// [ 1 ] = [0 0 0 0 1 0]M.mat[0][1] = 1; // F(k-1)M.mat[0][2] = 3; // +3*N^2M.mat[0][3] = 3; // -3*NM.mat[0][4] = 1; // +1M.mat[1][0] = 1; // F(k)// (k+1)^2 = N^2 + 2N + 1M.mat[2][2] = 1;M.mat[2][3] = 2;M.mat[2][4] = 1;// (k+1) = N + 1M.mat[3][3] = 1;M.mat[3][4] = 1;M.mat[4][4] = 1; // 常数项保持 1// 取模修正负数for (int i = 0; i < 6; i++) {for (int j = 0; j < 6; j++) {M.mat[i][j] = (M.mat[i][j] % mod + mod) % mod;}}// 初始状态 S(2)int S[6] = {7, 1, 4, 2, 1, 0}; // [F(2)=7, F(1)=1, 2^2=4, 2, 1, 0]// 快速幂:从 S(2) 推到 S(N)matrix P = fast_pow(M, N - 2);// 结果向量 = P * Sint ans[6] = {0};for (int i = 0; i < 6; i++) {for (int j = 0; j < 6; j++) {ans[i] = (ans[i] + P.mat[i][j] * S[j]) % mod;}}cout << ans[0] % mod << "\n"; // F(N)}return 0;
}