题目概述
题目链接:https://www.luogu.com.cn/problem/P2605。
有 \(n\) 个村庄,你需要建立不超过 \(k\) 个基站,每一户人家都有参数 \(d_i,s_i,w_i,c_i\) 分别表示距离第一户人家的距离、在不超过 \(s_i\) 的地方有基站才能覆盖此地、没有被基站覆盖的补偿费用、在此建立基站的费用。
求最小费用。
分析
一道经典题目,来记录一下。
首先不难处理 \(l_i,r_i\) 表示在这个范围内的人家只要有建基站那么就能覆盖此地。
二分即可。
设 \(f_{i,j}\) 表示前 \(i\) 个人家建立 \(j\) 个基站的最小代价(当前也要建)。
转移是显然的:
直接搞是 \(\mathcal{O}(n^3)\) 的。
先考虑如何把 \(cost\) 算掉。
在我每一次循环 \(i\) 的时候,遍历计算 \(cost\) 即可。
那么就可以得到一个客观的 \(\mathcal{O}(n^2)\) 的代码,有 \(60pts\)。
#include <iostream>
#include <stdlib.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define N 200005
#define M 105
using namespace std;
#define min(x,y) ((x)<(y)?(x):(y))
const int INF = 2e9;
int n,k,d[N],c[N],s[N],w[N],l[N],r[N];
int f[N][M],cost[N],ans;
signed main() {scanf("%d%d",&n,&k);for (int i = 2;i <= n;i ++) scanf("%d",&d[i]);for (int i = 1;i <= n;i ++) scanf("%d",&c[i]);for (int i = 1;i <= n;i ++) scanf("%d",&s[i]);for (int i = 1;i <= n;i ++) scanf("%d",&w[i]);for (int i = 1;i <= n;i ++) {l[i] = lower_bound(d + 1,d + 1 + n,d[i] - s[i]) - d;r[i] = lower_bound(d + 1,d + 1 + n,d[i] + s[i]) - d;r[i] -= (d[i] + s[i] < d[r[i]]);}for (int i = 1;i <= n;i ++) {for (int j = 1;j <= k;j ++) f[i][j] = INF;f[i][0] = f[i-1][0] + w[i];}ans = f[n][0];for (int i = 1;i <= n;i ++) {memset(cost,0,sizeof cost);for (int j = i - 1;j;j --)if (r[j] < i) cost[l[j] - 1] += w[j];for (int j = i - 1;j >= 0;j --) cost[j] += cost[j + 1];for (int j = 1;j <= k && j <= i;j ++) {if(j == 1) f[i][1] = cost[0] + c[i];elsefor(int p = i - 1;p >= j - 1;p --)f[i][j] = min(f[i][j],f[p][j - 1] + cost[p] + c[i]);int sum = 0;for (int p = i + 1;p <= n;p ++)if (l[p] > i) sum += w[p];ans = min(ans,f[i][j] + sum);} }printf("%d\n",ans);return 0;
}
考虑怎么优化。
显然,可以先枚举有多少个基站,这在斜率优化 \(dp\) 种颇有体现,以及 \(wqs\) 二分的处理方式也有。
可惜的是如果你这么改了之后,你的暴力只有 \(40pts\) 了。
#include <iostream>
#include <stdlib.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define N 200005
#define M 105
using namespace std;
#define min(x,y) ((x)<(y)?(x):(y))
const int INF = 2e9;
int n,k,d[N],c[N],s[N],w[N],l[N],r[N];
int f[N],cost[N],ans;
signed main() {scanf("%d%d",&n,&k);for (int i = 2;i <= n;i ++) scanf("%d",&d[i]);for (int i = 1;i <= n;i ++) scanf("%d",&c[i]);for (int i = 1;i <= n;i ++) scanf("%d",&s[i]);for (int i = 1;i <= n;i ++) scanf("%d",&w[i]);for (int i = 1;i <= n;i ++) {l[i] = lower_bound(d + 1,d + 1 + n,d[i] - s[i]) - d;r[i] = lower_bound(d + 1,d + 1 + n,d[i] + s[i]) - d;r[i] -= (d[i] + s[i] < d[r[i]]);}for (int i = 1;i <= n;i ++) f[i] = f[i - 1] + w[i];ans = f[n];// for (int i = 1;i <= n;i ++) {// memset(cost,0,sizeof cost);// for (int j = i - 1;j;j --)// if (r[j] < i) cost[l[j] - 1] += w[j];// for (int j = i - 1;j >= 0;j --) cost[j] += cost[j + 1];// for (int j = 1;j <= k && j <= i;j ++) {// if(j == 1) f[i][1] = cost[0] + c[i];// else// for(int p = i - 1;p >= j - 1;p --)// f[i][j] = min(f[i][j],f[p][j - 1] + cost[p] + c[i]);// int sum = 0;// for (int p = i + 1;p <= n;p ++)// if (l[p] > i) sum += w[p];// ans = min(ans,f[i][j] + sum);// } // }for (int i = 1;i <= n;i ++) {memset(cost,0,sizeof cost);for (int j = i - 1;j;j --)if (r[j] < i) cost[l[j] - 1] += w[j];for (int j = i - 1;j >= 0;j --) cost[j] += cost[j + 1];f[i] = cost[0] + c[i];int sum = 0;for (int p = i + 1;p <= n;p ++)if (l[p] > i) sum += w[p];ans = min(ans,f[i] + sum);}for (int j = 2;j <= k;j ++) {for (int i = j;i <= n;i ++) {memset(cost,0,sizeof cost);for (int p = 1;p < i;p ++)if (r[p] < i) cost[l[p] - 1] += w[p];for (int p = i - 1;p >= 0;p --) cost[p] += cost[p + 1];for (int p = i - 1;p >= j - 1;p --)f[i] = min(f[i],f[p] + cost[p] + c[i]);int sum = 0;for (int p = i + 1;p <= n;p ++)if (l[p] > i) sum += w[p];ans = min(ans,f[i] + sum);}}printf("%d\n",ans);return 0;
}
不过确实少了不少限制。
主要的时间开销为计算 \(cost\) 以及我们的暴力转移,我们的 \(k\) 是省略不掉的。
我们现在的转移是:
首先处理 \(k=1\) 的 \(cost\)。
考虑怎么快速转移。
我们发现限制我们 \(cost\) 的只有 \(r\),那么我们可以根据 \(i\) 的需要,去尺取变化就行了。
那你 \(k=1\) 都用 \(\mathcal{O}(n\log n)\) 解决了,那你再来个 \(k\),再来个线段树维护 \(dp\) 值加上前面的 \(cost\) 就行了。
代码
时间复杂度 \(\mathcal{O}(kn\log n)\)。
#include <iostream>
#include <stdlib.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define N 200005
#define M 105
using namespace std;
const int INF = 2e9;
int min(int a,int b) {return a < b ? a : b;}
int n,k,d[N],c[N],s[N],w[N],l[N],r[N];
int f[N],cost[N],ans;
struct node{int r,id;
}ls[N];
int tr[N << 2],lz[N << 2];
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
void pushup(int x) {tr[x] = min(tr[ls(x)], tr[rs(x)]);
}
void pushdown(int x) {tr[ls(x)] += lz[x],tr[rs(x)] += lz[x];lz[ls(x)] += lz[x],lz[rs(x)] += lz[x];lz[x] = 0;
}
void build(int x,int l,int r) {lz[x] = 0;if (l == r) {tr[x] = f[l];return;}int mid = l + r >> 1;build(ls(x),l,mid),build(rs(x),mid + 1,r);pushup(x);
}
void update(int x,int l,int r,int L,int R,int val) {if (l > R || r < L) return ;if (L <= l && r <= R) {tr[x] += val;lz[x] += val;return;}if (lz[x]) pushdown(x);int mid = l + r >> 1;update(ls(x),l,mid,L,R,val),update(rs(x),mid + 1,r,L,R,val);pushup(x);
}
int query(int x,int l,int r,int L,int R) {if (l > R || r < L) return INF;if (L <= l && r <= R) return tr[x];if (lz[x]) pushdown(x);int mid = l + r >> 1;return min(query(ls(x),l,mid,L,R),query(rs(x),mid + 1,r,L,R));
}
signed main() {scanf("%d%d",&n,&k);for (int i = 2;i <= n;i ++) scanf("%d",&d[i]);for (int i = 1;i <= n;i ++) scanf("%d",&c[i]);for (int i = 1;i <= n;i ++) scanf("%d",&s[i]);for (int i = 1;i <= n;i ++) scanf("%d",&w[i]);for (int i = 1;i <= n;i ++) {l[i] = lower_bound(d + 1,d + 1 + n,d[i] - s[i]) - d;r[i] = lower_bound(d + 1,d + 1 + n,d[i] + s[i]) - d;r[i] -= (d[i] + s[i] < d[r[i]]);ls[i] = {r[i],i};}// for (int i = 1;i <= n;i ++) {// memset(cost,0,sizeof cost);// for (int j = i - 1;j;j --)// if (r[j] < i) cost[l[j] - 1] += w[j];// for (int j = i - 1;j >= 0;j --) cost[j] += cost[j + 1];// for (int j = 1;j <= k && j <= i;j ++) {// if(j == 1) f[i][1] = cost[0] + c[i];// else// for(int p = i - 1;p >= j - 1;p --)// f[i][j] = min(f[i][j],f[p][j - 1] + cost[p] + c[i]);// int sum = 0;// for (int p = i + 1;p <= n;p ++)// if (l[p] > i) sum += w[p];// ans = min(ans,f[i][j] + sum);// } // }// for (int i = 1;i <= n;i ++) {// memset(cost,0,sizeof cost);// for (int j = i - 1;j;j --)// if (r[j] < i) cost[l[j] - 1] += w[j];// for (int j = i - 1;j >= 0;j --) cost[j] += cost[j + 1];// f[i] = cost[0] + c[i];// int sum = 0;// for (int p = i + 1;p <= n;p ++)// if (l[p] > i) sum += w[p];// ans = min(ans,f[i] + sum);// }// for (int j = 2;j <= k;j ++) {// for (int i = j;i <= n;i ++) {// memset(cost,0,sizeof cost);// for (int p = 1;p < i;p ++)// if (r[p] < i) cost[l[p] - 1] += w[p];// for (int p = i - 1;p >= 0;p --) cost[p] += cost[p + 1];// for (int p = i - 1;p >= j - 1;p --)// f[i] = min(f[i],f[p] + cost[p] + c[i]);// int sum = 0;// for (int p = i + 1;p <= n;p ++)// if (l[p] > i) sum += w[p];// ans = min(ans,f[i] + sum);// }// }stable_sort(ls + 1,ls + 1 + n,[](node x,node y) {return x.r < y.r;});for (int i = 1;i <= n;i ++) ans += w[i];for (int i = n;i;i --) cost[l[i] - 1] += w[i];for (int i = n - 1;i;i --) cost[i] += cost[i + 1];int now = 0;for (int i = 1;i <= n;i ++) {f[i] = now + c[i];int it = lower_bound(ls + 1,ls + 1 + n,(node){i,0},[](node x,node y) {return x.r < y.r;}) - ls;while(ls[it].r == i && it <= n) {now += w[ls[it].id];it ++;}ans = min(ans,f[i] + cost[i]);}for (int j = 2;j <= k;j ++) {build(1,1,n);for (int i = j;i <= n;i ++) {f[i] = query(1,1,n,j - 1,i - 1) + c[i];ans = min(ans,f[i] + cost[i]);int it = lower_bound(ls + 1,ls + 1 + n,(node){i,0},[](node x,node y) {return x.r < y.r;}) - ls;while(ls[it].r == i && it <= n) {update(1,1,n,1,l[ls[it].id] - 1,w[ls[it].id]);it ++;}}}printf("%d\n",ans);return 0;
}