본문 바로가기

코딩/문제풀이

Problem Solving Diary #1

IOI 2020 Day 1. Comparing Plants

대회 때 멘탈이 제대로 갈려나간 문제입니다. 다음 대회를 망치지 않기 위해서는 이런 문제를 풀어야 한다고 생각해서 잡고 풀었습니다.

 

서브태스크 1 (5점)

$k=2$로, 바로 옆의 식물과의 대소관계를 확실히 알 수 있습니다. 어떤 식물 $x$가 식물 $y$보다 확실히 크려면 (문제 조건에 의해 $x<y$), 아래 둘 중 하나가 성립해야 합니다. 단 $h_x$는 $x$번 식물의 높이입니다.

 

  • $h_x > h_{x+1} > \cdots > h_{y-1} > h_y$
  • $h_x > h_{x-1} > \cdots > h_1 > h_n > \cdots > h_{y+1} > h_y$

첫 번째 조건은 $s_x = s_{x+1} = \cdots = s_{y-1} = 0$에 대응되고, 두 번째 조건은 $s_1 = \cdots = s_{x-1} = s_y = \cdots = s_n = 1$에 대응됩니다. 이것이 참인지 판별하면 됩니다.

 

이걸 빠르고 편리하게 하려면 $h_x$의 누적합으로 판정하면 됩니다. 코드는 아래와 같습니다.

 

더보기
#include <bits/stdc++.h>
#include "plants.h"

using namespace std;

typedef long long ll;

int n, k;
int arr[200002];
int sum[200002];

void init(int _k, vector<int> r) {
	n = (int)r.size();
	k = _k;
	for(int i=1; i<=n; i++) arr[i] = r[i-1];

	for(int i=1; i<=n; i++) sum[i] = sum[i-1] + arr[i];
}

int compare_plants(int x, int y) {
	x++, y++;
	if(sum[y-1] - sum[x-1] == (y-1) - (x-1)) return -1;
	if(sum[y-1] - sum[x-1] == 0) return 1;
	if(sum[x-1] == x-1 && sum[n] - sum[y-1] == n-(y-1)) return 1;
	if(!sum[x-1] && sum[n] == sum[y-1]) return -1;
    return 0;
}

 

서브태스크 2 (19점)

$2k > n$이라는 조건이 있는데, 이 조건을 통해 무엇을 얻을 수 있을까요? 바로 전체 수들 중 최댓값을 잡을 수 있게 되었습니다. 최댓값인 위치의 $r_i$ 값은 분명히 0일 텐데, 기존에는 $r_i=0$인 지점이 여러 개였기 때문에 답을 찾을 수가 없었습니다. 하지만 이제는 $r_i=0$인 지점 중에서 다른 $r_i$를 모두 길이 $k$의 범위 안에 포함하는 $i$가 최댓값임이 증명되기 때문에, 최댓값을 찾을 수가 있게 됩니다.

 

예를 들어, $n=5$, $k=3$, $r=[5, 3, 4, 2, 1]$이라고 해 봅시다. 이때 $r_0 = r_2 = 0$인데, 2번 식물이 0번 식물의 범위에 포함되므로 0번 식물(높이 5)이 가장 높다는 것을 알 수 있습니다.

 

이제 0번 식물의 높이를 최소로 만들어 버리고, 이 변화에 따라 $r$을 수정하면 다음 최댓값, 그다음 최댓값 등을 모두 구할 수 있고, 아예 모든 식물 간의 크기 관계를 결정해 버릴 수 있습니다! 따라서 문제가 해결됩니다. 풀고 보니 서브태스크 2와 3은 서브태스크 4의 부분집합이었네요. (왜 IOI에 subtask dependency를 참가자들에게 공개하지 않는지 알 것 같습니다.)

 

서브태스크 3 (32점)

위 작업을 조금 빠르게 수행해야 합니다. 단순하게 하면 $O(N^2)$의 시간이 들 것이기 때문입니다.

 

먼저 위 풀이는 두 부분으로 나눌 수 있습니다. 하나는 최댓값의 위치를 찾는 것이고, 하나는 최댓값을 최솟값으로 바꾸면서 $r$을 갱신하는 것입니다. $r$ 갱신은 구간에 1을 빼거나 원소 하나를 갱신하는 쉬운 연산인데, 최댓값의 위치를 어떻게 빠르게 찾을 수 있을까요?

 

세그먼트 트리를 구축해, 각 노드별로 (범위 내 최솟값, 최솟값의 위치: 여러 개라면 가장 왼쪽 위치)를 들고 있게 합니다. 이때 최솟값은 항상 0임이 보장됩니다. 0 중에서 가장 index가 작은 것을 찾으면, 아래 두 경우 중 하나에 해당합니다.

  • 찾은 위치가 최댓값이다.
  • 찾은 위치는 최댓값이 아니며, 최댓값은 배열의 오른쪽 $k-1$개 칸 중 하나이다.

두 번째 경우 오른쪽 $k-1$개 칸 중에서 0인 가장 왼쪽 칸이 최댓값일 것입니다. 만약 거기에 0이 없다면 첫 번째 경우에 해당할 것입니다. 이제 레이지 연산을 지원하는 세그먼트 트리를 짜면 문제가 $O(N \log N)$에 풀립니다.

 

대회 중에 여기까지 풀었습니다.

 

더보기
#include <bits/stdc++.h>
#include "plants.h"

using namespace std;

typedef long long ll;

struct segTree{
    pair<int, int> tree[800002];
    int lazy[800002];

    void init(int i, int l, int r, int *A){
        lazy[i] = 0;
        if(l==r){
            tree[i] = make_pair(A[l], l);
            return;
        }
        int m = (l+r)>>1;
        init(i*2, l, m, A);
        init(i*2+1, m+1, r, A);
        tree[i] = min(tree[i*2], tree[i*2+1]);
    }

    void propagate(int i, int l, int r){
        tree[i].first += lazy[i];
        if(l!=r){
            lazy[i*2] += lazy[i];
            lazy[i*2+1] += lazy[i];
        }
        lazy[i] = 0;
    }

    void add(int i, int l, int r, int s, int e, int val){
        propagate(i, l, r);
        if(r<s || e<l) return;
        if(s<=l && r<=e){
            lazy[i] += val;
            propagate(i, l, r);
            return;
        }
        int m = (l+r)>>1;
        add(i*2, l, m, s, e, val);
        add(i*2+1, m+1, r, s, e, val);
        tree[i] = min(tree[i*2], tree[i*2+1]);
    }

    pair<int, int> findMin(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if(r<s || e<l) return make_pair(INT_MAX, INT_MAX);
        if(s<=l && r<=e) return tree[i];
        int m = (l+r)>>1;
        return min(findMin(i*2, l, m, s, e), findMin(i*2+1, m+1, r, s, e));
    }
} tree;

int n, k;
int arr[200002];
int sum[200002];
int ans[200002];

