题目链接:Problem - D - Codeforces
本身是一道数学题,我们可以把字符串中的奇数项和偶数项分开,形成两条序列 A 和 B。易知一种字母一定在同一条序列上。
假如说在 A 序列上分配了 \(a,b,c\) 三种字母,\(sum = c_a + c_b + c_c\),那么 A 序列的方案数为:
同理,假如 B 上分配了 \(d,e,f\) 三种字母,\(sum^{\prime} = c_d + c_e + c_f\),那么 B 序列的方案数为:
那么对于这种分配字母种类的方式,字符串种类一共有 \(ans_A \times ans_B\) 种,即:
可以发现,无论如何分配字母种类,只要分配合法,这一种分配方式的方案数量均为 \(ans\) 不变。而 \(c_i\) 是固定的,因此我们可以直接求出来 \(ans\)。因为有取余下除法,所以需要求逆元。而恰好模数 \(998244353\) 是一个质数,因此可以直接用费马小定理求出。这就是一个数学问题。
处理完 \(ans\),那么问题就剩下这个——求出序列 A,B 的分配字母种类的方案数。因为字母种类有 \(26\) 种,且多测数量达到了 \(10^4\),直接使用 dfs 需要算约 \(10^{12}\) 次,会 TLE,但如果这个 \(26\) 变为 \(13\),那么计算次数就约为 \(10^8\) 可以接受,因此使用双向 dfs 统计可行数量即可。
而在双向 dfs 拼凑答案时为了效率,我们需要使用 unordered_map 去存储每一种 A,B 序列分配字母种类的方案数,这时候会涉及到二维 unordered_map,此时直接用进制转化为一维即可(记得开 longlong)。算出种类数 \(cnt\),那么最后我们的答案就是 \(cnt * ans\),最后取模即可。
最后,还有一点,这题有一个卡常,即在算 \(ans\) 时,你的阶乘不能预处理,这题很极限,100ms 的预处理会正好卡死这题(至少我是这样)。因此需要在线计算阶乘。
AC代码为:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <map>
#include <unordered_map>using namespace std;typedef long long LL;
typedef pair<int, int> PII;
// typedef unordered_map<int, int> UII;const int mod = 998244353, N = 500010;int n = 26, m;
int ans, sum, tx, ty;
int g[30], cnt;
int f[N], in_f[N];
unordered_map<LL, int> q; int qmi(int a, int k, int q)
{int res = 1;while (k){if (k & 1) res = 1ll * res * a % mod;k >>= 1;a = 1ll * a * a % mod;}return res;
}int fd(int x)
{int res = 1;for (int i = 1; i <= x; i ++ ) res = 1ll * i * res % mod;return res;
}int in_fd(int x)
{int res = 1;for (int i = 1; i <= x; i ++ ) res = 1ll * i * res % mod;return qmi(res, mod - 2, mod);
}void dfs(int x, int s1, int s2)
{if (x > m) {q[s1 * N + s2] ++ ;return ;}if (g[x] == 0) dfs(x + 1, s1, s2);else {if (s1 + g[x] <= tx) dfs(x + 1, s1 + g[x], s2);if (s2 + g[x] <= ty) dfs(x + 1, s1, s2 + g[x]);}
}void dfs2(int x, int s1, int s2)
{if (x > n) {cnt += q[(tx - s1) * N + (ty - s2)];return ;}if (g[x] == 0) dfs2(x + 1, s1, s2);else {if (s1 + g[x] <= tx) dfs2(x + 1, s1 + g[x], s2);if (s2 + g[x] <= ty) dfs2(x + 1, s1, s2 + g[x]);}
}int main()
{int T;cin >> T;while (T -- ){q.clear();ans = 1, cnt = sum = 0;for (int i = 1; i <= n; i ++ ) scanf("%d", &g[i]), sum += g[i];sort(g + 1, g + 1 + n);reverse(g + 1, g + 1 + n);tx = sum / 2, ty = (sum + 1) / 2;m = 13;dfs(1, 0, 0);dfs2(m + 1, 0, 0);for (int i = 1; i <= n; i ++ ) ans = 1ll * ans * in_fd(g[i]) % mod;cout << 1ll * cnt * fd(tx) % mod * fd(ty) % mod * ans % mod << '\n';} return 0;
}