Description
有一个王国,共有 \(n\) 座城市,这些城市编号为 \(1\) 到 \(n\)(包含两端)。
王国中有 \(n-1\) 条双向道路将这些城市相连,并且保证任意两座城市之间都可以通过这些道路到达。
女王最近决定新建 \(k\) 座工厂。为了避免污染,每座城市至多只能建一座工厂。
你,作为王国的御用设计师,需要安排这些工厂的建设位置,并且尽量 最大化所有工厂两两之间距离的总和。
两座工厂之间的距离定义为:它们所在城市之间的最短路径长度。路径的长度等于路径上各条边长度的和。
\(n\leq 10^5\)。
Solution
对于每条边算贡献,假设一条边下面那个点的子树里选了 \(c\) 个点,那么贡献为 \(c(k-c)\),所以考虑对于子树树从下往上 dp。
设 \(f_{u,i}\) 表示 \(u\) 的子树里选了 \(i\) 个点,\(u\) 子树里的边(包括 \(u\) 与其父亲的边)的最大总贡献。
那么首先有转移:\(f_{u,i+j}\leftarrow f'_{u,i}+f_{v,j}\)。
得到子树后的贡献后,再加上 \(u\) 父亲上的边的贡献,即:\(f_{u,i}\leftarrow f_{u,i}+wi(k-i)\)。
但是这里有个 \((\max,+)\) 卷积,不能暴力转移。注意到 \(wi(k-i)\) 是下凸的,其差分为 \(w(k-2i+1)\),所以 dp 数组也是下凸的。用平衡树维护差分数组,需要支持启发式合并以及加等差数列,直接打标记即可。
时间复杂度:\(O(n\log^2n)\)。
Code
#include <bits/stdc++.h>#define int int64_tconst int kMaxN = 1e5 + 5;int n, k;
int rt[kMaxN];
std::vector<std::pair<int, int>> G[kMaxN];
std::mt19937 rnd(114514);struct FHQTreap {int tot, ls[kMaxN], rs[kMaxN], sz[kMaxN], val[kMaxN], rd[kMaxN], tag1[kMaxN], tag2[kMaxN];int newnode(int v) {sz[++tot] = 1, val[tot] = v, ls[tot] = rs[tot] = tag1[tot] = tag2[tot] = 0, rd[tot] = rnd();return tot;}void pushup(int x) {sz[x] = sz[ls[x]] + sz[rs[x]] + 1;}void addtag1(int x, int v) { val[x] += v, tag1[x] += v; }void addtag2(int x, int v) {val[x] += v * (sz[ls[x]] + 1), tag2[x] += v;}void pushdown(int x) {if (tag1[x]) {if (ls[x]) addtag1(ls[x], tag1[x]);if (rs[x]) addtag1(rs[x], tag1[x]);tag1[x] = 0;}if (tag2[x]) {if (ls[x]) addtag2(ls[x], tag2[x]);if (rs[x]) addtag1(rs[x], tag2[x] * (sz[ls[x]] + 1)), addtag2(rs[x], tag2[x]);tag2[x] = 0;}}int merge(int x, int y) {// std::cerr << "??? " << x << ' ' << y << ' ' << ls[x] << ' ' << ls[y] << ' ' << rs[x] << ' ' << rs[y] << '\n';if (!x || !y) return x + y;pushdown(x), pushdown(y);if (rd[x] < rd[y]) {rs[x] = merge(rs[x], y), pushup(x);return x;} else {ls[y] = merge(x, ls[y]), pushup(y);return y;}}void split(int x, int v, int &a, int &b) {if (!x) return void(a = b = 0);pushdown(x);if (val[x] >= v) {a = x, split(rs[x], v, rs[x], b);pushup(a);} else {b = x, split(ls[x], v, a, ls[x]);pushup(b);}}void ins(int &rt, int x) {int a, b;// std::cerr << "??? " << rt << ' ' << x << '\n';split(rt, val[x], a, b);rt = merge(a, merge(x, b));}void update(int rt, int v1, int v2) {// std::cerr << "fuck " << rt << ' ' << v1 << ' ' << v2 << '\n';addtag1(rt, v1), addtag2(rt, v2);}void insall(int &x, int y) {if (sz[x] < sz[y]) std::swap(x, y);// std::cerr << sz[x] << ' ' << sz[y] << ' ';std::vector<int> id;std::function<void(int)> dfs = [&] (int x) {if (!x) return;pushdown(x);if (ls[x]) dfs(ls[x]);if (rs[x]) dfs(rs[x]);ls[x] = rs[x] = tag1[x] = tag2[x] = 0, sz[x] = 1;id.emplace_back(x);};dfs(y);for (auto i : id) ins(x, i);// std::cerr << id.size() << ' ' << sz[x] << '\n';}void print(int x) {if (!x) return;pushdown(x);if (ls[x]) print(ls[x]);std::cerr << val[x] << ' ';if (rs[x]) print(rs[x]);}int getsum(int x, int k) {if (!x) return 0;// std::cerr << "??? " << x << ' ' << val[x] << ' ' << k << ' ' << sz[ls[x]] << '\n';pushdown(x);assert(sz[x] >= k);if (k <= sz[ls[x]]) return getsum(ls[x], k);else if (k <= sz[ls[x]] + 1) return getsum(ls[x], sz[ls[x]]) + val[x];else return getsum(ls[x], sz[ls[x]]) + val[x] + getsum(rs[x], k - sz[ls[x]] - 1);}
} t;void dfs(int u, int fa, int faw) {rt[u] = t.newnode(0);// std::cerr << t.sz[rt[u]] << '\n';for (auto [v, w] : G[u]) {if (v == fa) continue;dfs(v, u, w);t.insall(rt[u], rt[v]);}t.update(rt[u], faw * (k + 1), -2 * faw);// t.print(rt[u]), std::cerr << '\n';
}void dickdreamer() {std::cin >> n >> k;for (int i = 1; i < n; ++i) {int u, v, w;std::cin >> u >> v >> w;G[u].emplace_back(v, w), G[v].emplace_back(u, w);}dfs(1, 0, 0);// t.print(rt[1]), std::cerr << '\n';std::cout << t.getsum(rt[1], k) << '\n';// std::cerr << t.sz[rt[1]] << '\n';
}int32_t main() {
#ifdef ORZXKRfreopen("in.txt", "r", stdin);freopen("out.txt", "w", stdout);
#endifstd::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);int T = 1;// std::cin >> T;while (T--) dickdreamer();// std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";return 0;
}