void init(int _k, vector<int> r) {
	n = (int)r.size();
	k = _k;
	for(int i=1; i<=n; i++) arr[i] = r[i-1];

    tree.init(1, 1, n, arr);

    for(int turn=n; turn>=1; turn--){
        int minLoc = tree.findMin(1, 1, n, 1, n).second;
        if(minLoc < k){
            auto tmp = tree.findMin(1, 1, n, n-(k-minLoc)+1, n);
            if(tmp.first == 0) minLoc = tmp.second;
        }

        ans[minLoc] = turn;
        tree.add(1, 1, n, minLoc, minLoc, 1000000000);
        if(minLoc >= k) tree.add(1, 1, n, minLoc-k+1, minLoc, -1);
        else{
            tree.add(1, 1, n, 1, minLoc, -1);
            tree.add(1, 1, n, n-(k-minLoc)+1, n, -1);
        }
    }
}

int compare_plants(int x, int y) {
	x++, y++;

	if(ans[x] > ans[y]) return 1;
	return -1;
}

 

서브태스크 4 (49점)

서브태스크 2와 3에서는 해를 하나 구축한 뒤, compare_plants 함수에서는 그 해에서의 결과만 따졌습니다. 이것은 해가 하나이기 때문에 가능했습니다. 마찬가지로, 서브태스크 4에서도 유효한 해를 하나 구축할 수만 있다면 이러한 풀이가 가능할 것입니다.

 

서브태스크 3의 풀이가 통하지 않는 이유는 무엇일까요? 우리가 앞에서 사용한 과정을 보면, 먼저 0이 있는 가장 왼쪽 위치를 구한 뒤, 그 0을 포함하고 있는 다른 0이 (원형이기 때문에) 오른쪽 끝에 있는지 봐 주었습니다. 그런데 이제 $2k > n$이 아니기 때문에, 우리가 그렇게 해서 찾은 0을 포함하는 0이 또 있을 수 있습니다. 그 0을 포함하는 0 역시 또 있을 수 있고... 이러한 방식으로는 풀기 어렵습니다. 어쨌든 다른 0의 구간에 포함되지 않는 0을 찾으면 그것을 최댓값으로 볼 수 있다는 것은 보장될 것입니다.

 

그렇다면 이러한 0은 어떻게 찾을까요? 이러한 0의 왼쪽 $k-1$칸은 0이 아닌 수들만으로 이루어져 있을 것입니다. 따라서, 어떤 칸이 0이면 0, 0이 아니면 1로만 놓는 세그먼트 트리를 생각하고 최대 연속 합의 위치를 구했을 때, 그 위치를 바탕으로 우리가 찾고자 하는 0을 찾을 수 있을 것입니다. 따라서 이런 방식으로 금광 세그를 짜면 됩니다.

 

이게 말이 쉽지 고려해야 하는 부분이 은근히 많은데, 대충 요약해 보면

  1. 구간 내에서 1을 빼며, 0이 되는 점들을 모두 순회한다.
  2. 구간 내 최대 구간 합과 그 위치를 찾는다.

이 두 가지 정도가 추가됩니다. 2번은 금광 세그 위에서 이분 탐색을 구현하거나, 금광 세그의 노드에 몇 가지 인자를 더 넣으면 됩니다. 문제는 1번인데, 1을 뺄 구간 내에서 먼저 최솟값을 찾고, 최솟값이 1이면 그 점을 순회하고 그 점 왼쪽을 자른 뒤 다시 최솟값을 찾고, ... 이런 식으로 반복하면 조금 귀찮긴 하지만 문제가 풀릴 것입니다.

 

더보기
#include <bits/stdc++.h>
#include "plants.h"

using namespace std;

typedef long long ll;

struct segTree{
    pair<int, int> tree[1600002];
    int lazy[1600002];

    void init(int i, int l, int r, int *A){
        lazy[i] = 0;
        if(l==r){
            tree[i] = make_pair(A[l], l);
            return;
        }
        int m = (l+r)>>1;
        init(i*2, l, m, A);
        init(i*2+1, m+1, r, A);
        tree[i] = min(tree[i*2], tree[i*2+1]);
    }

    void propagate(int i, int l, int r){
        tree[i].first += lazy[i];
        if(l!=r){
            lazy[i*2] += lazy[i];
            lazy[i*2+1] += lazy[i];
        }
        lazy[i] = 0;
    }

    void add(int i, int l, int r, int s, int e, int val){
        propagate(i, l, r);
        if(r<s || e<l) return;
        if(s<=l && r<=e){
            lazy[i] += val;
            propagate(i, l, r);
            return;
        }
        int m = (l+r)>>1;
        add(i*2, l, m, s, e, val);
        add(i*2+1, m+1, r, s, e, val);
        tree[i] = min(tree[i*2], tree[i*2+1]);
    }

    pair<int, int> findMin(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if(r<s || e<l) return make_pair(INT_MAX, INT_MAX);
        if(s<=l && r<=e) return tree[i];
        int m = (l+r)>>1;
        return min(findMin(i*2, l, m, s, e), findMin(i*2+1, m+1, r, s, e));
    }
} tree;

struct mineSeg{
    struct Node {
        ll lmax; int lidx;
        ll rmax; int ridx;
        ll ans; int ansl, ansr;
        ll sum;

        Node(){}
        Node(ll x, int idx){
            lmax = rmax = ans = sum = x;
            lidx = ridx = ansl = ansr = idx;
        }

        Node merge(const Node &r)const{
            Node ret = *this;
            if(ret.lmax < sum + r.lmax){
                ret.lmax = sum + r.lmax;
                ret.lidx = r.lidx;
            }

            ret.rmax += r.sum;
            if(ret.rmax < r.rmax){
                ret.rmax = r.rmax;
                ret.ridx = r.ridx;
            }

            if(ret.ans < r.ans) ret.ans = r.ans, ret.ansl = r.ansl, ret.ansr = r.ansr;
            if(ret.ans < rmax + r.lmax){
                ret.ans = rmax + r.lmax;
                ret.ansl = ridx, ret.ansr = r.lidx;
            }

            ret.sum += r.sum;
            return ret;
        }
    } tree[1600002];

    void init(int i, int l, int r, int *A){
        if(l==r){
            tree[i] = Node(A[l] ? 1 : -1e9, l);
            return;
        }
        int m = (l+r)>>1;
        init(i*2, l, m, A);
        init(i*2+1, m+1, r, A);
        tree[i] = tree[i*2].merge(tree[i*2+1]);
    }

    void change(int i, int l, int r, int idx, int x){
        if(l==r){
            tree[i] = Node(x ? 1 : -1e9, l);
            return;
        }
        int m = (l+r)>>1;
        if(idx <= m) change(i*2, l, m, idx, x);
        else change(i*2+1, m+1, r, idx, x);
        tree[i] = tree[i*2].merge(tree[i*2+1]);
    }
} mineTree;

int n, k;
int arr[400002];
int sum[400002];
int ans[400002];

