Universal Cup のシーズン2が個人的に不甲斐ない結果で終わってしまったので、シーズン3開幕まで復習を頑張ろうと思います。
問題のリンク:https://qoj.ac/contest/1339/problem/7108
問題概要
長さ $N$ の数列 $A$ が与えられる。次のクエリを $N$ 回処理せよ。
使用可能な連続部分列の転倒数の最大値を出力し、数列の1要素を次回以降使用不能とする。ただし、これらはオンラインで解く必要がある(求めた転倒数の値から、次使用不能になる値が複号できる)。
マルチテストケースで、$N$ の総和は $10 ^ 6$ で抑えられている。
解法
基本方針:操作の過程で現れる全ての区間の転倒数を保持
マージテクの逆を考えると、区間が新たに2分割された時、短い方の区間長に依存するような更新をすれば、(更新部分以外は) $O(N\log{N})$ になることが分かります。
更新方法として、短い方の区間に対しては BIT を再度作り直して転倒数を計算、長い方は差分更新をすることで、全体で $O(N\log{N}^{2})$ で解くことができます。
肝心の実装は中々神経を必要として大変でした・・・
実装
#pragma GCC optimize("Ofast") #include <bits/stdc++.h> using namespace std; typedef long long int ll; typedef unsigned long long int ull; mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count()); ll myRand(ll B) { return (ull)rng() % B; } // 0-indexed template <typename T> struct BIT { int n; vector<T> bit, ary; BIT(int n = 0) : n(n), bit(n + 1), ary(n) {} T operator[](int k) { return ary[k]; } // [0, i) T sum(int i) { T res = 0; for (; i > 0; i -= (i & -i)) { res += bit[i]; } return res; } // [l, r) T sum(int l, int r) { return sum(r) - sum(l); } void add(int i, T a) { ary[i] += a; i++; for (; i <= n; i += (i & -i)) { bit[i] += a; } } }; void slv() { int n; cin >> n; vector<int> a(n); for (int i = 0; i < n; ++i) { cin >> a[i]; } vector<BIT<int>> bit_memo(n); // [l, r) の転倒数と左端を始点としたBITを返す auto calc = [&](int l, int r) -> pair<ll, BIT<int>> { // 座圧 vector<int> z(r - l); for (int i = l; i < r; ++i) { z[i - l] = a[i]; } sort(z.begin(), z.end()); z.erase(unique(z.begin(), z.end()), z.end()); for (int i = l; i < r; ++i) { a[i] = lower_bound(z.begin(), z.end(), a[i]) - z.begin(); } ll inv = 0; BIT<int> bit(z.size()); for (int i = l; i < r; ++i) { inv += (i - l - bit.sum(a[i] + 1)); bit.add(a[i], 1); } return {inv, bit}; }; // setで区間を管理 [l, r) set<pair<int, int>> st; st.insert({0, n}); // 転倒数の値を管理 multiset<ll> mst; vector<ll> inv_memo(n); // 初期化 auto init = calc(0, n); mst.insert(init.first); inv_memo[0] = init.first; bit_memo[0] = init.second; for (int i = 0; i < n; ++i) { ll val = *mst.rbegin(); cout << val << " "; ll mid; cin >> mid; mid = (mid ^ val) - 1; auto it = st.lower_bound({mid, 1e9}); auto [l, r] = *(--it); st.erase(it); val = inv_memo[l]; mst.erase(mst.find(val)); // 左端を消す場合 if (l == mid) { if (mid + 1 != r) { val -= bit_memo[l].sum(a[l]); bit_memo[l].add(a[l], -1); swap(bit_memo[l], bit_memo[l + 1]); st.insert({l + 1, r}); mst.insert(val); inv_memo[l + 1] = val; } continue; } // 右端を消す場合 if (mid + 1 == r) { val -= (r - l) - bit_memo[l].sum(a[r - 1] + 1); bit_memo[l].add(a[r - 1], -1); st.insert({l, r - 1}); mst.insert(val); inv_memo[l] = val; continue; } // 区間が2つに分かれる場合 if (mid - l <= r - mid) { for (int j = l; j <= mid; ++j) { val -= bit_memo[l].sum(a[j]); bit_memo[l].add(a[j], -1); } swap(bit_memo[l], bit_memo[mid + 1]); st.insert({mid + 1, r}); mst.insert(val); inv_memo[mid + 1] = val; auto nxt = calc(l, mid); bit_memo[l] = nxt.second; st.insert({l, mid}); mst.insert(nxt.first); inv_memo[l] = nxt.first; } else { for (int j = r - 1; j >= mid; --j) { val -= (j + 1 - l) - bit_memo[l].sum(a[j] + 1); bit_memo[l].add(a[j], -1); } st.insert({l, mid}); mst.insert(val); inv_memo[l] = val; auto nxt = calc(mid + 1, r); bit_memo[mid + 1] = nxt.second; st.insert({mid + 1, r}); mst.insert(nxt.first); inv_memo[mid+1] = nxt.first; } } cout << "\n"; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int q; cin >> q; while (q--) { slv(); } }