티스토리 뷰

CCO 2022 Day 2를 돌아 49/75점(7 17 25)을 받았다. 당시 대회 스코어보드가 남아있지 않아 정확한 난이도를 알 수 없지만, 평범한 점수인 것 같다. 

[BOJ 25224] Good Game (3:26)

세 문제를 처음 읽었을 때 쉬워 보이는 문제가 없어서 열심히 긁었다. 결국 거의 모든 서브태스크를 하나하나씩 긁으면서 시간이 꽤 걸렸는데, 긁다 보니 3번 풀이가 보였다.

 

Subtask 2는 파일 합치기 유형 DP를 이용해 해결할 수 있다. 사실 이거랑 정확히 똑같은 건 아니고, 조금 다른 점이 있긴 하지만 코드가 거의 똑같다. $O(N^3)$에 풀 수 있다.

 

Subtask 3부터는 조금의 관찰이 필요하다. 스택으로 접근하자. 스택에 문자를 계속 쌓다가, 어떤 문자가 맨 위에 2개 이상 쌓이는 순간 그 문자 전체를 제거할 수 있다. 어떤 문자가 2개 이상 쌓이면 2개 있는 것과 다를 바가 없다. 또 2개 이상 쌓였을 때 당장 없앨 게 아니면, 나중에 뒤에 나오는 문자와 함께 없애는 거니 1개로 줄여도 상관 없다. 따라서 큐의 형태가 항상 ABAB...또는 BABA...임을 알 수 있다. 이걸 가지고 DP를 한다. $DP_{AA}[i][j]$는 $i$번 문자까지 봤을 때 큐의 시작과 끝이 $A$이고 큐의 길이가 $j$인 것이 가능한지를 저장한다. 비슷한 식으로 $DP_{AB}$, $DP_{BA}$, $DP_{BB}$를 정의한다. 이렇게 점화식을 세우면 $O(N^2)$에 문제를 해결할 수 있다.

 

Full Task를 풀기 위해서는 $DP_{s} [i][j]=1$인 $j$가 구간을 이룸을 관찰하면 된다. 이 구간을 DP로 관리하면 $O(N)$에 문제를 해결할 수 있다.

 

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

struct Interval{
    int l, r;
    Interval(){
        l=1e9, r=0;
    }
    Interval(int _l, int _r){
        l=_l, r=_r;
        if(l>r) l=1e9, r=0;
    }

    Interval operator+(const Interval &nxt)const{
//        assert(max(0, r-l+1) + max(0, nxt.r-nxt.l+1) >= max(0,max(r,nxt.r)-min(l,nxt.l)+1));
        return Interval(min(l, nxt.l), max(r, nxt.r));
    }

    void operator+=(const Interval &nxt){
        *this = *this + nxt;
    }

    Interval operator+(const int &x)const{
        return Interval(l+x, r+x);
    }

    Interval operator-(const int &x)const{
        return Interval(max(1, l-x), r-x);
    }

    bool include(int x){
        return l<=x && x<=r;
    }
};

int n;
char arr[1000002];
char str[1000002];

Interval AtoA[1000002];
Interval AtoB[1000002];
Interval BtoA[1000002];
Interval BtoB[1000002];
bool zero[1000002];