void init(int _k, vector<int> r) {
	n = (int)r.size();
	k = _k;
	for(int i=1; i<=n; i++) arr[i] = r[i-1];
	for(int i=n+1; i<=n*2; i++) arr[i] = arr[i-n];

    tree.init(1, 1, n*2, arr);
    mineTree.init(1, 1, n*2, arr);

    for(int turn=n; turn>=1; turn--){
        auto tmp = mineTree.tree[1];
        if(tmp.ans < k-1) exit(1);
        int loc = tmp.ansr % n + 1;
        ans[loc] = turn;

        /// 1에서 0이 되는 위치들을 검색
        int lim = loc+n, leftmost = lim-k+1;
        while(leftmost < lim){
            pair<int, int> tmp2 = tree.findMin(1, 1, n*2, leftmost, lim-1);
            if(tmp2.first > 1) break;
            int x = tmp2.second;
            mineTree.change(1, 1, n*2, (x-1)%n+1, 0);
            mineTree.change(1, 1, n*2, (x-1)%n+1+n, 0);
            leftmost = x+1;
        }
        mineTree.change(1, 1, n*2, loc, 1);
        mineTree.change(1, 1, n*2, loc+n, 1);

        /// tree 변경
        tree.add(1, 1, n*2, loc, loc, 1000000000);
        tree.add(1, 1, n*2, loc+n, loc+n, 1000000000);

        if(loc >= k){
            tree.add(1, 1, n*2, loc-k+1, loc, -1);
            tree.add(1, 1, n*2, loc-k+1+n, loc+n, -1);
        }
        else{
            tree.add(1, 1, n*2, 1, loc, -1);
            tree.add(1, 1, n*2, loc-k+1+n, loc+n, -1);
            tree.add(1, 1, n*2, 2*n-(k-loc)+1, n*2, -1);
        }
    }
}

int compare_plants(int x, int y) {
	x++, y++;

	if(ans[x] > ans[y]) return 1;
	return -1;
}

 

서브태스크 5 (60점)

이제 답으로 0이 가능해집니다. 4번 서브태스크를 풀면서 얻은 힌트를 토대로 5번 서브태스크를 풀어 봅시다.

 

어떤 식물 $x$를 잡아 $x$번 식물에 최대한 높은 높이를 배정하는 것을 시도해 봅시다. 최대한 높은 높이로 배정하려면 세그먼트 트리의 $r_x$ 값을 최대한 빠르게 0을 만들어야 합니다. 그러려면 현재 최댓값을 배정할 수 있는 후보들 중에 $x$번 식물에서 오른쪽으로의 거리가 가장 짧은 것부터 선택하는 것이 항상 이득일 것입니다. 따라서 위 세그먼트 트리에 우선순위를 배정해 주면 $x$번 식물보다 항상 높은 식물의 목록을 얻을 수 있을 것 같습니다.

 

물론 위 가설은 증명이 필요합니다. $x$번 식물보다 높이가 높은 식물의 집합을 $S$라고 하면, 가능한 $S$의 후보는 당연히 많습니다. 우리는 이 중에서 $S$의 크기가 최소인 것이 유일하다고 보이고 싶은 것입니다. 상당히 비자명한 추측이지만, 일단 저는 맞다고 가정하고 코드를 짰습니다. 하지만 오답이 나왔습니다. 

 

잘 생각해 보면, $r_x$ 값을 빠르게 0을 만드는 것도 중요하지만, $r_x$ 이전 $k-1$개 값이 0이 아니도록 만드는 것도 중요합니다. 따라서 이 부분까지 고려해 보면 적당한 그리디를 찾는 것이 쉽지 않습니다. 따라서 완전히 다른 접근이 필요합니다.

 

위와 같은 방식으로 식물 하나를 고정하고 그것보다 무조건 큰 식물을 찾아나가는 풀이는 아무리 못해도 $O(N^2 \log N)$일 텐데, $N \le 300$이라는 것은 완전히 다른 방향으로 접근해야 함을 뜻합니다. 마침 세제곱이 돌 것 같이 생겼으니, 다른 방법을 고민해 봅시다. 

 

(작성중)

 

USACO 2021 US Open Contest. Balanced Subsets

어떤 올바른 영역 $S$에 대해, $S$의 가장 위쪽 칸이 $s$번 행에, 가장 아래쪽 칸이 $e$번 행에 있다고 가정하고, $s \le i \le e$인 $i$에 대해 $i$행의 칸 중 $S$에 포함된 칸을 $l_i$열부터 $r_i$열까지라고 합시다. 이때, $l$은 단조 감소하다가 단조 증가하고, $r$은 단조 증가하다가 단조 감소한다는 관찰을 할 수 있습니다.

 

기본적인 틀을 아래와 같이 잡습니다.

$DP[i][l][r]$: $S$의 가장 아래쪽 행이 $i$번 행이고, $i$번 행의 칸 중 $S$에 포함된 칸이 $l$열부터 $r$열까지인 올바른 $S$의 개수로 정의합니다. 이때 전이를 구하는 것이 쉽지 않기 때문에, 추가적인 인자 $u$와 $v$를 도입합니다.

 

$u$: $l$이 한 번이라도 감소한 적 있다면 1, 없다면 0

$v$: $r$이 한 번이라도 증가한 적 있다면 1, 없다면 0

 

따라서 $DP[i][l][r][u][v]$의 5차원 DP를 관리하면, 각 $u, v$로 가능한 4가지 상태에 대해 전이를 누적합을 이용해 따로따로 처리할 수 있고 문제가 해결됩니다. 이때 시간 복잡도는 $O(N^3 \log N)$ 또는 $O(N^3)$입니다. 코드는 $u$와 $v$를 하나의 인자로 합쳐 구현했습니다.

 

더보기
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll MOD = 1000000007;

int n;
int arr[152][152];
vector<int> loc[152];
int sum[152][152];
ll DP[152][152][152][4]; /// 0: 위 / 아래 \, 1: 위 / 아래 /, 2: 위 \ 아래 \, 3: 위 \ 아래 / (같으면 왼쪽으로 판정)
ll ans;

void addRange(int d, int x1, int x2, int y1, int y2, int mode, ll val){
    if(x1 > x2 || y1 > y2) return;
    DP[d][x1][y1][mode] = (DP[d][x1][y1][mode] + val) % MOD;
    DP[d][x1][y2+1][mode] = (DP[d][x1][y2+1][mode] - val + MOD) % MOD;
    DP[d][x2+1][y1][mode] = (DP[d][x2+1][y1][mode] - val + MOD) % MOD;
    DP[d][x2+1][y2+1][mode] = (DP[d][x2+1][y2+1][mode] + val) % MOD;
}

