かっつのメモ帳

主に競プロ 時々日記

Codeforces Round #721 Div2D. MEX Tree

本番ではLCAをゴチャゴチャして解いた。

問題のリンク

問題概要

$0, 1, \ldots , N-1$ の数字が各頂点に書かれた $ N $ 頂点の木が与えられる。

$ 0 \leq k \leq N $ を満たす各 $ k $ について、MEXが $ k $ に等しくなるようなパスの総数を求めよ。

制約

  • $ 2 \leq N \leq 2 \cdot 10^{5} $

解法

パスのMEXの値が $ k $ 以上になるものが存在するためには、 $ k - 1 $ 以下の頂点が全て一直線状に結べればよい。また、そのようなパスが分かっている場合、パスのMEXの値が $ k $ 以上になるペアの総数も簡単に求めることができる。

上記のものが求めれられている時、MEXの値が $ k $ となるようなパスの総数は、(MEXの値が $ k $ 以上になるペアの総数) - (MEXの値が $ k + 1 $ 以上になるペアの総数) と、前後で差分を取れば求めることができる。

従って次の問題が解ければよい。

  • 0から順に頂点集合に加えていく
  • 集合内の頂点を全て通るパスがあるか判定
  • もしあるならばパスの2端点を更新

この性質を用いると、今のパスの2端点を $x, y$ 、今加えようとしている頂点を $ s $ として、

$$ {dist} _{ max } = max( dist(x,y), dist(x,s) , dist(y,s)) \\ dist(x,y) + dist(x,s) + dist(y,s) = 2 {dist} _ { max } $$

を満たしているかを見ればよい。これはLCA等を用いれば確認することができる。

また、頂点を追加して行く度に一つずつ親の頂点を辿っていくことで線形時間で解くこともできる。以下の実装はこの方針で書いてみた。

実装

#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;

int main(){
    int q; cin >> q;
    while(q--){
        ll n; cin >> n;
        vector<vector<int>> g(n);
        for(int i=1;i<n;i++){
            int x,y; cin >> x >> y;
            g[x].push_back(y);
            g[y].push_back(x);
        }
        vector<ll> ch(n); // 部分木の頂点数
        vector<int> pr(n); // 親の頂点番号
        auto dfs=[&](auto dfs,int s,int p)->void{
            ch[s]=1; pr[s]=p;
            for(int t:g[s]){
                if(t==p)continue;
                dfs(dfs,t,s);
                ch[s]+=ch[t];
            }
        };
        dfs(dfs,0,-1);
        vector<ll> res(n+2);
        for(int t:g[0]){
            res[0]+=ch[t]*(ch[t]-1)/2;
        }
        res[1]=n*(n-1)/2-res[0];
        int x=0,y=0;
        ll u=ch[0],v=ch[0];
        vector<bool> on(n,false);
        on[0]=true;
        for(int i=1;i<n;i++){
            int s=i;
            if(on[s]){
                res[i+1]=res[i];
                continue;
            }
            while(!on[s]){
                on[s]=1;
                s=pr[s];
            }
            if(s==x){
                u=ch[i];
                x=i;
                if(y==0 and v==ch[0]){
                    for(int t:g[0]){
                        if(on[t])v-=ch[t];
                    }
                }
            }
            else if(s==y){
                v=ch[i];
                y=i;
            }
            else break;
            res[i+1]=u*v;
        }
        for(int i=1;i<=n;i++){
            res[i]-=res[i+1];
        }
        for(int i=0;i<=n;i++){
            printf("%lld ",res[i]);
        }
        printf("\n");
    }
}