かっつのメモ帳

主に競プロ 時々日記

ICPC国内予選2018 G-数式探し

過去記事の修正版です

問題のリンク

問題概要

 +, \times , () 及び一桁の数字からなる数式が与えられるので、この数式の連続部分文字列の数式で値が N に等しくなるものの個数を求めよ。

制約

 1 \leqq N \leqq 10\  ^9

解法

基本的な方針はこんな感じ。

  • 区間の長さに対して計算結果は単調に増加するので尺取り法の要領で数え上げが出来る
  • 選ぶ左端と右端に関する注意として、括弧の内外に分かれてはいけない
  • この対策として次のようなアルゴリズムを考える
  • 括弧の中身を優先的に処理→括弧の中身を計算結果に置き換えて元の数式を処理

例を挙げると、

N=2
(1+1)×1+1
括弧の中身について 1+1の1通り
括弧の中身を計算すると 1+1=2
ここで今見ている数式は 2×1+1となる
この区間について 2と2×1の2通り
合計で3通り

のようなイメージで実装しました。ここまでを実装するとこんな感じに

#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#include <deque>
using namespace std;
typedef long long ll;

const ll INF=1e9+5;

ll n;
string s;
ll res=0;
int r; //閉じ括弧のindexを保持

ll cal(int l){
    vector<ll> st;
    vector<char> sc;
    for(int i=l;i<s.size();i++){
        if(s[i]=='('){
            ll p=cal(i+1);
            st.push_back(p);
            i=r;
        }
        else if(s[i]=='+'||s[i]=='*'){
            sc.push_back(s[i]);
        }
        else if('0'<=s[i]&&s[i]<='9'){
            st.push_back(s[i]-'0');
        }
        else if(s[i]==')'){
            r=i;
            break;
        }
    }
    // 括弧の中身で尺取り

    // 括弧の中身を計算して返す
    ll sum=0; ll last=st[0];
    for(int i=1;i<st.size();i++){
        if(sc[i-1]=='+'){
            sum+=last;
            last=st[i];
        }
        else{
            last*=st[i];
        }
        sum=min(sum,INF);
        last=min(last,INF);
    }
    sum+=last;
    sum=min(sum,INF);
    return sum;
}

void solve(){
    res=0;
    vector<ll> st;
    vector<char> sc;
    for(int i=0;i<s.size();i++){
        if(s[i]=='('){
            ll p=cal(i+1);
            st.push_back(p);
            i=r;
        }
        else if(s[i]=='+'||s[i]=='*'){
            sc.push_back(s[i]);
        }
        else{
            st.push_back(s[i]-'0');
        }
    }
    // 括弧を全部処理した後数式全体で尺取り

    printf("%lld\n",res);
}

int main(){
    while(cin >> n,n){
        cin >> s;
        solve();
    }
}

ここまでで1つ罠が存在して、括弧内の数式を愚直に計算しようとするとオーバーフローを起こすケースが存在します。この対策としては、その括弧の区間を丸々含む解は存在しないので適当なINF値で置き換えれば良いです。

後は尺取り部分を実装すれば終わりです。

ところでこの尺取りの操作における注意しなければならないポイントとして、×1の処理が挙げられます。

色々な解決策はあると思うんですが、左端を固定した時に二種類の右端を用意すると数え上げが容易になります。イメージとしては次の通りです。

r1:今の計算結果がN以下になるまで伸ばした時の右端
r2:今の計算結果がN以下になるまで伸ばした時の右端、ただしNと等しくなったら更新を止める
 

N=15, 3を区間の左端とする
3×5)×1×1)+8
 ↑r2 ↑r1
 
この時条件を満たす数式は(r1-r2+1)通り

ここまで来ると後は実装やるだけ感が出てきましたね。尺取りの操作にはdequeを使って行いました。以下実装です。

実装

#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#include <deque>
using namespace std;
typedef long long ll;

const ll INF=1e9+5;

ll n;
string s;
ll res=0;
int r; //閉じ括弧のindexを保持