int main(){
    scanf("%d", &n);
    for(int i=1; i<=n; i++){
        loc[i].push_back(0);
        for(int j=1; j<=n; j++){
            char c;
            scanf(" %c", &c);
            arr[i][j] = (c=='.');
            sum[i][j] = sum[i][j-1] + arr[i][j];
            if(arr[i][j]) loc[i].push_back(j);
        }
        loc[i].push_back(n+1);
    }

    for(int x=1; x<=n; x++){
        for(int l=1; l<=n; l++){
            for(int r=l; r<=n; r++){
                if(sum[x][r] - sum[x][l-1]) continue;
                DP[x][l][r][0] = (DP[x][l][r][0] + 1) % MOD;

                for(int j=0; j<4; j++){
//                    printf("%d %d %d %d: %lld\n", x, l, r, j, DP[x][l][r][j]);
                }
                if(x==n) continue;

                if(sum[x+1][r] - sum[x+1][l-1] == 0){ /// DP[x+1][0]으로 전이
                    ll toZero = DP[x][l][r][0];
                    int lLim = *prev(lower_bound(loc[x+1].begin(), loc[x+1].end(), l)) + 1;
                    int rLim = *upper_bound(loc[x+1].begin(), loc[x+1].end(), r) - 1;
                    addRange(x+1, lLim, l, r, rLim, 0, toZero);
                }

                { /// DP[x+1][1]로 전이
                    ll toOne = DP[x][l][r][0];
                    int rLim = *upper_bound(loc[x+1].begin(), loc[x+1].end(), r) - 1;
                    addRange(x+1, l+1, r, r, rLim, 1, toOne);

                    toOne = DP[x][l][r][1];
                    addRange(x+1, l, r, r, rLim, 1, toOne);
                }

                { /// DP[x+1][2]로 전이
                    ll toTwo = DP[x][l][r][0] % MOD;
                    int lLim = *prev(lower_bound(loc[x+1].begin(), loc[x+1].end(), l)) + 1;
                    addRange(x+1, lLim, l, l, r-1, 2, toTwo);

                    toTwo = DP[x][l][r][2];
                    addRange(x+1, lLim, l, l, r, 2, toTwo);
                }

                { /// DP[x+1][3]으로 전이
                    addRange(x+1, l+1, r-1, l+1, r-1, 3, DP[x][l][r][0]);
                    addRange(x+1, l, r-1, l, r-1, 3, DP[x][l][r][1]);
                    addRange(x+1, l+1, r, l+1, r, 3, DP[x][l][r][2]);
                    addRange(x+1, l, r, l, r, 3, DP[x][l][r][3]);
                }
            }
        }

        for(int l=1; l<=n; l++){
            for(int r=1; r<=n; r++){
                for(int d=0; d<4; d++){
                    DP[x+1][l][r][d] += DP[x+1][l][r-1][d] + DP[x+1][l-1][r][d] - DP[x+1][l-1][r-1][d] + MOD;
                    DP[x+1][l][r][d] %= MOD;
                }
            }
        }

        for(int l=1; l<=n; l++){
            for(int r=1; r<=n; r++){
                if(sum[x+1][r] - sum[x+1][l-1] == 0 && l<=r) continue;
                for(int d=0; d<4; d++) DP[x+1][l][r][d] = 0;
            }
        }
    }

    for(int x=1; x<=n; x++) for(int l=1; l<=n; l++) for(int r=l; r<=n; r++){
        if(sum[x][r] - sum[x][l-1]) continue;
        for(int d=0; d<4; d++) ans = (ans+DP[x][l][r][d]) % MOD;
    }
    printf("%lld", ans);
}

 

IOI 2019 Day 1. Rectangles

서브태스크 1 (8점)

모든 가능한 사각형에 대해 나이브하게 확인하면 $O(N^6)$의 풀이가 나옵니다.

 

서브태스크 2, 3 (27점)

각 행에 대해 $l$번째 열부터 $r$번째 열까지의 최댓값을 전처리해 두고, 열에 대해서도 같은 방식으로 전처리를 하면 직사각형 하나를 $O(N+M)$에 확인이 가능해져, $O(N^5)$에 문제를 풀 수 있습니다. 커팅을 대충이라도 하면 상수가 작아서 서브태스크 3까지 맞습니다.

 

더보기
#include <bits/stdc++.h>
#include "rect.h"

using namespace std;

typedef long long ll;

int n, m;
int arr[2502][2502];
int max1[202][202][202];
int max2[202][202][202];
int ans;

ll count_rectangles(vector<vector<int> > _vec){
    n = _vec.size();
    m = _vec[0].size();

    for(int i=1; i<=n; i++){
        for(int j=1; j<=m; j++){
            arr[i][j] = _vec[i-1][j-1];
        }
    }

    for(int i=1; i<=n; i++){
        for(int j=1; j<=m; j++){
            for(int k=j; k<=m; k++){
                max1[i][j][k] = max(max1[i][j][k-1], arr[i][k]);
            }
        }
    }
    for(int i=1; i<=m; i++){
        for(int j=1; j<=n; j++){
            for(int k=j; k<=n; k++){
                max2[i][j][k] = max(max2[i][j][k-1], arr[k][i]);
            }
        }
    }

    for(int x1=2; x1<n; x1++){
        for(int x2=x1; x2<n; x2++){
            for(int y1=2; y1<m; y1++){
                for(int y2=y1; y2<m; y2++){
                    bool able = 1;
                    for(int x=x1; x<=x2; x++){
                        if(max1[x][y1][y2] >= min(arr[x][y1-1], arr[x][y2+1])){
                            able = 0;
                            break;
                        }
                    }
                    for(int y=y1; y<=y2; y++){
                        if(max2[y][x1][x2] >= min(arr[x1-1][y], arr[x2+1][y])){
                            able = 0;
                            break;
                        }
                    }
                    if(able) ans++;
                }
            }
        }
    }
    return ans;
}

 

서브태스크 5 (36점)

$n=3$이므로, 모든 영역은 2행 위의 인접한 몇 칸으로 이루어져 있고, 각 영역은 두 인자 $l$, $r$로 표현될 수 있습니다. 또한 2행 위 칸들 중 위나 아래에 자기 이하의 값이 있는 칸은 선택해선 안됩니다. 구간 최솟값을 $O(M^2)$에 전처리해 두면, 몇 가지 조건만 만족하는지 확인해서 전체 시간 복잡도 $O(M^2)$에 문제를 풀 수 있습니다.

 

서브태스크 6 (50점)

모두 0으로 이루어져 있고, 사방이 1로 둘러싸여 있는 직사각형 영역의 개수를 세면 됩니다. 따라서 0으로 이루어진 모든 component를 flood-fill 하듯이 찾고, 그 영역이 직사각형을 이루는지 검사해 주면 (주변 경계는 모두 1일 것이므로) 답에 1을 더해 주면 됩니다. 직사각형인지 검사하는 방법은 컴포넌트의 크기를 $S$라고 할 때 $(max_x - min_x + 1)(max_y - min_y + 1) = S$인지 검사해 주면 될 것입니다.

 

서브태스크 4 (72점)