int main(){
    scanf("%d %s", &n, arr+1);

    zero[0] = 1;
    for(int i=1; i<=n; i++){
        if(arr[i] == 'A'){
            AtoA[i] = AtoA[i-1] + (AtoB[i-1] + 1);
            if(zero[i-1]) AtoA[i] += Interval(1, 1);
            BtoA[i] = BtoA[i-1] + BtoB[i-1];
            AtoB[i] = AtoA[i-1] - 1;
            BtoB[i] = BtoA[i-1];
            if(AtoA[i-1].include(1)) zero[i] = 1;
        }
        else{
            BtoB[i] = BtoB[i-1] + (BtoA[i-1] + 1);
            if(zero[i-1]) BtoB[i] += Interval(1, 1);
            AtoB[i] = AtoB[i-1] + AtoA[i-1];
            BtoA[i] = BtoB[i-1] - 1;
            AtoA[i] = AtoB[i-1];
            if(BtoB[i-1].include(1)) zero[i] = 1;
        }
//        printf("%d: AA [%d, %d] AB [%d, %d] BA [%d, %d] BB [%d, %d] zero %d\n",
//               i, AtoA[i].l, AtoA[i].r, AtoB[i].l, AtoB[i].r, BtoA[i].l, BtoA[i].r, BtoB[i].l, BtoB[i].r, zero[i]);
    }

    if(!zero[n]){
        puts("-1");
        return 0;
    }

    int state = 0, len = 0; /// 0, AA, AB, BA, BB
    for(int i=n; i>=1; i--){
        if(arr[i] == 'A'){
            if(state == 0) str[i] = '-', state = 1, len = 1;
            else if(state == 1){
                if(AtoA[i-1].include(len)) str[i] = '=';
                else if(AtoB[i-1].include(len-1)) str[i] = '+', state = 2, len--;
                else str[i] = '+', state = 0, len = 0;
            }
            else if(state == 2) str[i] = '-', state = 1, len++;
            else if(state == 3){
                if(BtoA[i-1].include(len)) str[i] = '=';
                else str[i] = '+', state = 4;
            }
            else str[i] = '-', state = 3;
        }
        else{
            if(state == 0) str[i] = '-', state = 4, len = 1;
            else if(state == 4){
                if(BtoB[i-1].include(len)) str[i] = '=';
                else if(BtoA[i-1].include(len-1)) str[i] = '+', state = 3, len--;
                else str[i] = '+', state = 0, len = 0;
            }
            else if(state == 3) str[i] = '-', state = 4, len++;
            else if(state == 2){
                if(AtoB[i-1].include(len)) str[i] = '=';
                else str[i] = '+', state = 1;
            }
            else str[i] = '-', state = 2;
        }
    }

//    printf("%s\n", str);

    vector<pair<int, int> > vec;
    vector<pair<int, int> > ans;
    for(int i=1; i<=n; i++){
        if(str[i] == '+') vec.push_back(make_pair(vec.empty() ? 1 : vec.back().first+vec.back().second, 1));
        else if(str[i] == '=') vec.back().second++;
        else{
            int s = vec.back().first, l = vec.back().second + 1;
            vec.pop_back();
            while(l > 3) ans.push_back(make_pair(s, 2)), l-=2;
            ans.push_back(make_pair(s, l));
            assert(l>1);
        }
    }
    printf("%d\n", (int)ans.size());
    for(auto p: ans) printf("%d %d\n", p.first, p.second);
}

 

[BOJ 25223] Phone Plans (17점)

대회 중에 Subtask 3까지 풀었다. Subtask 3까지는 핵심 아이디어를 알면 구현만 조금 다르고 쉽게 풀린다. 먼저 통신사 A의 간선들, 통신사 B의 간선들을 비용이 증가하는 순으로 정렬한다. 그리고 비용이 작은 것부터 union find를 했을 때, 이미 연결되어 있는 정점을 잇는 간선은 필요가 없으므로 버린다. (다른 말로, kruscal algorithm에서 필요한 간선만을 남긴다)

 

