序列与整数对
题面
给定一个长度为 \(n\) 的正整数序列 \(A_1, A_2, \cdots,A_n\) ,有 \(m\) 次询问,每次给定两个正整数 \(x, y\) ,求有多少个整数对 \((i,j)\) 满足 \(1 \le i < j \le n,A_i = x, A_j = y\)。
\(1 \le n, m \le 10^5\)
\(1 \le A_i, x, y \le 10^9\)
题解
这道题用根号分治思想来解决。
先挑出这道题的难点,就是对于每个询问,我们无法快速找到有多少个对应的整数对。
朴素做法:先离散化,然后对每个数记录其出现位置,询问枚举 \(x/y\) 的出现位置,然后再另外一个出现位置中查找,符合条件的解,单次时间复杂度 \(O(n^2)\) ,可以用双指针优化到 \(O(n)\)。总时间复杂度 \(O(nm)\)
我们应用轻重分治思想来解决这道题。
首先我们考虑朴素解法,如果我们不进行处理,那么每次查询是 \(O(n)\) 的。
我们定义 \(B\) 表示一个分界线,如果出现次数 \(\le B\) 那么我们定义其为轻,否则为重。
对每个询问分类讨论:
-
\(x, y\) 两个都为轻
我们枚举 \(x\) 的出现位置,然后用双指针即可实现 \(O(cnt_x + cnt_y)\),单次时间复杂度 \(O(B)\)
-
\(x, y\) 其中一个为重
那么如果我们还是枚举,那么最坏就是 \(O(n)\) 的,所以我们将其进行预处理,对于每个出现次数 \(>B\) 的 \(x\)。
我们预处理出 \(mp1[x][y]\) 表示询问 \((x,y)\) 的答案,\(mp2[x][y]\) 表示询问 \((y,x)\) 的答案。
对于每个 \(x\),这两个都可以 \(O(n)\) 求,因为每个 \(cnt_x > B\) 所以这样的 \(x\) 的个数不会超过 \(\frac n B\) 个
预处理时间复杂度 \(O(n \frac n B)\)。
总时间复杂度 \(O(n(B + \frac n B))\)。
code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <map>
#include <cmath>using namespace std;typedef long long ll;const int N = 1e5 + 10;int n, m;
int a[N], b[N], cnt;
int sum[N];
map <pair <int, int>, ll> mp1, mp2;
vector <int> pos[N];int main () {// freopen ("test/test.in", "r", stdin);// freopen ("test/test.out", "w", stdout);cin >> n >> m;int B = sqrt (n);for (int i = 1; i <= n; i ++) {cin >> a[i];b[i] = a[i];}sort (b + 1, b + 1 + n);cnt = unique (b + 1, b + 1 + n) - 1 - b;for (int i = 1; i <= n; i ++) {a[i] = lower_bound (b + 1, b + 1 + cnt, a[i]) - b;pos[a[i]].push_back (i);}for (int i = 1; i <= cnt; i ++) {if (pos[i].size () > B) {sum[0] = 0, sum[n + 1] = 0;for (int j = 1; j <= n; j ++) {if (a[j] == i) {sum[j] = sum[j - 1] + 1;} else {sum[j] = sum[j - 1];mp1[{i, a[j]}] += sum[j];}}for (int j = n; j >= 1; j --) {if (a[j] == i) {sum[j] = sum[j + 1] + 1;} else {sum[j] = sum[j + 1];mp2[{i, a[j]}] += sum[j];}}}}for (int i = 1; i <= m; i ++) {int fx, fy, x, y;cin >> fx >> fy;x = lower_bound (b + 1, b + 1 + cnt, fx) - b;y = lower_bound (b + 1, b + 1 + cnt, fy) - b;if (b[x] != fx || b[y] != fy) {cout << 0 << endl;continue;}ll ans = 0;if (x == y) {ll cnt = pos[x].size ();ans = cnt * (cnt - 1) / 2;} else if (pos[x].size () <= B && pos[y].size () <= B) {int cnt1 = pos[x].size (), cnt2 = pos[y].size ();int p1 = 0, p2 = 0;for (; p1 < cnt1; p1 ++) {// 保证 y 在 x 右边while (p2 < cnt2 && pos[x][p1] >= pos[y][p2]) p2 ++;ans += cnt2 - p2;}} else {if (pos[x].size () > B) ans = mp1[{x, y}];else ans = mp2[{y, x}];}cout << ans << endl;}return 0;
}