很牛的观察题,神秘 dp 优化。
题意:给出一个长为 \(n\) 的序列 \(a\) 和 \(k\),要求你选出一个长为 \(k\) 的序列 \(b\),\(1\le b_i\le n\),\(b\) 中元素可以重复。定义这样的序列 \(b\) 权值为 \(\sum\limits_{i=1}^k a_{b_i}-\prod\limits_{i=1}^k b_i\),求最大权值。
\(n\le 10^6,a_i\le 10^9\)。
做法:
首先发现,如果我选多于 \(30\) 个大于 \(1\) 的元素那么肯定不优,因为左侧增大最多 \(10^9\) 而减去的至少多了 \(2^{30}\)。
然后还是不会做,只有朴素的 \(V\log V\) 直接枚举我目前选择的元素的乘积是多少。
先思考一个部分分,\(k=2\) 怎么做,这个就是求 \(a_i+a_j-i\times j\) 的最大值,直接移项,固定 \(a_i\) 去找 \(a_j\),直接维护 \((j,a_j)\) 的凸包在上面切就行了。
拓展到 \(k=30\) 貌似没什么用,我们先来证明一个引理:
- 对于一个可重集合 \(S\),\(\forall x\in S,0\le x\le \frac2 3,\sum\limits_{x\in S}x =1\),那么会存在一个集合 \(T\subset S\),满足 \(\frac{1}{2}\le\sum\limits_{x\in T}x\le\frac 2 3\)。
考虑证明,我们重复以下过程:
-
如果集合中存在一个数 \(x\) 满足 \(\frac 1 2\le x\le \frac 2 3\),直接得证。
-
否则我们选取最小的两个数 \(x,y\),删掉他们并往集合里扔入 \(x+y\)。
当元素只有两个的时候,较大的一定满足方式 1,得证;而如果是方式 2,我们加进去的 \(x+y\) 一定不大于 \(\frac2 3\),因为如果大于,那么全集减去这两个数就小于 \(\frac 1 3\),肯定这两个数有一个大于 \(\frac 1 3\),选的就不是最小的两个了,所以我们一定能从一个大小为 \(|S|\) 的集合变成大小减一的集合,归纳就可以证明。
那么这个结论有什么好处呢?我们直接让每个 \(b\) 对我目标的 \(\prod\) 取个 \(\log\),那就等于我要构造一个这样的 \(S\)。直接构造很麻烦,但是因为有这样一个 \(T\subset S\),我可以先对 \(V^{\frac 2 3}\) 范围内的先做出来答案,然后两个 \(V^{\frac23}\) 的答案拟合一下就可以得到 \(V\) 范围内的答案!
所以我们现在只需要对 \(\le V^{\frac23}\) 的跑直接选的做法,然后枚举 \(|T|\),因为 \(T,S/T\) 都是在这个范围内的,直接用 \(k=2\) 的方式拟合在一起就是答案了。
复杂度 \(O(V^{\frac2 3}\log V^{\frac 2 3}\min(k,\log V))\)。
官方题解中貌似提到了可以优化成 \(O(n\log V^{\frac 2 3}\min(k,\log V))\),但是 \(n\) 和 \(V^{\frac 2 3}\) 不是同阶吗?并不理解这么做的意义,但是可以看看进一步的引理及证明,在此略去。
有点卡常。
代码:
#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize(2)
#pragma GCC optimize("Ofast")
const int maxn = 1e6 + 5, V = 1e6;
int n, k, a[maxn];
long long val;
inline long double slope(long long *dp, int i, int j) {return (long double)(dp[j] - dp[i]) / (j - i);
}
inline long long calc(long long *f, long long *g, int i, int j) {return f[i] + g[j] - 1ll * i * j;
}
int st[maxn], top;
long long dp[31][maxn], ans = -1, t[maxn];
long long solve(long long *f, long long *g) {top = 0;for (int i = 2; i <= V; i++) {while(top > 1 && slope(f, st[top - 1], st[top]) <= slope(f, st[top], i))top--;st[++top] = i;}long long ans = -1;for (int i = 2; i <= V; i++) {while(top > 1 && calc(f, g, st[top], i) <= calc(f, g, st[top - 1], i))top--;ans = max(ans, calc(f, g, st[top], i));}return ans;
}
signed main() {ios::sync_with_stdio(false);cin >> n >> k;for (int i = 1; i <= n; i++)cin >> a[i];val = a[1];memset(dp, -0x3f, sizeof(dp));for (int i = n; i >= 1; i--)a[i] -= a[1], dp[1][i] = a[i];int lim = min(30, k);for (int i = 2; i <= lim; i++) {for (int j = 2; j <= n; j++)for (long long k = 1ll * j * j; k <= V; k += j)dp[i][k] = max(dp[i][k], dp[i - 1][k / j] + a[j]);}for (int i = 1; i <= lim; i++)for (int j = 2; j <= V; j++) ans = max(ans, dp[i][j] - j);memset(t, -0x3f, sizeof(t));for (int i = lim - 1; i >= 1; i--) {for (int j = 1; j <= V; j++)t[j] = max(t[j], dp[lim - i][j]);ans = max(ans, solve(dp[i], t));}cout << ans + k * val << endl;return 0;
}