이제 나머지 간선들은 하나를 추가할 때마다 연결된 정점 집합이 바뀌는 중요한 간선들이다. 이 문제를 어렵게 만드는 주 요인은, 두 정점 쌍이 A의 간선들만을 이용해서 연결되어 있거나, B의 간선들만을 이용해서 연결되어 있는 경우를 세어야 하기 때문이다. 이 부분을 해결할 방법을 대회 시간 내에 찾지 못했다. Subtask 1부터 3까지 푸는 방법을 요약하면,

  • Subtask 1은 각 정점 쌍이 처음으로 연결되는 시점을 찾는 것이 주요 과제이다. Union Find를 하는데, 각 집합의 정점 목록을 가지고 다닌다. 이후 두 점이 합쳐질 때 양쪽 집합에서 점을 하나씩 뽑는 모든 쌍에 대해 시각을 기록한다. 이 과정 전체는 $O(N^2)$의 시간 복잡도가 걸린다. 이제 $(a, b)$ 쌍이 트리 A에서 합쳐지는 시점을 $x$좌표로, 트리 B에서 합쳐지는 시점을 $y$좌표로 놓고 좌표평면 위에 플로팅을 하면 누적합을 이용해 답을 찾는 것은 어렵지 않다.
  • Subtask 2는 두 트리 A와 B가 독립적이기 때문에 각각에 대해서 쌍 개수를 세기만 하면 된다. Union Find의 각 집합의 크기를 관리하면서 다니면 된다.
  • Subtask 3은 트리 A의 크기가 매우 작기 때문에 A에서 연결된 쌍의 개수가 적다. 이 쌍들에 대해서만 B에서 연결되었는지를 매번 확인해 주면 된다.

나는 Full Task를 이상한 방향으로 접근해서 오랫동안 풀지 못했다. 다음은 mingyu331과 토론한 후 알게 된 Full Task 풀이이다. 요점은 A와 B 모두에서 이동 가능한 정점 쌍 수를 어떻게 빼 주는가인데, 트리 A와 B에서 union find로 간선을 하나씩 합쳐 가면서 집합의 변화를 $O(N \log N)$에 모두 관리할 수 있다. small to large를 한다고 생각해도 되고, union find에서 size optimization을 해준다고 생각해도 된다. 둘이 같은 이야기이긴 하다. 어쨌든 이제 임의 시점에서 어떤 점 x가 어떤 집합에 있었는지를 알 수 있다.

 

그 다음으로 전체적인 풀이의 방향을 정하자. 투 포인터를 사용한다. 두 트리 간선을 비용 순으로 정렬했을 때, 트리 A에서 $L_a$번 간선까지, 트리 B에서 $L_b$번 간선까지 사용했다고 하자. 이 상태를 $(L_a, L_b)$로 나타낸다. $(0, b)$에서 시작한다. 현재 상태 $(x, y)$에서 $K$개 이상의 쌍을 만들 수 있으면 답에 갱신하고 $(x-1, y)$로 간다. 그렇지 않다면 $(x, y+1)$로 간다. 이제 중요한 것은 $x$와 $y$가 바뀔 때 두 집합 값이 동일한 수를 갱신하는 것인데, $a$가 증가하는 경우 Union Find를 실시간으로 하나 관리하며 트리 A 내의 집합 번호가 바뀌는 정점들을 수동으로 갱신해 주면 된다. $b$가 감소하는 경우는 실시간으로 처리하긴 어려우므로, 먼저 다 계산해두고 거꾸로 롤백하는 식으로 갱신해주면 된다. 이렇게 하면 $O(N \log^2 N)$ 또는 $O(N \log N)$ (hash map)에 문제를 해결할 수 있다.

 

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

struct dat{
    int x, a, b;
    dat(){}
    dat(int x, int a, int b): x(x), a(a), b(b){}
};

struct UnionFind{
    int par[200002];
    int sz[200002];
    vector<int> vec[200002];

    void init(int n){
        for(int i=1; i<=n; i++){
            par[i] = i, sz[i] = 1;
            vec[i].clear();
            vec[i].push_back(i);
        }
    }

    int find(int x){
        if(x == par[x]) return x;
        return par[x] = find(par[x]);
    }

    ll merge(int x, int y, vector<dat> &history){
        x = find(x), y = find(y);
        if(x==y) return 0;
        if(sz[x] < sz[y]) swap(x, y);
        par[y] = x;
        ll tmp = ll(sz[x]) * sz[y];
        sz[x] += sz[y];
        history.clear();
        for(auto X: vec[y]){
            vec[x].push_back(X);
            history.push_back(dat(X, y, x));
        }
        vec[y].clear();
        return tmp;
    }
} dsuA, dsuB;

