给定一个式子,包含 >
,<
,?
或者 \([0,m)\) 中的一个数字。其中每个数字代表一个数。
>
代表返回两边的最大值,<
代表返回两边的最小值,?
表示你要在上文的两个符号中选择一个符号替换它。
假设有 \(t\) 个 ?
,你需要对所有 \(2^t\) 中安排符号的方案求出结果之和。
一共有 \(n\) 组 \(m\) 个数,你要对这 \(n\) 组 \(a\),把 \([0,m)\) 中的每个数字 \(i\) 替换为 \(a_i\),然后求出答案的总和。
(感觉总结的有点乱,可以去看原题面)。
非常神仙的小常数做法,以至于代码写的一坨都能一发拿最优解(
首先考虑 ?
的转化。
我们发现,\(\{\max(a,b),\min(a,b)\}=\{a,b\}\),也就是说,a?b
的两种替换方案的和其实就是 \(a+b\)。
接下来,由于我们已经引入了加号,我们考虑用类似的方式把 \(\min\) 转化掉。
有,\(\min(a,b)=a+b-\max(a,b)\)。
因此,到这为止,我们可以把原来的式子转化为只包含若干 \(\max(S)\) 的线性表达式。
假设我们有两个函数 \(F=\sum_S f(S)x^S,G=\sum_S g(S)x^S\),我们考虑他们在这三种符号下叠加的结果。
不难发现这其实就是集合并卷积,我们直接 FMT
即可做到 \(O(m2^m)\) 预处理。
但是这样仍然不足以通过。
我们可以继续利用 FMT
的线性性,它说明了,我们可以在叶子节点进行一次 FMT
,然后在之后的过程中直接点值相乘。
最后到根节点再 IFMT
回去。
同时由于叶子节点上只有一个点是有值的,所以也能在 \(\mathcal O(2^m)\) 的时间内完成。
最后回答部分直接用预处理出来的表算答案就可以了,常数非常小,直接拿下最优解()
#include <algorithm>
#include <iostream>
#include <vector>
#include <tuple>
#include <stack>const int N = 2e5 + 7, M = 10;
#define rep(i,a,b) for(int i(a);i<=(b);++i)const int O = 1e9 + 7;
struct Modint {int x;inline int load() {return x;}inline operator void* () {return (void*)(x != 0);}Modint() {} Modint(int _x): x(_x) {}
};
inline Modint operator - (Modint a) { return a.x ? O - a.x : 0; }
inline Modint operator * (Modint a, Modint b) { return 1ll * a.x * b.x %O; }
inline Modint operator + (Modint a, Modint b) { return a.x + b.x - (a.x + b.x >= O) * O; }
inline Modint operator - (Modint a, Modint b) { return a.x - b.x + (a.x - b.x < 0) * O; }
inline Modint fpow(Modint x, int k) { Modint res = 1; for(; k; k >>= 1, x = x * x) if(k & 1) res = res * x; return res; }
inline Modint operator / (Modint a, Modint b) { return a * fpow(b, O-2); }
inline Modint& operator /= (Modint& a, Modint b) { return a = a / b; }
inline Modint& operator += (Modint& a, Modint b) { a.x += b.x, a.x -= (a.x >= O) *O; return a; }
inline Modint& operator -= (Modint& a, Modint b) { a.x -= b.x, a.x += (a.x < 0) *O; return a; }
inline Modint& operator *= (Modint& a, Modint b) { return a.x = a.x * b.x %O, a; }
typedef Modint mi;int n, m;
int vals[N][M];struct pattern {std::vector<mi> w;pattern(): w(1<<m, 0) {}void FMT(int inv = 0) {for(int i = 1; i < (1 << m); i <<= 1) {for(int j = 0; j < (1 << m); j += 2 * i) {for(int k = 0; k < i; ++k) {if(inv) w[j + k + i] -= w[j + k];else w[j + k + i] += w[j + k];}}}}mi& operator[] (int k) {return w[k];}const mi operator[] (int k) const {return w[k];}
};inline pattern operate(const pattern& a, const pattern& b, char opt) {pattern res;int t = (1 << m) - 1;if(opt == '>') {for(int i = 0; i < (1 << m); ++i) {res[i] = a[i] * b[i];}} else if(opt == '?') {for(int i = 0; i < (1 << m); ++i) {res[i] = a[i] * b[t] + a[t] * b[i]; }} else {for(int i = 0; i < (1 << m); ++i) {res[i] = a[i] * b[t] + a[t] * b[i] - a[i] * b[i];}}return res;
}void solve() {std::cin >> n >> m;for(int j = 0; j < m; ++j) { rep(i, 1, n) { std::cin >> vals[i][j]; } }std::stack<char> opt;std::stack<pattern> pat;auto proc = [&](char op) {auto x = pat.top(); pat.pop();auto y = pat.top(); pat.pop();pat.push(operate(x, y, op));};std::string str;std::cin >> str;for(char c: str) {if(c == '(') {opt.push('(');} else if(c == ')') {while(!opt.empty() && opt.top() != '(') { proc(opt.top()), opt.pop(); }opt.pop();} else if(!isdigit(c)) {while(!opt.empty() && opt.top() != '(') { proc(opt.top()), opt.pop(); }opt.push(c);} else {int x = c - '0';pattern p;for(int i = 0; i < (1 << m); ++i) if((i >> x) & 1) p[i] = 1;pat.push(p);}} while(!opt.empty()) proc(opt.top()), opt.pop();auto res = pat.top();res.FMT(1);mi ans = 0;for(int i = 1; i <= n; ++i) {static int g[1<<M]; g[0] = 0;for(int j = 0; j < m; ++j) {for(int k = 1 << j; k < 1 << (j + 1); ++k) {g[k] = std::max(g[k ^ (1 << j)], vals[i][j]);ans += res[k] * g[k];} }}std::cout << ans.load() << "\n";
}int main() {std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);solve();
}