かっつのメモ帳

主に競プロ 時々日記

Codeforces #641 Div1 D. Slime and Biscuits

一応分かった気がするのでメモ 議論間違ってたらすいません…

リンク

https://codeforces.com/contest/1349/problem/D

問題概要

N体のスライムがいてそれぞれ  a_{i}枚ずつビスケットを所持している。毎ターン毎に次の操作を繰り返す。

  •  \sum a_{i}枚から一つ等確率でビスケットを選び、その現在の持ち主以外のN-1体から等確率で選んだ1体に渡す

操作の終了条件は初めてある1体が全てのビスケットを所持することである。操作の期待回数をmod998244353 で求めよ。

制約

 2 \leqq N \leqq 10\  ^5

 1 \leqq  \sum a_{i} \leqq 3 \times 10\  ^5

考察

ある1体が全てのビスケットを所持するまでの期待回数

→位置iで操作が終了する確率を p_{i}、その期待値を e_{i}とすると求める期待値は \sum p_{i} e_{i}である

以下 E_{i}= p_{i} e_{i}とする、また\sum p_{i}=1が成立している

 

また操作の終了とは別に位置iにだけ注目した時に初めて全てのビスケットが揃うまでの期待回数を E'_{i}とすると、Cを定数として全てのj=1,2,…,nに対してE_{j}=E'_{j}-\displaystyle \sum_{i≠j} (p_{i} C + E_{i}) が成立する

ここでCは位置iでビスケットを全て占めた状態から位置j (i≠j)でビスケットを全て占めた状態になるまでの期待回数を表していて、これはi,jのペアに依らない定数として扱うことができる

上式を変形して、 \sum E_{i}=E'_{j}- C \displaystyle \sum_{i≠j} p_{i}

これをj=1,2,…,nについて和を取ることで、 n\sum E_{i}=\sum E'_{i}- C(n-1)

(この式変形には n\sum p_{i}-\sum p_{i}=n-1を用いた)

両辺をnで割ると左辺が求めたい値になっていることから E'_{i}及び定数Cを求めることができれば良い

 

ここで次のような値を考える  dp\lbrack i \rbrack=(ビスケットをi枚所持している時に全てのビスケットを所持するようになるまでの期待回数)

この時、 E'_i = dp\lbrack {a\lbrack i \rbrack} \rbrack, C=dp\lbrack 0 \rbrackである

このdpは次のように求めていくことができる
 dp\lbrack i \rbrack=(ビスケットをi枚所持している時にi+1枚のビスケットを所持するようになるまでの期待回数)を求めて後ろから累積和を取る

 

この値はi番目とi-1番目の関係を漸化式で表すことで求めることができる

dp\lbrack 0 \rbrack=\frac{1}{n-1}+\frac{n-2}{n-1}\times(dp\lbrack 0 \rbrack +1)

dp\lbrack i \rbrack=\frac{sum-i}{sum}\times\frac{1}{n-1}+\frac{sum-i}{sum}\times\frac{n-2}{n-1}\times(dp\lbrack i \rbrack+1)+\frac{i}{sum}\times(dp\lbrack i-1 \rbrack+dp\lbrack i \rbrack+1)

これを整理するとdp\lbrack 0 \rbrack=n-1, \frac{sum-i}{sum(n-1)}dp\lbrack i \rbrack=1+\frac{i}{sum}dp\lbrack i-1 \rbrack が得られる

 

これを実装することで問題を解くことができます

実装

#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long int ll;

constexpr ll mod=998244353;

ll dp[300300];

ll mod_pow(ll a,ll b){
    a%=mod;
    if(b==0)return 1;
    if(b==1)return a;
    ll res=mod_pow(a,b/2)%mod;
    res*=res; res%=mod;
    if(b%2)res*=a;
    return res%mod;
}

int main(){
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    int n; cin >> n;
    vector<int> a(n);
    int sum=0;
    for(int i=0;i<n;i++){
        cin >> a[i];
        sum+=a[i];
    }
    dp[0]=n-1;
    // dp[i]=1+{(sum-i)/sum}*{1/(n-1)}+{(sum-i)/sum}*{(n-2)/(n-1)}*dp[i]+{i/sum}*(dp[i-1]+dp[i])
    for(int i=1;i<sum;i++){
        ll a=(sum-i)*mod_pow(sum,mod-2)%mod*mod_pow(n-1,mod-2)%mod;
        ll b=1+i*mod_pow(sum,mod-2)%mod*dp[i-1]%mod;
        dp[i]=b*mod_pow(a,mod-2)%mod;
    }
    for(int i=sum-1;i>=0;i--){
        (dp[i]+=dp[i+1])%=mod;
    }
    ll res=0;
    for(int i=0;i<n;i++){
        (res+=dp[a[i]])%=mod;
    }
    (res-=dp[0]*(n-1)%mod)%=mod;
    if(res<0)res+=mod;
    printf("%lld\n",res*mod_pow(n,mod-2)%mod);
}