かっつのメモ帳

主に競プロ 時々日記

May Cook-Off 2021 - FLGZRO

問題のリンク

問題概要

頂点1 を根とする $ N $ 頂点の根付き木が与えられる。各頂点 $ i $ に値 $ A_i ( > 0 ) $ が書かれていて、そのうち1頂点を選んでその値を $ 0 $ に書き換える。

値が $ 0 $ の頂点を $ u $ として、$ u $ が葉ノードに達するまで次の操作を行い続ける。

  • $ u $ の部分木内の頂点 $ v $ を一つ選んで、$ A _ {u} $ と $ A _ {v} $ を swap する

最終的に得られる木の総数を $ 10 ^ 9 + 7 $ で割った余りを求めよ。

制約

  • $ 1 \leq N \leq 10^{5} $

  • $ 1 \leq A _ {i} \leq 10^{9} $

解法

頂点 $ i $ の部分木に対して、操作を行った時に考えられる木の総数を $dp[i]$ とする木DPを考える。

求めるべき答えは $dp[1]$ であり、初期条件は各葉ノード $ v $ に対して $ dp[v] = 1 $ である。

遷移をその頂点が0として選ばれる場合とその子孫だけで操作が完了する場合の2通りに分けて考える。

前者の総数を $ dp_2 [ i ] $ とおく。この時、頂点 $ i $ の部分木に含まれる頂点の集合を $sub_{i}$ として、

$$ dp _ {2}[i] = \sum _ {v \in sub _ {i} , a[v] \neq a[i]} dp_{2}[v] $$

と表すことができる。

後者は、頂点 $ i $ の子ノードの集合を $ child_{i} $ として次のような遷移で表せる。

$$ dp[i] = \sum _ {v \in child _ {i} } dp[v] + dp_{2} [i] $$

前者の式は、この問題で数え上げる対象が根方向から葉にかけて隣接2項の値が異なるような頂点の選び方と言い換えられることに注目すると分かりやすい。

後者はいつもの木DPで、前者はマージテクの要領でmapを管理していけば良い。

実装

#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;

constexpr int mod=1e9+7;

void add(int &a,int b){
    a+=b;
    if(a>=mod)a-=mod;
}
void sub(int &a,int b){
    a-=b;
    if(a<0)a+=mod;
}

int main(){
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    int q; cin >> q;
    while(q--){
        int n; cin >> n;
        vector<vector<int>> g(n);
        for(int i=1;i<n;i++){
            int x,y; cin >> x >> y;
            x--; y--;
            g[x].push_back(y);
            g[y].push_back(x);
        }
        vector<int> a(n);
        for(int i=0;i<n;i++){
            cin >> a[i];
        }
        vector<int> dp(n);
        auto dfs=[&](auto dfs,int s,int p)->map<int,int>{
            map<int,int> dp2;
            int cnt=0;
            int sum=0;
            for(int t:g[s]){
                if(t==p)continue;
                cnt++;
                auto res=dfs(dfs,t,s);
                add(dp[s],dp[t]);
                add(sum,dp[t]);
                if(res.size()>dp2.size())swap(res,dp2);
                for(auto z:res){
                    add(dp2[z.first],z.second);
                }
            }
            if(cnt==0)sum=1;
            sub(sum,dp2[a[s]]);
            add(dp2[a[s]],sum);
            add(dp[s],sum);
            return dp2;
        };
        dfs(dfs,0,-1);
        cout << dp[0] << "\n";
    }
}