직사각형의 위쪽 변과 아래쪽 변을 고정합시다. 이때 위쪽 경계와 아래쪽 경계보다 큰 수가 중간에 들어 있는 열은 선택이 불가능한 열인데, 이러한 열은 사전에 모두 제외해 둡시다. 이렇게 하면 왼쪽 경계와 오른쪽 경계를 결정해야 하는 상황이 됩니다. 왼쪽 경계와 오른쪽 경계를 잡을 때는 우리가 잡은 경계 안에 있는 모든 행에 대해서 문제의 조건을 만족해야 합니다.

 

행이 여러 개이면 힘들어지니, 행이 하나라고 가정해봅시다. (즉, 직사각형의 위쪽 변과 아래쪽 변의 거리는 1인 경우) 행이 하나이므로, 선택 가능한 범위의 칸 중 최댓값을 잡을 수 있습니다. 이 최댓값은 직사각형의 내부에 존재할 수 없습니다. 따라서 이 최댓값을 기준으로 양쪽으로 범위를 나눠서 분할 정복하듯이 생각할 수 있습니다.

 

예를 들어 해당하는 행에 적힌 수가 [3, 1, 2, 5, 3, 4]이라고 해 봅시다. 편의상 0-index를 사용합시다. 일단 범위 전체가 하나의 후보가 될 수 있을 것입니다(1-4). 전체 최댓값은 3번째 칸의 5이므로, 5는 직사각형 내부에 있을 수 없습니다. 따라서 5를 기준으로 왼쪽과 오른쪽으로 나뉜 부분이 후보가 될 수 있습니다(1-2, 4-4). 1-2 구간에서 2번 칸의 2가 최댓값이므로, 그 왼쪽인 (1-1) 구간이 후보가 됩니다. 이렇게 네 개의 후보가 존재합니다. 일반적으로 이러한 후보의 개수는 $M$개임이 쉽게 증명됩니다.

 

따라서 맨 위쪽 변에 대해서 생각해 보기만 해도 이렇게 후보를 $M$개로 줄일 수 있습니다. 이것을 바탕으로 답을 효율적으로 계산하기 위해서는, 우선 위쪽 경계를 정하고 (N가지), 그 경계 아래의 가장 위에 있는 행에 대해서 위와 같이 분할 정복으로 가능한 후보의 개수를 $M$개 이하로 줄이고 (세그먼트 트리로 $O(M \log M)$에 가능), 아래쪽 경계를 한 칸씩 늘려갈 때마다 각 후보가 가능한지를 검사해서 후보를 줄이면 됩니다. 구현 디테일은 생략합니다.

 

서브태스크 7 (100점)

현재 가능한 후보의 개수는 $O(N^2 M)$개입니다. 후보를 더 줄일 수 있지 않을까요?

 

분할 정복 관점으로 돌아갑시다. 모든 직사각형은 아래 둘 중 하나를 만족합니다.

 

  • $i$번 행을 지난다. (Case 1)
  • $i$번 행을 지나지 않는다. (Case 2)

만약 $i=\frac{N}{2}$ 정도로 잡아서, Case 1을 $O(NM)$ 정도에 처리할 수 있다면 전체 시간복잡도 $O(NM \log N)$에 문제가 해결될 것입니다. 그런데 직사각형이 $i$번 행을 지나는 이상, 가능한 (왼쪽 경계, 오른쪽 경계)의 후보를 서브태스크 6과 같이 $M$개로 간추릴 수 있습니다. 이제 각 후보에 대해 $O(N)$에 (위쪽 경계, 아래쪽 경계)로 가능한 후보의 개수를 세어 주면 될 것입니다. 여기서 로그를 붙이지 않는 것이 가장 큰 관건입니다. 이제부터는 (왼쪽 경계, 오른쪽 경계)는 정했다고 생각하고 설명해 보겠습니다.

 

먼저 좌/우 조건을 만족해서 직사각형에 포함될 수 있는 열의 종류를 구해 봅시다. 단순히 구간 최댓값 세그먼트 트리를 관리하면 $O(N \log M)$이 걸리는데, 업데이트가 없기 때문에 스파스 테이블을 이용해 구간 최댓값을 구해도 됩니다. 이렇게 하면 전처리 $O(NM \log M)$에 시간복잡도 $O(N)$으로 이 부분을 마무리해줄 수 있습니다.

 

이제 상/하 조건을 만족시킬 수 있는 (위, 아래) 경계의 수를 구해 봅시다. 먼저 그 전에 (위쪽 경계) 따로, (아래쪽 경계) 따로에 대해 가능한 경계 종류를 구해 봅시다. 위쪽 경계로 가능한 행은 $i$번 행에서 위로 올라가면서 최댓값이 등장하는 행들일 것입니다. 하지만 이런 식으로 판별하는 것을 $O(NL)$ ($L$은 구간 길이)보다 빠르게 하는 것은 힘들어 보입니다. 따라서 여기서 추가적인 관찰을 활용해야 합니다. 카르테시안 트리를 그리듯이 (왼쪽 경계, 오른쪽 경계)의 후보를 간추렸으므로, 이들은 서로 포함 관계나 상호 배제 관계를 이루고 이는 트리로 나타낼 수 있습니다. 따라서 일단 모든 열에 대해 저 경계 후보를 구해 두고, 우리가 필요한 열에 대해 그때그때 경계 후보의 교집합을 구하는 방식으로 하면 행 $i$를 고정했을 때 이 부분을 전체 시간복잡도 $O(NM)$에 풀 수 있음이 보장됩니다.

 

위쪽 경계와 아래쪽 경계 각각에 대해 구했으니, 이제 이 둘을 합쳐 보겠습니다. 그런데 사실 이 부분도 쉽지가 않아서, 이것도 사전에 각 열에 대해서 따로따로 구해놓은 뒤에 합치는 게 어떨까 하는 생각을 할 수 있습니다. 여기서 또 하나의 관찰이 들어갑니다. 열을 하나만 잡으면, 역시 카르테시안 트리에 의해 후보의 수가 $N$개 이하가 됩니다. 따라서, 이 후보들을 그냥 열별로 그대로 들고 다니는게 가능합니다. 이 둘을 합치는 것은 사전에 열별로 목록을 정렬해 두면 선형 시간에 가능할 것입니다.

 

그럼 이제 마지막으로 각 열에 대해 저 목록을 어떻게 빠르게 들고 다닐지 생각해볼 필요가 있습니다. 그런데 이건 그냥 맨 처음에 열별로 전처리를 해둘 수 있습니다. 그리고 우리는 단지 그 목록 중 $i$번 행을 포함하는 범위만 빼내면 되는 것입니다. 이렇게 하면 문제가 $O(NM \log N)$에 풀립니다. N과 M을 바꿔써서 1시간 정도 해멨는데, 혹시나 해서 코드 맨 윗줄에 N=M=max(N, M);을 추가해보니 AC를 받을 수 있었습니다. 간단한 팁 정도로 알아두셔도 좋을 것 같습니다.

 

이걸 그냥 짜면 메모리가 터지고, 스파스 테이블을 short 자료형으로 관리하는 추한 짓(...)을 해야 맞을 수 있습니다.

 

더보기
#include <bits/stdc++.h>
#define LIM 2502
#include "rect.h"

