問題概要
次の条件を満たす順列の総数を $998244353$ で割った余りを求めよ。
- 長さが $N$ である
- $K \leqq i \lt N$ を満たす全ての $i$ について、前 $K$ 個の最小値よりも $i$ 番目の値の方が大きい
制約
- $1 \leqq N \leqq 10^{7}$
- $1 \leqq K \leqq 10^{7}$
解法
まず1の入る位置 $i$ を考える。
2番目の条件から $K \leqq i \lt N$ を満たす位置に1は入ることは無い。逆にそれ以外の位置には何処にでも入れることができる。
$i$ を固定した時の順列の総数を考える。以下 $i$ は1-indexedで扱う。
1の前後で独立して考えることができて、
- 1より前は何を置いても最小値条件に影響を与えない→ $_{N-1}C_{i-1} (i-1)!$ 通り
- 1より後ろについては残った数字を1,2,3…と書き換えると、長さ $N-i$ の順列の場合に帰着する
以上より長さ $N$ に対応する順列の総数を $dp\lbrack N \rbrack$ と定めると、
$$ dp[N]=\sum_{i=1}^{\min (N, K)}{ } _ {N-1} C _ {i-1}(i-1) ! \cdot dp[N-i] $$
$$ \therefore dp[N]=(N-1)! \sum_{i=1}^{\min (N, K)}\frac{dp[N-i]}{(N-i)!} $$
これは前から累積和を更新しながら計算することで $O(N)$ で求まる。
実装
#include <bits/stdc++.h> using namespace std; typedef long long int ll; const int N=1e7+5; constexpr ll mod=998244353; ll mod_pow(ll x,ll n){ x%=mod; ll res=1; while(n>0){ if(n&1LL)res=res*x%mod; x=x*x%mod; n>>=1LL; } return res; } int main(){ int n,k; cin >> n >> k; vector<ll> fac(N),inv(N); { // 前計算 fac[0]=1; for(ll i=1;i<N;i++) fac[i]=fac[i-1]*i%mod; inv[N-1]=mod_pow(fac[N-1],mod-2); for(ll i=N-2;i>=0;i--)inv[i]=inv[i+1]*(i+1)%mod; } vector<ll> dp(N),s(N); dp[0]=s[0]=1; ll sum=1; for(int i=1;i<N;i++){ if(i-k-1>=0){ sum=(sum-s[i-k-1]+mod)%mod; } dp[i]=fac[i-1]*sum%mod; s[i]=dp[i]*inv[i]%mod; sum=(sum+s[i])%mod; } cout << dp[n] << endl; }