传送门
首先如果没有任何限制条件,则原问题即变为简单的「求方程 \(\sum_{i=1}^nx_i=m\) 的解的个数」。此时考虑插板法,等价于将 \(m\) 个 \(1\) 分成 \(n\) 份,这时有 \(m-1\) 个空隙,要插 \(n-1\) 个板,方案数就是 \(\binom{m-1}{n-1}\)。
现在有两类限制条件。第二类条件 \(x_i\ge a_i\) 容易考虑,因为所有 \(x_i\) 的共有限制是 \(x_i\ge1\),因此把 \(x_i\ge a_i\) 的不等号两端同时减去 \(a_i-1\) 即得 \(x_i-a_i+1\ge1\),转化成为所有 \(x_i\) 共有的条件。
第一类条件 \(x_i\le a_i\) 难以处理,因此考虑容斥。该式的逆命题即为 \(x_i\ge a_i+1\),类似地将不等式两端同时减去 \(a_i\),转化为 \(x_i-a_i\ge1\)。因为 \(n_1\le8\),可以使用状态压缩的方法枚举不满足哪几个条件。对于每一种枚举的结果,将 \(m\) 减去不满足的条件对应的 \(a_i\) 得到一个数 \(tmp\),用插板法计算出方案数 \(\binom{tmp-1}{n-1}\),乘上容斥系数之后再全部加起来就是答案。
求组合数时,观察数据范围发现 \(n,m\) 较大且 \(p\) 不一定是质数,因此使用扩展 Lucas 定理。但是直接写板子常数较大难以通过,可以使用提前处理(细节参考代码中 C
函数的前半部分)等方法进行优化,从而通过本题。
代码如下:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10;
int T, p, n, n1, n2, m, a[20], f[N];
int qpow (int a, int b, int mod) {int res = 1;for (; b; b >>= 1, a = a * a % mod)if (b & 1)res = res * a % mod;return res;
}
int exgcd (int a, int b, int &x, int &y) {if (!b) {x = 1, y = 0;return a;}int d = exgcd(b, a % b, x, y), t = x;x = y, y = t - a / b * y;return d;
}
int fac (int n, int d, int t) {if (!n)return 1;if (n < d)return f[n];return qpow(f[t - 1], n / t, t) * f[n % t] % t * fac(n / d, d, t) % t;
}
int inv (int n, int mod) {int x, y;exgcd(n, mod, x, y);return (x % p + p) % p;
}
int crt (int a, int mod) {return a * (p / mod) % p * inv(p / mod, mod) % p;
}
int C (int n, int m, int d, int t) {if (n < m)return 0;f[0] = 1;for (int i = 1; i <= t; ++i) {if (i % d != 0)f[i] = f[i - 1] * i % t;elsef[i] = f[i - 1];}int fz = fac(n, d, t), fm1 = fac(m, d, t), fm2 = fac(n - m, d, t), k = 0;for (int i = n; i; i /= d)k += i / d;for (int i = m; i; i /= d)k -= i / d;for (int i = n - m; i; i /= d)k -= i / d;return fz * inv(fm1, t) % t * inv(fm2, t) % t * qpow(d, k, t) % t;
}
int solve (int n, int m, int d, int t) {int res = 0, s = 1 << n1;for (int i = 0; i < s; ++i) {int op = 1, tmp = n;for (int j = 0; j < n1; ++j)if (i >> j & 1)op = -op, tmp -= a[j + 1];(res += op * C(tmp, m, d, t)) %= t;}return res;
}
int exlucas (int n, int m) {if (n < m)return 0;int res = 0, tmp = p, d;for (int i = 2; i * i <= p; ++i)if (tmp % i == 0) {d = 1;while (tmp % i == 0)d *= i, tmp /= i;(res += crt(solve(n, m, i, d), d)) %= p;}if (tmp != 1)(res += crt(solve(n, m, tmp, tmp), tmp)) %= p;return (res % p + p) % p;
}
signed main () {ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);cin >> T >> p;while (T--) {cin >> n >> n1 >> n2 >> m;for (int i = 1; i <= n1 + n2; ++i)cin >> a[i];for (int i = n1 + 1; i <= n1 + n2; ++i)m -= (a[i] - 1);cout << exlucas(m - 1, n - 1) << endl;} return 0;
}