using namespace std;

typedef long long ll;

int N, M;
int arr[LIM][LIM];
int power2[LIM];
int columnDncLoc[LIM][LIM];
ll ans;

short maxLeft[LIM][LIM][12], maxRight[LIM][LIM][12];
short maxUp[LIM][LIM][12], maxDown[LIM][LIM][12];

vector<pair<short, short> > rangeRow, rangeColumn[LIM][LIM];

void init(vector<vector<int> > &);
void makeColumnDncLoc(int, int);
void makeSparse();
void findRange();

void divideAndConquer(int, int);

ll count_rectangles(vector<vector<int> > a) {
	init(a); /// 간단한 초기화를 합니다.
	M = N = max(N, M);
	makeColumnDncLoc(2, N-1);
    makeSparse(); /// 행 / 열 구간 최댓값의 위치를 찾기 위해 스파스 테이블을 만듭니다.
    findRange(); /// 각 행/열에 대해 카르테시안 트리를 만듭니다.
    divideAndConquer(2, N-1);
    return ans;
}

void init(vector<vector<int> > &a){
    N = a.size(); M = a[0].size();
	for(int i=0; i<N; i++){
        for(int j=0; j<M; j++){
            arr[i+1][j+1] = a[i][j];
        }
	}

	for(int i=0; i<=2500; i++){
        for(int d=1; d<=12; d++){
            if((1<<d) > i+1){
                power2[i] = d-1;
                break;
            }
        }
	}
}

void makeColumnDncLoc(int l, int r){
    if(l>r) return;
    int m = (l+r)>>1;
    for(int i=l; i<=m; i++){
        for(int j=m; j<=r; j++){
            columnDncLoc[i][j] = m;
        }
    }
    makeColumnDncLoc(l, m-1);
    makeColumnDncLoc(m+1, r);
}

void makeSparse(){
    for(int i=1; i<=N; i++){
        for(int j=1; j<=M; j++){
            maxLeft[i][j][0] = maxRight[i][j][0] = j;
            maxUp[i][j][0] = maxDown[i][j][0] = i;
        }
    }

    for(int d=1; d<12; d++){
        int v = (1<<(d-1));
        for(int i=1; i<=N; i++){
            for(int j=1; j<=M; j++){
                if(j>v){
                    if(arr[i][maxLeft[i][j][d-1]] > arr[i][maxLeft[i][j-v][d-1]]) maxLeft[i][j][d] = maxLeft[i][j][d-1];
                    else maxLeft[i][j][d] = maxLeft[i][j-v][d-1];
                }
                if(j+v<=M){
                    if(arr[i][maxRight[i][j][d-1]] > arr[i][maxRight[i][j+v][d-1]]) maxRight[i][j][d] = maxRight[i][j][d-1];
                    else maxRight[i][j][d] = maxRight[i][j+v][d-1];
                }
                if(i>v){
                    if(arr[maxUp[i][j][d-1]][j] > arr[maxUp[i-v][j][d-1]][j]) maxUp[i][j][d] = maxUp[i][j][d-1];
                    else maxUp[i][j][d] = maxUp[i-v][j][d-1];
                }
                if(i+v<=M){
                    if(arr[maxDown[i][j][d-1]][j] > arr[maxDown[i+v][j][d-1]][j]) maxDown[i][j][d] = maxDown[i][j][d-1];
                    else maxDown[i][j][d] = maxDown[i+v][j][d-1];
                }
            }
        }
    }
}

short rowMax(int i, int l, int r){
    short cand1 = maxRight[i][l][power2[r-l]];
    short cand2 = maxLeft[i][r][power2[r-l]];
    return arr[i][cand1] > arr[i][cand2] ? cand1 : cand2;
}

short columnMax(int i, int l, int r){
    short cand1 = maxDown[l][i][power2[r-l]];
    short cand2 = maxUp[r][i][power2[r-l]];
    return arr[cand1][i] > arr[cand2][i] ? cand1 : cand2;
}

void findRangeRow(int i, int l, int r){
    if(l==r){
        rangeRow.push_back(make_pair(l, r));
        return;
    }
    int m = rowMax(i, l, r);
    if(m!=l) findRangeRow(i, l, m-1);
    if(m!=r) findRangeRow(i, m+1, r);
    rangeRow.push_back(make_pair(l, r));
}

void findRangeColumn(int i, int l, int r){
    if(l==r){
        if(arr[l][i] < min(arr[l-1][i], arr[r+1][i])){
            rangeColumn[i][columnDncLoc[l][r]].push_back(make_pair(l, r));
        }
        return;
    }
    int m = columnMax(i, l, r);
    if(m!=l) findRangeColumn(i, l, m-1);
    if(m!=r) findRangeColumn(i, m+1, r);
    if(arr[columnMax(i, l, r)][i] < min(arr[l-1][i], arr[r+1][i])){
        rangeColumn[i][columnDncLoc[l][r]].push_back(make_pair(l, r));
    }
}

inline bool cmp(pair<short, short> it1, pair<short, short> it2){
    if(it1.second != it2.second) return it1.second < it2.second;
    return it1.first > it2.first;
}

void findRange(){
    for(int i=2; i<M; i++){
        findRangeColumn(i, 2, N-1);
    }
}

vector<pair<short, short> > vec[LIM]; /// 각 열별로 가능한 위/아래 후보에 관한 정보를 담고 있게 될 벡터입니다.
vector<pair<short, short> > ret;

void mergeVec(int l, int r){
    for(auto it1 = vec[l].begin(), it2 = vec[r].begin(); it1 != vec[l].end() && it2 != vec[r].end(); ){
        if(cmp(*it1, *it2)) ++it1;
        else if(cmp(*it2, *it1)) ++it2;
        else{
            assert(*it1 == *it2);
            ret.push_back(*it1);
            ++it1;
            ++it2;
        }
    }
    vec[l].swap(ret);
    vec[r].clear();
    vec[r].shrink_to_fit();
    ret.clear();
    ret.shrink_to_fit();
}

void divideAndConquer(int l, int r){
    if(l>r) return;
    int m = (l+r)>>1;

    for(int i=2; i<M; i++){ /// vec 벡터를 전처리합니다.
        vec[i].swap(rangeColumn[i][m]);
        rangeColumn[i][m].clear();
        rangeColumn[i][m].shrink_to_fit();
    }

    rangeRow.clear();
    findRangeRow(m, 2, M-1);
    for(auto _pair: rangeRow){ /// m번 행이 무조건 포함된다고 고정하고, 가능한 (왼쪽 경계, 오른쪽 경계) 쌍을 시도합니다.
        int s = _pair.first, e = _pair.second; /// 왼쪽 경계와 오른쪽 경계입니다.
        int tl = m, tr = m; /// 이 두 변수는 가능한 위쪽 한계 / 아래쪽 한계를 담당합니다.

        while(tl > l){
            if(arr[tl-1][rowMax(tl-1, s, e)] < min(arr[tl-1][s-1], arr[tl-1][e+1])) tl--;
            else break;
        }
        while(tr < r){
            if(arr[tr+1][rowMax(tr+1, s, e)] < min(arr[tr+1][s-1], arr[tr+1][e+1])) tr++;
            else break;
        }

        if(s != e){
            int mid = rowMax(m, s, e);
            if(mid != s) mergeVec(s, mid);
            if(mid != e) mergeVec(s, mid+1);
        }

        if(arr[m][rowMax(m, s, e)] >= min(arr[m][s-1], arr[m][e+1])) continue;
        for(auto p: vec[s]){
            if(tl <= p.first && p.second <= tr) ans++;
        }
    }

    divideAndConquer(l, m-1);
    divideAndConquer(m+1, r);
}

 

