TLで流れていた解法よりも直感的に理解しやすい(と個人的に思っている)のでメモ
問題概要
大きさが $ 1 $ から $ N $ の $ N $ 枚のパンケーキを下から積んでいく。この時、$ i $ 枚積んだ時点で上から見えるパンケーキの数は $ A_i $ 枚だったとする。
$ N! $ 通りのパンケーキの積み重ね方のうち、条件を満たすものの総数を $10^{9}+7$ で割った余りで求めよ。
制約
Small(10pt)
- $ 2 \leq N \leq 13$
Large(21pt)
- $ 2 \leq N \leq 10^{5}$
解法
大まかな方針として、一番大きいパンケーキの場所を固定して、区間を分割していくことを考える。
大きさ $ N $ のパンケーキを置くとその時点で上から覗けるパンケーキの数は1つになるため、置ける候補となる点は $ A_i = 1$ となる $ i $ である。
この $ i $ の前後で区間を分割する。
重要な考察として、$ i $ 番目以前に置いたパンケーキは $ i $ 番目のパンケーキによって覆われてしまうため、それ以降に影響を及ぼさない。
つまり、長さ $ A $ と 長さ $ B $ の区間に分割する時、$\binom{A+B}{A}$ 通り数字の選び方があって、それぞれ $1,2,\ldots$ と数字を振り直していけば元の問題に帰着される。ただし、右側の区間は見えるパンケーキの数が元々 $+1$ される。
後はこれを再帰的に実装すれば部分点が得られる。
もう一つ考察をここから進めると、一番大きいパンケーキを置く候補が複数ある場合必ず一番右を選ばないと矛盾することが分かるため、これをシミュレートすれば良い。計算量は $O(N\log{N})$ で解くことができる。
実装
#pragma GCC optimize("Ofast") #include <bits/stdc++.h> using namespace std; typedef long long int ll; typedef unsigned long long ull; mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count()); ll myRand(ll B) { return (ull)rng() % B; } constexpr ll mod=1e9+7; ll mod_pow(ll a,ll b){ a%=mod; if(b==0)return 1; if(b==1)return a; ll res=mod_pow(a,b/2)%mod; res*=res; res%=mod; if(b%2)res*=a; return res%mod; } struct perm{ private: int sz; vector<ll> p,invp; public: perm(int n){ sz=n+1; p.resize(sz),invp.resize(sz); p[0]=1; for(int i=1;i<=sz-1;i++){ p[i]=p[i-1]*i%mod; } invp[sz-1]=mod_pow(p[sz-1],mod-2); for(int i=sz-2;i>=0;i--){ invp[i]=invp[i+1]*(i+1)%mod; } } ll comb(ll x,ll y){ if(x<y||y<0)return 0; return (p[x]*invp[x-y]%mod)*invp[y]%mod; } }; perm p(1<<20); const int N=100001; vector<int> v[N]; int search(int val,int l,int r){ // [l,r) にある最右のvalのindexを返す // 見つからなかったら-1 auto it=lower_bound(v[val].begin(), v[val].end(),r); if(it==v[val].begin())return -1; it--; int id=*it; if(l<=id and id<r){ return id; } else{ return -1; } } int main(){ cin.tie(nullptr); ios::sync_with_stdio(false); int t; cin >> t; int no=0; while(t--){ no++; int n; cin >> n; for(int i=0;i<N;i++){ v[i].clear(); } vector<int> a(n); for(int i=0;i<n;i++){ cin >> a[i]; v[a[i]].push_back(i); } auto cal=[&](auto cal,int l,int r,int cnt)->ll{ if(l==r)return 1; int len=r-l; int id=search(cnt+1,l,r); if(id==-1)return 0; return p.comb(len-1,id-l)*cal(cal,l,id,cnt)%mod*cal(cal,id+1,r,cnt+1)%mod; }; ll res=cal(cal,0,n,0); printf("Case #%d: %lld\n",no,res); } }