struct Line{
    int x, y, c;
    Line(){}
    Line(int x, int y, int c): x(x), y(y), c(c){}
    bool operator<(const Line &r)const{
        return c<r.c;
    }
};

int n, a, b; ll L;
Line edgeA[200002], edgeB[200002];
vector<dat> historyA;
vector<dat> historyB[200002];
ll changeB[200002];

inline ll pr(int x, int y){
    return ll(x)*n*2+y;
}

map<ll, ll> mp;
ll cnt = 0;
ll total = 0;
int ans = INT_MAX;
int bWhere[200002];

inline void add(ll x){
    cnt += mp[x]++;
//    printf("add %lld\n", x);
}

inline void del(ll x){
    cnt -= --mp[x];
//    printf("del %lld\n", x);
}

int main(){
    scanf("%d %d %d %lld", &n, &a, &b, &L);
    for(int i=1; i<=a; i++) scanf("%d %d %d", &edgeA[i].x, &edgeA[i].y, &edgeA[i].c);
    for(int i=1; i<=b; i++) scanf("%d %d %d", &edgeB[i].x, &edgeB[i].y, &edgeB[i].c);
    sort(edgeA+1, edgeA+a+1);
    sort(edgeB+1, edgeB+b+1);

    dsuA.init(n), dsuB.init(n);
    for(int i=1; i<=b; i++){
        total += changeB[i] = dsuB.merge(edgeB[i].x, edgeB[i].y, historyB[i]);
    }
    mp.clear(); cnt = 0;
    for(int i=1; i<=n; i++) add(pr(i, dsuB.find(i))), bWhere[i] = dsuB.find(i);

    for(int ax=0, bx=b; bx>=0&&ax<=a; ){
        if(total - cnt >= L){ /// ÃæºÐ
            ans = min(ans, edgeA[ax].c + edgeB[bx].c);
            if(!bx) break;
            for(dat p: historyB[bx]){
                del(pr(dsuA.find(p.x), p.b));
                add(pr(dsuA.find(p.x), p.a));
                bWhere[p.x] = p.a;
            }
            total -= changeB[bx];
            bx--;
        }
        else{
            ax++;
            if(ax>a) break;
            historyA.clear();
            total += dsuA.merge(edgeA[ax].x, edgeA[ax].y, historyA);
            for(dat p: historyA){
                del(pr(p.a, bWhere[p.x]));
                add(pr(p.b, bWhere[p.x]));
            }
        }
    }

    printf("%d", ans == INT_MAX ? -1 : ans);
}

 

[BOJ 25222] Bi-ing Lottery Treekets (7점)

Subtask 1은 백트래킹, Subtask 2는 정렬 뒤 간단한 코딩으로 풀 수 있다. 둘 다 정해와는 거리가 멀다. 

 

Subtask 3부터는 접근을 위해 트리 DP를 이용해야 한다. 모든 공을 떨어뜨린 뒤 $i$번 서브트리 아래에 총 $j$개의 지점에 공이 들어갔을 때, 이 공 번호의 집합 $S$ 하나가 정해져 있다면 각 공의 위치로 가능한 가짓수를 $DP[i][j]$라고 정의하면 된다. 일단은 공이 각 점에서 최대한 하나이기 때문에 복잡한 케이스 처리가 들어있진 않다. 하지만 DP의 정의 때문에 헷갈리는 일이 자주 발생한다. 공을 떨어뜨리는 순서가 달라도 마지막 상황이 같을 수 있기 때문에 이런 사소한 처리에서 실수할 만한 부분이 많다.

 

더보기
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll MOD = 1000000007;

void quit(){
    puts("0");
    exit(0);
}

int n, k;
int L[4002], R[4002], cnt[4002], dir[4002], sum[4002];
int sz[4002];
ll DP[4002][4002];
ll fact[4002], comb[4002][4002];

inline ll d(int x, int y){
    return (x<0||y<0)?0:DP[x][y];
}