USACO 2018 January. Sprinklers

스프링클러가 왼쪽 아래와 오른쪽 위 모두에 존재하는 칸들을 찾아봅시다. 먼저 자신의 오른쪽 위에 스프링클러가 존재하는 칸은 x좌표를 고정했을 때 $0 \le y \le r_x$인 점에 해당합니다. $r_x$는 쉽게 구할 수 있습니다. 마찬가지로 자신의 왼쪽 아래에 스프링클러가 존재하는 칸은 x좌표를 고정했을 때 $l_x \le y \le N$인 점에 해당하며, 이 역시 쉽게 구할 수 있습니다. 이때 중요한 관찰로 $l$과 $r$이 단조감소합니다.

 

이제 이 영역 내에서 만들 수 있는 직사각형의 개수를 세어 봅시다. 먼저 직사각형의 오른쪽 경계의 x좌표를 고정합니다. 이제 해당 x좌표를 오른쪽 끝으로 가지는 직사각형의 수를 세면 됩니다. 이 오른쪽 끝의 x좌표를 $b$라고 합시다. 우리는 직사각형의 위쪽 경계를 $e$, 아래쪽 경계를 $s$, 왼쪽 경계를 $a$라고 할 때 각 $b$에 대해 가능한 $(a, s, e)$의 개수를 더하면 됩니다. 이때 $f(a, s, e)$를 $a \le x \le b$, $s \le y \le e$인 사각형이 조건을 만족하면 1, 아니면 0인 함수로 정의합니다.

 

이때 $l_b \le s \le e \le r_b$가 되도록 $(s, e)$를 잡을 것임은 자명합니다. 그런데 중요한 사실은, 이 조건을 만족하는 순간 $f(a, s, e)$는 $e$와 무관한 값이 됩니다. $r$이 단조감소하기 때문에 $r_a \ge e$일 것이기 때문입니다. 따라서 각 $s$에 대해 가능한 최소의 $a$를 빠르게 구할 수 있다면, 답 역시 빠르게 구할 수 있습니다.

 

여기서 세그먼트 트리를 관리합니다. 세그먼트 트리에는 x좌표 하나를 볼 때마다 그 x좌표에 해당하는 $l_x$ 이상 $r_x$ 이하인 값에 +1을 해줄 것입니다. 이렇게 하면 $s$에 대해서 그 직사각형을 왼쪽으로 몇 칸 확장할 수 있는지 구할 수 있습니다. 세그먼트 트리에 하나의 값을 더 추가해, 구간에 대해서 $a_r + 2a_{r-1} + \cdots + {r-l+1}a_l$을 빠르게 구하도록 할 수 있고, 이 값까지 빠르게 구하면 문제를 $O(N \log N)$의 시간에 풀 수 있습니다. 

 

더보기
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll MOD = 1000000007;

inline ll s2(ll x){
    return (x * (x+1) / 2) % MOD;
}

struct segTree{
    struct Node{
        ll sum, val, len, lazy;
        Node(){
            sum = val = len = lazy = 0;
        }
        Node(ll sum, ll val, ll len, ll lazy): sum(sum), val(val), len(len), lazy(lazy){}

        Node operator+(const Node &r)const{
            return Node((sum+r.sum)%MOD, (val+sum*r.len+r.val)%MOD, len+r.len, 0);
        }
    } tree[400002];

    void init(int i, int l, int r){
        tree[i].len = r-l+1;
        if(l==r) return;
        int m = (l+r)>>1;
        init(i*2, l, m);
        init(i*2+1, m+1, r);
    }

    void propagate(int i, int l, int r){
        tree[i].sum += tree[i].lazy * (r-l+1) % MOD;
        tree[i].val += tree[i].lazy * s2(r-l+1) % MOD;
        if(l!=r){
            tree[i*2].lazy += tree[i].lazy;
            tree[i*2+1].lazy += tree[i].lazy;
        }
        tree[i].lazy = 0;
    }

    void update(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if(r<s || e<l) return;
        if(s<=l && r<=e){
            tree[i].lazy = 1;
            propagate(i, l, r);
            return;
        }
        int m = (l+r)>>1;
        update(i*2, l, m, s, e);
        update(i*2+1, m+1, r, s, e);
        tree[i] = tree[i*2] + tree[i*2+1];
    }

    Node query(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if(r<s || e<l) return Node();
        if(s<=l && r<=e) return tree[i];
        int m = (l+r)>>1;
        return query(i*2, l, m, s, e) + query(i*2+1, m+1, r, s, e);
    }
} tree;

int n;
int arr[100002];
int l[100002], r[100002];
ll ans;

int main(){
    scanf("%d", &n);
    for(int i=1; i<=n; i++){
        int x, y;
        scanf("%d %d", &x, &y);
        arr[x+1] = y+1;
    }

    l[0] = n+1;
    for(int i=1; i<=n; i++){
        l[i] = min(l[i-1], arr[i]);
    }
    for(int i=n-1; i>=1; i--){
        r[i] = max(r[i+1], arr[i+1]-1);
    }

    tree.init(1, 1, n);
    for(int i=1; i<n; i++){
        if(l[i] > r[i]) continue;
        tree.update(1, 1, n, l[i], r[i]);
        ans += tree.query(1, 1, n, l[i], r[i]).val;
        ans %= MOD;
    }
    printf("%lld", ans);
}

 

JOIOC 2021. Monster Game

무지성으로 머지 소트를 돌리면, 수들이 "거의 정렬됨"을 관찰할 수 있습니다. 거의 정렬된다는 것은, 수들이 아래와 같은 형태를 보인다는 것을 말합니다.

 

  • 3 2 1     6 5 4    8 7
  • 1    3 2     8 7 6 5 4

 

머지 소트를 돌리면서 최대 $N \log N - N$번의 쿼리를 날렸을 것이니, 앞으로 약 $N$번 정도의 쿼리밖에 더 사용하지 못합니다. 위와 같이 거의 정렬된 수열을 적은 횟수의 쿼리만을 사용하여 추가로 정렬하는 것이 목표입니다.

 

