题目概述
题目链接:https://www.luogu.com.cn/problem/P3643。
给你 \(n\) 个班级,每个班级要么不选数要么选的数在 \([a_i,b_i]\),且选的数比编号比他小的班级选的数都要大,问有多少种方案(对 \((10^9+7)\) 取模)。
分析
感觉挺经典的,而且这里的 trick 很通用,所以记录一下。
首先我们不难想到一个最最最暴力的 \(dp\)。
设 \(f_{i,j}\) 表示前 \(i\) 个班级已经处理完毕,当前 \(i\) 必选且选择的数为 \(j\) 的方案数。
显然有:
\[f_{i,j}=\sum_{lst=0}^{i-1}\sum_{w=a_{lst}}^{b_{lst}}f_{lst,w}
\]
那么我们一分也拿不了。
我们考虑怎么离散化让这个变得简单起来(这是重点)。
我们这个离散化主要是想要把一个区间离散化成一个更小的区间,可以想到区间覆盖的东东。
那么所以我们可以考虑这样一个 \([x,y)\) 区间离散化即可,为什么不用 \([x,y]\) 呢,会有一些边界问题,前者更好处理。
然后我们考虑怎么映射回来。
先重新定义 \(f\) 表示前 \(i\) 个班级处理完毕,当前 \(i\) 必选且选择的数落在区间 \(j\) 的方案数(这里的区间最多是 \(\mathcal{O}(n)\) 的)。
对于之前的班级选择的数不在 \(j\) 这个区间,也就是在之前,这是好转移的。
但是如果在这个区间呢?似乎有有点难办。
我们考虑有 \(m\) 个班级在这个区间,\(len\) 表示这个区间的长度,严格递增地选数,且可以不选的方案。
那么我们可以补 \(0\),然后从里面任意选 \(m\) 个,也就是 \(C_{m+len-1}^{m}\)。
这个显然是对的。
于是我们的转移就有了:
\[f_{i,j}=\sum_{lst=0}^{i-1}\sum_wC_{m+len-1}^{m}f_{lst,w}
\]
显然可以用前缀和优化。
代码
时间复杂度 \(\mathcal{O}(n^3)\),这个代码实现得十分精妙。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#define int long long
#define N 505
// #define M 1505
#define getid(x) (lower_bound(ls.begin(),ls.end(),x) - ls.begin())
using namespace std;
const int mod = 1e9 + 7;
int jc[N],inv[N];
// int C(int a,int b) {
// if (a < 0 || b < 0 || a < b) return 0;
// return jc[a] * inv[b] % mod * inv[a - b] % mod;
// }
int n,a[N],b[N],f[N],g[N];
signed main(){inv[0] = inv[1] = 1;for (int i = 2;i < N;i ++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;vector<int> ls;cin >> n;for (int i = 1;i <= n;i ++) {scanf("%lld%lld",&a[i],&b[i]);b[i] ++;ls.push_back(a[i]),ls.push_back(b[i]);}stable_sort(ls.begin(),ls.end());ls.erase(unique(ls.begin(),ls.end()),ls.end());for (int i = 1;i <= n;i ++) a[i] = getid(a[i]),b[i] = getid(b[i]);f[0] = 1;for (int j = 0;j < (int)ls.size() - 1;j ++) {int len = ls[j + 1] - ls[j];g[0] = 1;for (int i = 1;i <= n;i ++) g[i] = g[i - 1] * (i + len - 1) % mod * inv[i] % mod;for (int i = n;i;i --)if (a[i] <= j && j < b[i]) {for (int c = 1,k = i - 1;k >= 0;k --) {f[i] = (f[i] + g[c] * f[k] % mod) % mod;if (a[k] <= j && j < b[k]) c ++;}}}int ans = 0;for (int i = 1;i <= n;i ++) ans = (ans + f[i]) % mod;cout << ans;return 0;
}