void dfs(int x){
    int l = L[x], r = R[x];
    sz[x] = 1, sum[x] = cnt[x];
    if(l) dir[l] = 0, dfs(l), sz[x] += sz[l], sum[x] += sum[l];
    if(r) dir[r] = 1, dfs(r), sz[x] += sz[r], sum[x] += sum[r];

    if(!l && !r){ /// 리프 노드
        if(cnt[x] > 2) quit();
        else if(cnt[x] == 1) DP[x][1] = 1;
        else DP[x][0] = DP[x][1] = 1;
    }
    else if(l && !r){ /// 왼쪽 자식만 있음
        if(!cnt[x]){ /// 추가 X
            for(int i=sum[x]; i<=sz[x]; i++){
                DP[x][i] += DP[l][i];
            }
            DP[x][sz[x]] += DP[l][sz[l]]; /// 위에서 하나 추가해서 x에 쌓임
        }
        else{ /// 추가 O
            for(int i=sum[x]; i<sz[x]; i++){
                DP[x][i] += DP[l][i] * (i-sum[x]+1) % MOD;
            }
            DP[x][sz[x]] += DP[l][sz[l]] * (sz[x]-sum[x]+1) % MOD;
        }
    }
    else if(r && !l){ /// 오른쪽 자식만 있음
        if(!cnt[x]){ /// 추가 X
            for(int i=0; i<=sz[x]; i++){
                DP[x][i] += DP[r][i];
            }
            DP[x][sz[x]] += DP[r][sz[r]]; /// 위에서 하나 추가해서 x에 쌓임
        }
        else{ /// 추가 O
            for(int i=0; i<sz[x]; i++){
                DP[x][i] += DP[r][i] * (i-sum[x]+1);
            }
            DP[x][sz[x]] += DP[r][sz[r]] * (sz[x]-sum[x]+1) % MOD;
        }
    }
    else{ /// 양쪽 자식 모두 있음
        if(!cnt[x]){
            for(int i=sum[l]; i<=sz[l]; i++){
                for(int j=sum[r]; j<=sz[r]; j++){
                    int L = i - sum[l], R = j - sum[r];
                    if((dir[x] == 0 && R && i!=sz[l]) || (dir[x] == 1 && L && j!=sz[r]) || i+j>k) continue;
                    DP[x][i+j] += DP[l][i] * DP[r][j] % MOD;
                    if(i==sz[l] && j==sz[r])
                        DP[x][i+j+1] += DP[l][i] * DP[r][j] % MOD;
                }
            }
        }
        else{
            for(int i=sum[l]; i<=sz[l]; i++){
                for(int j=sum[r]; j<=sz[r]; j++){
                    if(i > sum[l]){ /// 1. 내 것이 왼쪽으로 감
                        int L = i - sum[l] - 1, R = j - sum[r];
                        if((dir[x] == 0 && R && i!=sz[l]) || (dir[x] == 1 && L && j!=sz[r]) || i+j>k) {}
                        else DP[x][i+j] += DP[l][i] * DP[r][j] % MOD * (L+1) % MOD;
                    }
                    if(j > sum[r]){ /// 2. 내 것이 오른쪽으로 감
                        int L = i - sum[l], R = j - sum[r] - 1;
                        if((dir[x] == 0 && R && i!=sz[l]) || (dir[x] == 1 && L && j!=sz[r]) || i+j>k) {}
                        else DP[x][i+j] += DP[l][i] * DP[r][j] % MOD * (R+1) % MOD;
                    }
                    if(i==sz[l] && j==sz[r]){ /// 3. 이 서브트리가 꽉 참
                        int L = i - sum[l], R = j - sum[r];
                        DP[x][i+j+1] += DP[l][i] * DP[r][j] % MOD; /// 현재에 멈추는 경우
                        DP[x][i+j+1] += DP[l][i] * DP[r][j] * L % MOD; /// 왼쪽에 멈추는 경우
                        DP[x][i+j+1] += DP[l][i] * DP[r][j] * R % MOD; /// 오른쪽에 멈추는 경우
                    }
                }
            }
        }
    }

    for(int i=0; i<=sz[x]; i++) DP[x][i] %= MOD;

    #ifdef TEST
    printf("%d: ", x);
    for(int i=0; i<=sz[x]; i++) printf("%lld ", DP[x][i]);
    puts("");
    #endif
}

