给你三个长为 \(n\) 的序列 \(a,b,c\)。
求所有满足一下条件的 \([1,2,\cdots,n]\) 的长为 \(m\) 的子序列 \(p_1,p_2,\cdots,p_m\) 中,\(\sum_{i=1}^m c_{p_i}\) 的最大值
- \(a_{p_1}\le a_{p_2}\le\cdots\le a_{p_m}\)。
- \(\forall i \neq j,b_{p_i} \neq b_{p_j}\),即 \(b_i\) 互异。
\(1\le a_i,b_i\le n\le 3\times 10^3,1\le m\le 5\)
我们发现 \(b_i\) 互不相同这个条件非常难办,所以肯定要把它给转化掉。
有两种方法。
第一种是我们考虑,假如把第二个条件改成 \(b\) 单增,那么就是二维偏序,可以 \(\mathcal O(nm\log^2 n)\) 解决。
那么我们每次把 \(b\) 映射到一个随机的排列,有 \(\frac{1}{m!}\approx .008\) 的概率答案中的 \(b\) 被映射到单增序列上。
那么做多次即可,时间复杂度 \(\mathcal O(nm\log^2n \times m!)\)。
第二种是题解做法,我们考虑把每种颜色映射到 \(1\sim m\) 中的随机一种,然后直接状压,复杂度 \(O(n\log n2^m)\)。
正确率为答案中的 \(b\) 刚好被映射到排列上,也就是 \(\frac{m!}{m^m} \approx .03\)。
时间复杂度 \(\mathcal O(2^mn\log n \times \frac{m^m}{m!})\)。
我写的第一种,卡了好久才过 /ll
#include <algorithm>
#include <iostream>
#include <random>const int N = 3001, M = 512;std::mt19937 rnd(std::random_device{}());auto upd = [](auto& x, auto&& y) {x = std::max(x, y);
};struct Matr {int matr[N][M];inline void add(int x, int y, int z) {for(; x < N; x += x & -x)for(int j = y; j < M; j += j & -j)upd(matr[x][j], z);}inline int sum(int x, int y) {int z = 0;for(; x; x -= x & -x)for(int j = y; j; j -= j & -j)upd(z, matr[x][j]);return z;}inline void clear(int x, int y) {for(; x < N; x += x & -x)for(int j = y; j < M; j += j & -j)matr[x][j] = 0;}
};Matr Mat[4];
int n, m, a[N], rb[N], b[N], c[N], w[N], q[N];int solve() {std::shuffle(w + 1, w + n + 1, rnd);for(int i = 1; i <= n; ++i) b[i] = w[rb[i]];int ans = 0;for(int i = 1; i <= n; ++i) {Mat[0].add(a[i], b[i], c[i]);for(int t = 1; t < m - 1; ++t)if(auto val = Mat[t-1].sum(a[i], b[i]-1); val)Mat[t].add(a[i], b[i], val + c[i]);if(auto val = Mat[m-2].sum(a[i], b[i]-1); val)upd(ans, val + c[i]);}for(int i = 1; i <= n; ++i) for(int x = a[i]; x < N; x += x & -x)for(int y = b[i]; y < M; y += y & -y)for(int t = 0; t < m - 1; ++t)Mat[t].matr[x][y] = 0;return ans;
}int T = clock();int main() {std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);std::cin >> n >> m;for(int i = 1; i <= n; ++i) std::cin >> a[i];for(int i = 1; i <= n; ++i) std::cin >> rb[i];for(int i = 1; i <= n; ++i) std::cin >> c[i];for(int i = 1; i <= n; ++i) w[i] = i % (M - 1) + 1;if(m == 1) {int _ans = 0;for(int i = 1; i <= n; ++i)upd(_ans, c[i]);std::cout << _ans << "\n";} else {int _ans = 0;for(int T = 666; T--; )upd(_ans, solve());if(_ans == 0) std::cout << "-1\n";else std::cout << _ans << "\n";}
}
