考虑到可以将每个数最后一次出现与第一次出现的位置之差拆成若干个相邻位置之差:
\[last_i - first_i = \sum i-pre_i
\]
且每次修改一个点,对 \(pre\) 的影响是 \(O(1)\) 的,所以我们可以将所求的答案转为一个(带权的)二维偏序:
\[\sum_{l \le pre_i, i\le r} i - pre_i
\]
将时间维加上,差分将这个“四维偏序”转为三维偏序:在每一个点出现时,添加一个贡献为 \(+v\) 的点,删除时,添加一个贡献为 \(-v\) 的点。
最后,用 CDQ 分治做三维偏序即可。复杂度 \(O(n\log ^2 n)\)。
const int MAXN = 1e5 + 5;
int n, m, a[MAXN], pre[MAXN], ans[MAXN], cnt;
struct _point {int type;int x, y, v, id;
};
vector<_point> points;
set<int> st[MAXN];
struct _bit {int tr[MAXN];int lowbit(int x) { return x & (-x); };void modify(int x, int v) {while (x <= n) {tr[x] += v;x += lowbit(x);}}int query(int x) {int ret = 0;while (x) {ret += tr[x];x -= lowbit(x);}return ret;}
} t;void solve(int l, int r) {if (l == r) return;solve(l, mid); solve(mid + 1, r);vector<_point> v;for (int i = l; i <= mid; ++i) {if (points[i].type == 1) v.push_back(points[i]);}for (int i = mid + 1; i <= r; ++i) {if (points[i].type == 2) v.push_back(points[i]);}sort(v.begin(), v.end(), [](_point x, _point y) {if (x.y == y.y) return x.type < y.type;return x.y < y.y;});for (auto p:v) {if (p.type == 1) {t.modify(p.x, p.v);} else {ans[p.id] += t.query(n) - t.query(p.x - 1);}}for (auto p:v) {if (p.type == 1) {t.modify(p.x, -p.v);}}assert(t.query(n) == 0);
}void work() {cin >> n >> m;for (int i = 1; i <= n; ++i)cin >> a[i];for (int i = 1; i <= n; ++i) {if (pre[a[i]]) points.push_back({1, pre[a[i]], i, i - pre[a[i]]});pre[a[i]] = i;st[a[i]].insert(i);}for (int i = 1; i <= m; ++i) {int op; cin >> op;if (op == 1) {int p, x; cin >> p >> x;if (a[p] == x) continue;auto it = st[a[p]].find(p);if (next(it) != st[a[p]].end()) {auto nxt = next(it);points.push_back({1, p, *nxt, p - *nxt});if (it != st[a[p]].begin()) {auto lst = prev(it);points.push_back({1, *lst, *nxt, *nxt - *lst});}}if (it != st[a[p]].begin()) {auto lst = prev(it);points.push_back({1, *lst, p, *lst - p});}st[a[p]].erase(it);it = st[x].lower_bound(p);if (it != st[x].begin() && it != st[x].end()) {auto lst = prev(it);points.push_back({1, *lst, *it, *lst - *it});}if (it != st[x].begin()) {auto lst = prev(it);points.push_back({1, *lst, p, p - *lst});}if (it != st[x].end()) {points.push_back({1, p, *it, *it - p});}st[x].insert(p);a[p] = x;} else {int l, r; cin >> l >> r;points.push_back({2, l, r, 0, ++cnt});}}int N = points.size();solve(0, N - 1);for (int i = 1; i <= cnt; ++i)cout << ans[i] << endl;
}