int main(){
    scanf("%d %d", &n, &k);
    for(int i=1; i<=k; i++){
        int x;
        scanf("%d", &x);
        cnt[x]++;
    }
    for(int i=1; i<=n; i++) scanf("%d %d", &L[i], &R[i]);

    fact[0] = 1;
    for(int i=1; i<=n; i++) fact[i] = fact[i-1] * i % MOD;
    for(int i=0; i<=n; i++){
        comb[i][0] = comb[i][i] = 1;
        for(int j=1; j<i; j++){
            comb[i][j] = (comb[i-1][j] + comb[i-1][j-1]) % MOD;
        }
    }

    dfs(1);

    printf("%lld", DP[1][k]);
}

 

Fulll Task는 큰 차이는 없고 $cnt[x]$, 즉 $x$에 떨어뜨릴 공 개수가 여러 개가 될 수 있는 점이 문제가 된다. 이건 왼쪽으로 떨어뜨릴 공 수, 오른쪽으로 떨어뜨릴 공 수, 자기 칸에 남을 공 수 조합이 $O(cnt_k)$ 가지라서 다 합치면 $O(k)$가 된다는 사실을 이용하면 시간 복잡도가 $O(N(N+K))$ 이하로 bound됨을 보일 수 있다. 실수할 구석이 정말 많고 케이스도 많아서 까다롭다.

 

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll MOD = 1000000007;

void quit(){
    puts("0");
    exit(0);
}

int n, k;
int L[4002], R[4002], cnt[4002], dir[4002], sum[4002];
int sz[4002];
ll DP[4002][4002];
ll fact[4002], comb[4002][4002];

inline ll d(int x, int y){
    return (x<0||y<0||x>n||y>n)?0:DP[x][y];
}

