还是记录一下,有点怪但是有点帅。
首先条件等价于和不进位,也就是说每一位只有一个 \(b\) 是 \(1\)。
对于 \(b_i\ge a_i\) 的限制我们可以当成 \(a_i\) 加若干非负数,考虑 check 一个答案是否合法,记得确定一个漂亮一点的过程,不然做不了。找到当前的最高位 \(p\),并找出最大的 \(a_i\),设其最高位为 \(A_i\):
- \(p<A_i\),直接寄了。
- \(p=A_i\),则 \(b_i\) 这一位肯定得为 \(1\),所以我们把 \(a_i\) 减去 \(2^p\) 继续做即可。
- \(p>A_i\),直接把 \(a_i\) 删掉即可。
为啥删最大的是对的?因为我们不妨可以让 \(a_i,b_i\) 按大小顺序一一对应。
这样从最高位开始 check 是 \(\mathcal{O}(L^2)\) 的。
考虑优化。此时我们无法直接确定答案,并且这个过程至少也得平方,考虑找更多性质加速。我们每次会先确定最高位,然后继续往后走,然而前面做过的操作后面显然也会保留,不过这不足以过这道题。
既然先确定最高位,我们不妨考察最高位的上下界,我们发现要想确定最高位那一定最高位及以前都会设为 \(1\),性质很强。考察其下界,不妨先给 \(a\) 从大到小排序,对于每个 \(a_i\),\(p\ge A_i+i-1\),于是其下界为 \(p=\max {A_i+i-1}\),而我们发现只要给 \(p\) 加 \(1\),访问到的每个 \(i\) 都会是第三种情况,这样就一定可以。
于是我们每次确定最高位只需要 check 其下界就可以判定,然后往后确定,不过发现整个过程仍然是平方复杂度。我们此时有了快速找最高位的方法,然而 check 时还是愚蠢地跑了一整遍过程,我们发现在 check 中如果适当加入性质有的时候似乎可以提前结束,但是这样复杂度仍然不对,它浪费在了对于一个前缀 check 了一遍又一遍,我们现在希望结合性质找到一个高效的算法。
考虑递归处理。设计递归函数 \(H(a,p)\),check 最高位是否可以 \(\le p\),首先如果下界大于 \(p\) 直接结束了。否则 check 是否能达到下界,我们可以简单确定第一个取到上界的 \(i\) 以前的方案,剩下的考虑直接递归到 \(H(a',p-i)\),如果返回 \(0\),则考察 \(p\) 是否等于下界,如果等于则返回 \(0\),否则答案就是下界加一,构造一下方案直接递归即可,显然此时一定返回 \(1\)。
于是我们可以利用这个函数构造出答案,求 \(\max A_i+i-1\) 可以使用线段树维护。
事实上如果能够快速求出上下界,且范围很小,我们是可以直接往一些递归方向思考的,因为答案确定起来很简单,它有助于我们直接确定答案,避免多次 check,递归进去就可以边 check 边构造,十分厉害。
复杂度其它题解有讲,是 \(\mathcal{O}(L\log L)\),这篇题解主要是梳理思路和讲解答案。
代码:
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i (l); i <= (r); ++ i)
#define rrp(i, l, r) for (int i (r); i >= (l); -- i)
#define eb emplace_back
using namespace std;
#define pii pair <int, int>
#define inf 1000000001
#define ls (p << 1)
#define rs (ls | 1)
#define fi first
#define se second
constexpr int N = 3e5 + 5, M = 2e5 + 5;
typedef long long ll;
typedef unsigned long long ull;
inline ll rd () {ll x = 0, f = 1;char ch = getchar ();while (! isdigit (ch)) {if (ch == '-') f = -1;ch = getchar ();}while (isdigit (ch)) {x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar ();}return x * f;
}
int qpow (int x, int y, int p) {int ret (1);for (; y; y >>= 1, x = x * x % p) if (y & 1) ret = ret * x % p;return ret;
}
int n;
char s[N];
int nxt[N], tmp[N], a[N], cnt;
vector <int> vec[N];
class node {public:int mx, id, sum;friend node operator + (const node &a, const node &b) {if (a.mx > b.mx) return {a.mx, a.id, a.sum + b.sum}; return {b.mx, b.id, a.sum + b.sum};}
} t[N << 2];
int tag[N << 2];
void psu (int p) {t[p] = t[ls] + t[rs];
}
void build (int p, int l, int r) {if (l == r) {t[p].mx = -1e9, t[p].id = l; return ;} int mid (l + r >> 1);build (ls, l, mid), build (rs, mid + 1, r);psu (p);
}
void add (int p, int k) {t[p].mx += k, tag[p] += k;
}
void psd (int p) {if (tag[p]) {add (ls, tag[p]), add (rs, tag[p]), tag[p] = 0;}
}
void upd (int p, int l, int r, int L, int R, int k) {if (L <= l && r <= R) return add (p, k);int mid (l + r >> 1); psd (p);if (L <= mid) upd (ls, l, mid, L, R, k);if (R > mid) upd (rs, mid + 1, r, L, R, k);psu (p);
}
void upd (int p, int l, int r, int x, int k) {if (l == r) {if (k == -1) t[p].mx = -1e9, t[p].sum = 0;else t[p].mx = k, t[p].sum = 1; return ;}int mid (l + r >> 1); psd (p);if (x <= mid) upd (ls, l, mid, x, k);else upd (rs, mid + 1, r, x, k);psu (p);
}
int qry (int p, int l, int r, int L, int R) {if (L > R) return 0;if (L <= l && r <= R) return t[p].sum;int mid (l + r >> 1), ret (0);if (L <= mid) ret += qry (ls, l, mid, L, R);if (R > mid) ret += qry (rs, mid + 1, r, L, R);return ret;
}
set <int, greater <int> > st;
void ins (int k) {if (k > 1) upd (1, 1, cnt, 1, k - 1, 1);upd (1, 1, cnt, k, qry (1, 1, cnt, k + 1, cnt) + a[k]);st.insert (k);
}
void era (int k) {if (k > 1) upd (1, 1, cnt, 1, k - 1, -1);upd (1, 1, cnt, k, -1);st.erase (k);
}
int ans[N << 1];
bool solve (int p) {if (st.empty ()) {int v = 6e5;while (v > 0 && ! ans[v]) -- v;while (~ v) putchar (ans[v] + 48), -- v;return 1;}int h = t[1].mx;int id = t[1].id;if (p < h) return 0;int now = h;vector <int> tmp;for (; ;) {int x = * st.begin ();ans[now] = 1, tmp.eb (x), era (x); -- now;if (x == id) break;}if (nxt[id]) ins (nxt[id]);if (solve (now)) return 1;if (nxt[id]) era (nxt[id]);if (p == h) {for (auto x : tmp) ans[++ now] = 0, ins (x);return 0;}ans[now + 1] = 0, ans[h + 1] = 1;return solve (now + 1);
}
int32_t main () {// freopen ("1.in", "r", stdin);// freopen ("1.out", "w", stdout);n = rd (); int mx (0);rep (i, 1, n) {scanf ("%s", s + 1);int len = strlen (s + 1);reverse (s + 1, s + len + 1);rep (j, 1, len) if (s[j] == '1') vec[j].eb (i); mx = max (mx, len);}rep (i, 1, mx) {sort (vec[i].begin (), vec[i].end (), [&] (int x, int y) { return tmp[x] < tmp[y]; });for (auto j : vec[i]) nxt[cnt + 1] = tmp[j], tmp[j] = ++ cnt, a[cnt] = i - 1;}build (1, 1, cnt);rep (i, 1, n) ins (tmp[i]);solve (n + mx);
}