ll cal(int l){
    vector<ll> st;
    vector<char> sc;
    for(int i=l;i<s.size();i++){
        if(s[i]=='('){
            ll p=cal(i+1);
            st.push_back(p);
            i=r;
        }
        else if(s[i]=='+'||s[i]=='*'){
            sc.push_back(s[i]);
        }
        else if('0'<=s[i]&&s[i]<='9'){
            st.push_back(s[i]-'0');
        }
        else if(s[i]==')'){
            r=i;
            break;
        }
    }
    // 括弧の中身で尺取り
    ll now=st[0];
    int r1=1,r2=1;
    deque<ll> dq;
    deque<char> qc;
    dq.push_back(now);
    for(int i=0;i<st.size();i++){
        while(now<=n&&r1<st.size()){
            bool check=(now<n);
            if(sc[r1-1]=='+'){
                if(now+st[r1]>n)break;
                now+=st[r1];
                dq.push_back(st[r1]);
                qc.push_back(sc[r1-1]);
                r1++;
                if(check)r2=r1;
            }
            else{
                ll q=dq.back(); dq.pop_back();
                if(now-q+q*st[r1]>n){
                    dq.push_back(q);
                    break;
                }
                now-=q;
                q*=st[r1];
                now+=q;
                dq.push_back(q);
                qc.push_back(sc[r1-1]);
                r1++;
                if(check)r2=r1;
            }
        }
        if(now==n)res+=(r1-r2+1);
        if(qc.size()==0){
            dq.pop_back();
            dq.push_back(st[i+1]);
            now=st[i+1];
        }
        else if(qc.front()=='+'){
            now-=dq.front();
            dq.pop_front();
            qc.pop_front();
        }
        else{
            ll q=dq.front(); dq.pop_front();
            now-=q;
            q/=st[i];
            now+=q;
            dq.push_front(q);
            qc.pop_front();
        }
        r1=max(r1,i+2);
        r2=max(r2,i+2);
    }
    // 括弧の中身を計算して返す
    ll sum=0; ll last=st[0];
    for(int i=1;i<st.size();i++){
        if(sc[i-1]=='+'){
            sum+=last;
            last=st[i];
        }
        else{
            last*=st[i];
        }
        sum=min(sum,INF);
        last=min(last,INF);
    }
    sum+=last;
    sum=min(sum,INF);
    return sum;
}

void solve(){
    res=0;
    vector<ll> st;
    vector<char> sc;
    for(int i=0;i<s.size();i++){
        if(s[i]=='('){
            ll p=cal(i+1);
            st.push_back(p);
            i=r;
        }
        else if(s[i]=='+'||s[i]=='*'){
            sc.push_back(s[i]);
        }
        else{
            st.push_back(s[i]-'0');
        }
    }
    // 括弧を全部処理した後数式全体で尺取り
    // 上の丸々コピペするだけでOK
    ll now=st[0];
    int r1=1,r2=1;
    deque<ll> dq;
    deque<char> qc;
    dq.push_back(now);
    for(int i=0;i<st.size();i++){
        while(now<=n&&r1<st.size()){
            bool check=(now<n);
            if(sc[r1-1]=='+'){
                if(now+st[r1]>n)break;
                now+=st[r1];
                dq.push_back(st[r1]);
                qc.push_back(sc[r1-1]);
                r1++;
                if(check)r2=r1;
            }
            else{
                ll q=dq.back(); dq.pop_back();
                if(now-q+q*st[r1]>n){
                    dq.push_back(q);
                    break;
                }
                now-=q;
                q*=st[r1];
                now+=q;
                dq.push_back(q);
                qc.push_back(sc[r1-1]);
                r1++;
                if(check)r2=r1;
            }
        }
        if(now==n)res+=(r1-r2+1);
        if(qc.size()==0){
            dq.pop_back();
            dq.push_back(st[i+1]);
            now=st[i+1];
        }
        else if(qc.front()=='+'){
            now-=dq.front();
            dq.pop_front();
            qc.pop_front();
        }
        else{
            ll q=dq.front(); dq.pop_front();
            now-=q;
            q/=st[i];
            now+=q;
            dq.push_front(q);
            qc.pop_front();
        }
        r1=max(r1,i+2);
        r2=max(r2,i+2);
    }
    printf("%lld\n",res);
}

int main(){
    while(cin >> n,n){
        cin >> s;
        solve();
    }
}