ABC425 G
题面
给定两个正整数 \(N,M\) 以及一个长度为 \(N\) 的非负整数序列 \(A = A_1, A_2, ..., A_N\)。求
\(1 \le N \le 2 \times 10^5\)
\(1 \le M \le 10^9\)
\(0 \le A_i \le 10^9\)
题解
首先,如果 \(M \le 10^5\),那么是不是可以想到把 \(A\) 建成一棵 01trie,然后依次在 tire 上查询每个 \(x\)。
我们设 \(d\) 表示最多有多少层,也就是 \(\log V\),其中 \(V\) 是值域,那么时间复杂度就是 \(O(Md)\)。
那么我们可以先建出一棵 trie,然后考虑如何处理 \(M \le 10^9\) 情况。
思考我们在 trie 上找异或最小值或者异或最大值,一般都是从上往下贪心的去遍历 trie,对于一个数是这样,对于一段连续的数,也是一样的,我们只关注最高位的情况。
所以我们可以将 \(0 \sim val\) 这一个区间的询问当成一个询问,然后对这个询问的情况进行分类讨论。
我们设当前询问为 \(0 \sim val\),保证到第 \(k\) 层时 \(0 \sim val\) 二进制最高位的 \(1\) 处在第 \(k\) 位及以下,分两种情况来讨论:
-
\(val < 2^k\):也就说明 \(0 \sim val\) 中的所有数的第 \(k\) 位都是 \(0\),尽量走左子树,如果走右子树,答案要加上 \(2^k \times (val + 1)\)。
-
\(val \ge 2^k\):这时要分成两段分别处理:
\(0 \sim 2^k - 1\):和上面的情况一样,尽量走左子树,否则答案加 \(2^k \times 2^k\)。
\(2^k \sim val\):这一段的第 \(k\) 位都是 \(1\),所以应该尽量走右子树,否则答案加 \((val + 1) \times 2^k\)。这一段走到下一层的时候应该减去 \(2^k\),因为要保证每次到下一层时二进制最高位的 \(1\) 处在第 \(k - 1\) 位及以下。
我们这样去处理的话,时间复杂度最坏还是 \(O(Md)\) 的,因为如果每层都分成两段,那么每到一层,询问个数都会翻倍,后面就会有 \(10^9\) 个询问,这是不能接受的。
要解决这个问题,就要用到一个神奇的 trick。
假如有两个询问,都是 \(0 \sim val\),那么我们将两个询问合并,再记录一个 \(cnt\) 表示 \(0 \sim val\) 的询问出现了多少次。这样就能将时间复杂度降到 \(O(\frac d 2 \times 2^{d/2} \log V)\)。
考虑证明,我们将 dfs 的层分成上 15 层和下 15 层:
对于上 15 层,我们最多分出 \(2^{15}\) 个询问。
对于下 15 层,因为值域为 \(2^{15}\),所以经过合并,最多也只有 \(2^{15}\) 个询问。
所以在每一层的询问个数都不会超过 \(2^{15}\) 个。
那么最后的时间复杂度就是 \(O(\frac d 2 \times 2^{d/2} \log V)\),其中 \(\frac d 2\) 表示每层将 \(2^{d / 2}\) 个询问进行合并,排序的时间复杂度。(经过yyb大佬指点,这个排序的时间复杂度是可以通过实现消去的。因为只有拆分会影响顺序,通过维护两个队列,一个是拆分的,一个是未拆分的,这样两个队列都是有序的,合并的时候就能实现线性合并了)
关于实现细节,因为每个询问都是从 \(0\) 开始的,所以我们只记录右端点即可,在每个节点的询问可以用一个 vector 来储存。
code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>using namespace std;const int N = 2e5 + 10, M = N * 30;typedef long long ll;
typedef pair <int, int> pii;int n, m;
int a[N], t[M][2], idx;
ll ans = 0;void insert (int x) {int p = 0;for (int i = 30; i >= 0; i --) {int ch = (x >> i) & 1;if (!t[p][ch]) t[p][ch] = ++ idx;p = t[p][ch];}
}void dfs (vector <pii> Q, int p, int k) {if (k < 0) return;vector <pii> ql, qr;sort (Q.begin (), Q.end ());for (int i = 0; i < (int)Q.size (); i ++) {while (i + 1 < (int)Q.size () && Q[i].first == Q[i + 1].first) {Q[i + 1].second += Q[i].second;i ++;}int val = Q[i].first, cnt = Q[i].second;if (val < (1 << k)) {if (t[p][0]) ql.emplace_back (val, cnt);else {ans += 1ll * (val + 1) * (1 << k) * cnt;qr.emplace_back (val, cnt);}} else {if (t[p][0]) ql.emplace_back ((1 << k) - 1, cnt);else {ans += 1ll * (1 << k) * (1 << k) * cnt;qr.emplace_back ((1 << k) - 1, cnt);}if (t[p][1]) qr.emplace_back (val - (1 << k), cnt);else {ans += 1ll * (val - (1 << k) + 1) * (1 << k) * cnt;ql.emplace_back (val - (1 << k), cnt);}}}if (t[p][0] && ql.size ()) dfs (ql, t[p][0], k - 1);if (t[p][1] && qr.size ()) dfs (qr, t[p][1], k - 1);
}int main () {cin >> n >> m;for (int i = 1; i <= n; i ++) {cin >> a[i];insert (a[i]);}vector <pii> Q;Q.emplace_back (m - 1, 1);dfs (Q, 0, 30);cout << ans << endl;return 0;
}