일단 맨 첫 번째 수에서 쿼리를 계속 날려 가며, 자기보다 작은 수가 나왔다가 다시 자기보다 큰 수가 나오는 순간 종료합니다. 이때 자기보다 작은 마지막 수를 $r$이라고 합시다. 또 이 과정 동안 찾은 자신보다 작은 수의 개수를 $cnt$라고 합시다.

 

  • Case 1. $cnt \ge 2$인 경우
    이 경우 첫 수는 2 이상임을 알 수 있습니다. (추가적으로, 첫 수가 2일 경우 두번째 수는 1, 0이 자동결정되고, 세번째 수가 3일 수 없음을 알 수 있습니다.)

    여기서 $cnt$는 첫 수와 같거나, (첫 수-1)임을 알 수 있습니다. 첫 수를 $x$라고 하면, $x$보다 작은 수의 목록은 $0 \cdots x-2$, $x+1$입니다. 그런데 만약 $x$에서 시작해 $0$까지 내려간 다음에 나오는 수가 $x+1$이라면 $cnt=x$일 것이고, 아니면 $cnt=x-1$일 것입니다.

    두 경우를 분리하는 방법은 다음과 같습니다. $r$번 수는 0 또는 $x+1$입니다. $r$번 수가 0일 때 $r-2$번 수는 2이고, $r$번 수가 $x+1$일 때 $r-2$번 수는 1입니다. 이 둘을 비교하면 어떤 경우에 해당하는지를 알 수 있고, 0부터 $x$ (또는 $x+1$)까지의 수를 정렬할 수 있습니다.

    이 다음은 매우 간단합니다. 편의상 0부터 $x$까지의 수를 정렬했고, $x+1$ 이상의 수가 오른쪽 끝에 몰려 있는 상황이라고 가정합시다. 그럼 $x$보다 작은 수가 처음으로 나올 때까지, 그 오른쪽을 계속 탐방하면 됩니다. $x$보다 작은 수가 나오면 그 수는 $x+1$임을 알 수 있고, 거기까지 뒤집은 뒤 똑같은 과정을 반복하면 됩니다. 이렇게 하면 $O(N)$번의 쿼리로 정리가 가능합니다.

  • Case 2. $cnt=1$, $r \ge 3$인 경우
    이 경우는 아래 두 경우 중 하나입니다.
    1) $0, x, x-1, \cdots, 2, 1$
    2) $1, 0, x, x-1, \cdots 3, 2$

    이 둘을 분류하는 방법은, 두 번째 수와 마지막 수를 비교하면 간단히 알 수 있습니다. 나머지는 Case 1과 같은 방식으로 풀립니다.

  • Case 3. $cnt=1$, $r=1$인 경우
    이 경우는 첫 수가 0, 두 번째 수가 1인 경우뿐이므로 첫 두 수를 제거하고 재귀적으로 반복하면 됩니다.

  • Case 4. $cnt=1$, $r=2$인 경우
    이 경우 첫 세 수는 (1, 0, 2), (0, 2, 1), (2, 1, 0) 중 하나입니다. 따라서 이 세 수를 제외한 나머지 수들을 먼저 정렬한 뒤, (정렬이 끝나면 3이 맨 왼쪽에 왔을 것이므로) 각 수와 3을 비교해 3보다 큰 수를 2로 확정짓고 나머지 수를 배치하면 됩니다.

    이런 식으로 재귀적으로 수들을 제거하는 과정에서 $n=3$이면 수를 정렬할 수 없는데, 이를 방지하기 위해 $n \le 7$이면 $O(N^2)$에 정렬하도록 하면 문제가 해결됩니다.

 

더보기
#include <bits/stdc++.h>
#include "monster.h"

using namespace std;

typedef long long ll;

namespace {
    int n;
    int arr[1002];
    int arr2[1002];

    void Sort(int l, int r){
        if(l>=r) return;
        int m = (l+r)>>1;
        Sort(l, m);
        Sort(m+1, r);

        for(int i=l; i<=r; i++) arr2[i] = arr[i];
        for(int i=l, j=m+1, k=l; k<=r; k++){
            if(i==m+1) arr[k] = arr2[j++];
            else if(j==r+1) arr[k] = arr2[i++];
            else if(Query(arr2[i], arr2[j]) == 0) arr[k] = arr2[i++];
            else arr[k] = arr2[j++];
        }
    }

    void calculateElse(int tmp, int st){
        for(int i=st; i<n; i++){
            if(Query(arr[tmp], arr[i])){
                reverse(arr+tmp+1, arr+i+1);
                tmp = i;
            }
        }
    }

    void Reorder(int l, int r){
        if((r-l) <= 7){
            int cnt[1005] = {0};
            for(int i=l; i<=r; i++){
                for(int j=i+1; j<=r; j++){
                    if(Query(arr[i], arr[j])) cnt[arr[i]]++;
                    else cnt[arr[j]]++;
                }
            }
            sort(arr+l, arr+r+1, [&](int x, int y){
                return cnt[x] < cnt[y];
            });
            if(Query(arr[l], arr[l+1]) == 0) swap(arr[l], arr[l+1]);
            if(Query(arr[r], arr[r-1])) swap(arr[r], arr[r-1]);
            return;
        }

        int comp[1005] = {0};
        int cnt = 0;
        int lim = r;
        for(int i=l+1; i<=r; i++){
            if(Query(arr[l], arr[i])) comp[i] = 1, cnt++;
            if(comp[i-1] && !comp[i]){
                lim = i-1;
                break;
            }
        }
        if(cnt >= 2){ /// Case 1
            assert(lim >= l+3);
            if(Query(arr[lim], arr[lim-2])) reverse(arr+l, arr+lim);
            else reverse(arr+l, arr+lim+1);
            calculateElse(lim, lim+1);
            return;
        }
        else if(lim >= l+3){ /// Case 2
            if(Query(arr[l+1], arr[lim])) reverse(arr+l+1, arr+lim+1);
            else reverse(arr+l+2, arr+lim+1), swap(arr[l], arr[l+1]);
            calculateElse(lim, lim+1);
        }
        else if(lim==1){ /// Case 3
            Reorder(l+2, r);
        }
        else{ /// Case 4
            Reorder(l+3, r);
            if(Query(arr[l], arr[l+3])){
                reverse(arr+l, arr+l+3);
            }
            else if(Query(arr[l+1], arr[l+3])){
                swap(arr[l+1], arr[l+2]);
            }
            else{
                swap(arr[l], arr[l+1]);
            }
        }
    }
}

vector<int> Solve(int N){
    n = N;
    for(int i=0; i<n; i++) arr[i] = i;
    random_shuffle(arr, arr+n);
    Sort(0, n-1);

    Reorder(0, n-1);

//    for(int i=0; i<n; i++) printf("%d ", arr[i]);
//    puts("");

    vector<int> idx (n);
    for(int i=0; i<n; i++) idx[arr[i]] = i;
    return idx;
}

 

'코딩 > 문제풀이' 카테고리의 다른 글

BOJ 7469 K번째 수 (in Q log N!)  (3) 2021.08.21
Problem Solving Diary #2  (0) 2021.06.15
Problem Solving Diary #1  (0) 2021.06.11

태그