void dfs(int x){
    int l = L[x], r = R[x];
    sz[x] = 1, sum[x] = cnt[x];
    if(l) dir[l] = 0, dfs(l), sz[x] += sz[l], sum[x] += sum[l];
    if(r) dir[r] = 1, dfs(r), sz[x] += sz[r], sum[x] += sum[r];

    if(sz[x] < sum[x]) quit();

//    if(x==330){
//        printf("");
//    }

    if(!l && !r){ /// 리프 노드
        if(cnt[x] == 1) DP[x][1] = 1;
        else DP[x][0] = DP[x][1] = 1;
    }
    else if(l && !r){ /// 왼쪽 자식만 있음
        for(int i=sum[x]; i<=sz[l]; i++){
            if(i-sum[x]+cnt[x] > n || i-sum[x]+cnt[x] < 0) continue;
            DP[x][i] += DP[l][i] * comb[i-sum[x]+cnt[x]][cnt[x]] % MOD * fact[cnt[x]] % MOD;
        }
        if(sum[l]+cnt[x] <= k && 0 <= sz[x]-sum[x]+cnt[x] && sz[x]-sum[x]+cnt[x] <= n)
            DP[x][sz[x]] += DP[l][sz[l]] * comb[sz[x]-sum[x]+cnt[x]][cnt[x]] % MOD * fact[cnt[x]] % MOD; /// 위에서 하나 추가해서 x에 쌓임
    }
    else if(r && !l){ /// 오른쪽 자식만 있음
        for(int i=sum[x]; i<=sz[r]; i++){
            if(i-sum[x]+cnt[x] > n || i-sum[x]+cnt[x] < 0) continue;
            DP[x][i] += DP[r][i] * comb[i-sum[x]+cnt[x]][cnt[x]] % MOD * fact[cnt[x]] % MOD;
        }
        if(sum[r]+cnt[x] <= k && 0 <= sz[x]-sum[x]+cnt[x] && sz[x]-sum[x]+cnt[x] <= n)
            DP[x][sz[x]] += DP[r][sz[r]] * comb[sz[x]-sum[x]+cnt[x]][cnt[x]] % MOD * fact[cnt[x]] % MOD; /// 위에서 하나 추가해서 x에 쌓임
    }
    else{ /// 양쪽 자식 모두 있음
        for(int A=0; A<=cnt[x]; A++){ /// 왼쪽에 A, 오른쪽에 B 넣음
            int B = cnt[x]-A;
            for(int i=sum[l]+A; i<=sz[l]; i++){ /// 왼쪽 서브트리에 들어가야 할 총 개수
                for(int j=sum[r]+B; j<=sz[r]; j++){ /// 오른쪽 서브트리에 들어가야 할 총 개수
                    int L = i-sum[l]-A, R = j-sum[r]-B; /// x보다 위에서 넣어야 하는 각각의 개수
                    if((dir[x] == 0 && R && i!=sz[l]) || (dir[x]==1 && L && j!=sz[r]) || i+j>k || L<0 || R<0) continue;
                    if(i+j<=n && A+B<=n && A+L<=n && B+R<=n)
                    DP[x][i+j] += DP[l][i] * DP[r][j] % MOD * comb[A+B][A] % MOD * fact[A] % MOD * fact[B] % MOD
                                  * comb[A+L][A] % MOD * comb[B+R][B] % MOD;
                    DP[x][i+j] %= MOD;
                }
            }
        }
        /// 가운데가 차는 경우는 따로 고려하기로 한다.
        /// 1) 가운데가 x보다 위에서 들어오는 경우
        DP[x][sz[x]] += DP[x][sz[x]-1];
        /// 2) 가운데가 x인 겨우
        for(int A=0; A<cnt[x]; A++){
            int B = cnt[x]-1-A;
            int i=sz[l], j=sz[r];
            int L=i-sum[l]-A, R=j-sum[r]-B;
            if(L<0 || R<0) continue;
            if(i+j+1 <= n && A+B<=n && B+R <= n && A+L <= n){
                DP[x][i+j+1] += DP[l][i] * DP[r][j] % MOD * (A+B+1) % MOD * comb[A+B][A] % MOD * fact[A] % MOD * fact[B]
                            % MOD * comb[A+L][A] % MOD * comb[B+R][B] % MOD;
                DP[x][i+j+1] %= MOD;
            }
        }
    }

    for(int i=0; i<=sz[x]; i++) DP[x][i] %= MOD;

    #ifdef TEST
    printf("%d: ", x);
    for(int i=0; i<=sz[x]; i++) printf("%lld ", DP[x][i]);
    puts("");
    #endif
}

int main(){
//    freopen("out.txt", "r", stdin);
    scanf("%d %d", &n, &k);
    for(int i=1; i<=k; i++){
        int x;
        scanf("%d", &x);
        cnt[x]++;
    }
    for(int i=1; i<=n; i++) scanf("%d %d", &L[i], &R[i]);

    fact[0] = 1;
    for(int i=1; i<=n; i++) fact[i] = fact[i-1] * i % MOD;
    for(int i=0; i<=n; i++){
        comb[i][0] = comb[i][i] = 1;
        for(int j=1; j<i; j++){
            comb[i][j] = (comb[i-1][j] + comb[i-1][j-1]) % MOD;
        }
    }

    dfs(1);

    printf("%lld", DP[1][k]);
}

'코딩 > 국대 멘토링 교육' 카테고리의 다른 글

2023 4주차 멘토링 일지  (0) 2023.05.27
2023 3주차 멘토링 일지  (2) 2023.05.21
2023 1주차 멘토링 일지  (1) 2023.05.13
2021 국대 멘토링 일지 - 3주차  (0) 2021.04.24
2021 국대 멘토링 일지 - 1주차  (0) 2021.04.05
공지사항
최근에 올라온 글
Total
Today
Yesterday