かっつのメモ帳

主に競プロ 時々日記

diverta 2019 Programming Contest 2 E-Balanced Piles

リンク

atcoder.jp

問題概要

N個のマスがあり最初はどのマスにも積み木は乗っていない。N個のマスのうち積み木が最小のマスを一つを選んで、現在の積み木の高さのMAX~MAX+Dの高さの好きな高さまで積み木を載せる操作を行う。N個のマスの積み木の高さが全てHになる積み重ね方の総数を 10^9+7で割った余りを求めよ。

制約

 2 \leqq N \leqq 10\  ^6  1 \leqq D \leqq 10\  ^6  1 \leqq H \leqq 10\  ^6

考察

これ難しくないですか 僕は難しいと思います(取っ掛かりが見つけにくい?)

色々考えた結果 O(NH)のDP遷移を考えるとかなり見通しがよくなりました、のでまず最初にそのDPの説明をします。

dp'[i][j]:高さの最大値がiでその個数がjとなる積み重ね方の総数

と定義します。求めたいのはdp'[H][N]となります。

最小値から遷移させないといけないので、dp配列に最小値の個数を持たないといけないのでは?という気持ちになるんですが、この方針は早々に行き詰まってしまいます。ここで積み木を載せる操作をしたマスは必ず最大値になる点に注目して、積み重ねた時点で既に載せる順番が決まってるものとすれば最小値の情報は一切持たなくて良いことになります。

つまりj+1≠Nの時、dp'[i][j+1]=dp'[i][j]×(j+1) と遷移させることが出来ます(イメージとしては既に順番が定まってるj個の間のどこに挿入するか、で考えると分かりやすいです)

また、dp'[i+1][1]+=dp'[i][j], dp'[i+2][1]+=dp'[i][j] ,…dp'[min(H,i+D)][1]+=dp'[i][j]とも遷移出来ます。

従ってdp'[i+1][1]+=(dp'[i][1]+dp'[i][2]+…+dp'[i][N]), dp'[i+2]+=(dp'[i][1]+dp'[i][2]+…+dp'[i][N]), …, dp'[min(i+D,H)][1]+=(dp'[i][1]+dp'[i][2]+…+dp'[i][N])

と表すことが出来ます。またdp'[i][j+1]=dp'[i][j]×(j+1)より、(dp'[i][1]+dp'[i][2]+…+dp'[i][N])=dp'[i][1]*(1!+2!+…+N!)と表すことが出来ます。(1!+2!+…+N!)は前計算可能、そして累積和を計算しておくことによってこの遷移は高速化可能です。

今までdp'[i][1]としていたのをdp[i]と定義し直すことにします。

すると以上の考察から、dp[i]=(dp[i-1]+dp[i-2]+…+dp[max(0,i-D)])×(1!+2!+…+N!) という遷移式になります。

で求める答えはdp[H]になります(上で求めたいのはdp'[H][N]って言ってて自分でも書いてて、ん?となったんですがdp'[H][1]まで確定した時点で後は残りのマスを選ぶ順番は順序付けされているので動かす順番は一意に定まっています。Hより高く積むことが無いのでこれでOKなはずです)

残る問題は初期値の設定です(僕はここで最後頭混乱して苦しみました)

始めの0がN個並んでいる状態では選ぶ順番が決まっていないので適切に処理してあげる必要があります。なのでdp[1]~dp[D]に対してN!(N個のうちどれを選んで最大値にするかがN通り、また残りのN-1個の順序が(N-1)!通りなので)を足してあげれば上手く行きます。以上より O(N+H)でこの問題が解くことが出来ました。

実装

#include <iostream>
using namespace std;
typedef long long int ll;

ll mod=1e9+7;
ll dp[1000100];
ll pr[1000100]; //累積和

int main(){
	cin.tie(nullptr);
	ios::sync_with_stdio(false);
	int n,d,h; cin >> n >> h >> d;
	ll p=1;
	ll s=0;
	for(ll i=1;i<=n;i++){
		(p*=i)%=mod;
		(s+=p)%=mod;
	}
	for(int i=1;i<=h;i++){
		pr[i]=pr[i-1];
		ll sum=(pr[i-1]-pr[max(0,i-d-1)])%mod;
		if(sum<0)sum+=mod;
		(dp[i]+=sum*s%mod)%=mod;
		if(i<=d)(dp[i]+=p)%=mod;
		(pr[i]+=dp[i])%=mod;
	}
	cout << dp[h] << endl;
}

 謝罪

数式見にくくてごめんなさい(途中で見やすく書くことを放棄した)