티스토리 뷰

문제 링크

풀이

문제를 보면 두 정점 사이의 최단 거리를 빠르게 구하는 유형이다. 그런데 일반 그래프에서 이건 엄청 어렵다. 따라서 주어진 그래프를 조금 제한적인 형태, 예를 들면 트리로 바꿀 수 있지 않을까 추측해볼 수 있다.

그리고 실제로도 주어진 상황을 트리로 변형할 수 있다. 왜 그런지 생각해보면 나름 자명한데, $(a, b)$에서 $(c, d)$로 이동할 때 다음과 같은 그리디한 전략이 통하기 때문이다.

  • $(a, b)$에서 시작하여, 현재 지점에서 High Jump를 하고 $(c, d)$로 이동할 수 없는 동안, 다음을 반복한다.
    • High Jump를 하고, 도달할 수 있는 점들을 보자.
    • 이들 중 높이가 제일 높은 점으로 이동한다.
  • 만약 High Jump를 통해 $(c, d)$로 도달할 수 있다면, 이동한다.

증명의 직관은, $(a, b)$에서 더 높은 점 $(x, y)$로 가면, $(x, y)$에서 갈 수 있는 점들은 $(a, b)$에서 갈 수 있는 점을 포함한다는 데에서 유래한다. 그렇다면 트리를 어떻게 빨리 만들까? 점들의 높이를 투 포인터로 관리하고, Union Find에 그룹 최댓값과 위치 등을 관리하는 식으로 어렵지 않게 할 수 있다.

따라서 각 $(a, b)$별로 갈 수 있는 가장 높은 점을 빨리 찾고, $(a, b) \rightarrow (x, y)$로 올라가는 간선을 통해 트리를 만든 후, $(a, b)$에서 LCA sparse table로 이분 탐색을 하며 $(c, d)$로 갈 수 있는 가장 깊이가 깊은 점을 찾으면 된다. $(c, d)$로 갈 수 있는지는 Union Find Tree를 구축하고 dfs ordering을 통해 정점별 in/out 값을 통해 서브트리 내부에 있는지 판별을 하면 된다.

시간 복잡도는 $O((Q + NM) \log NM)$ 근처일 것이다.

구현상 주의해줄 점은 똑같은 높이의 산이 여러 개 있을 때 Union Find에서 무엇이 부모로 올라오는지 순서를 지정해주는 것 등이 있다.


#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

struct Point{
    int x; ll v;
    Point(){}
    Point(int x, ll v): x(x), v(v){}
    bool operator<(const Point r)const{
        return make_pair(v, x)<make_pair(r.v,r.x);
    }
};

struct UnionFind{
    int n;
    int par[300005]; ll mx[300005]; int maxLoc[300005];
    int parEdge[300005];
    vector<int> childEdge[300005];
    int in[300005], out[300005], inCnt;

    void init(int _n, ll *A){
        n = _n;
        for(int i=1; i<=n; i++){
            par[i] = i;
            mx[i] = A[i];
            maxLoc[i] = i;
            parEdge[i] = 0;
        }
    }

    int find(int x){
        if(x==par[x]) return x;
        return par[x] = find(par[x]);
    }

    void merge(int x, int y){
//        printf("Merge request: %d %d -> %d %d\n", x, y, find(x), find(y));
        x = find(x), y = find(y);
        if(x==y) return;
        par[y] = x;
        parEdge[y] = x;
        if(mx[x] < mx[y] || (mx[x] == mx[y] && maxLoc[x] < maxLoc[y])) mx[x] = mx[y], maxLoc[x] = maxLoc[y];
    }

    void dfs(int x){
        in[x] = ++inCnt;
        for(int y: childEdge[x]) dfs(y);
        out[x] = inCnt;
    }

    void do_dfs(){
        int root = -1;
        for(int i=1; i<=n; i++){
            if(!parEdge[i]) root = i;
            else childEdge[parEdge[i]].push_back(i);
        }
        assert(root != -1);
        dfs(root);
    }
} dsu;

int n, m, nm; ll L;
ll arr[300005];
int par[300005];
bool vis[300005];

void addPoint(int x){
    vis[x] = 1;
    for(int ng: {x%m==1 ? -1 : x-1, x%m ? x+1 : -1, x-m, x+m}){
        if(ng<1||ng>nm||!vis[ng]) continue;
        if(make_pair(arr[x], x) > make_pair(arr[ng], ng)) dsu.merge(x, ng);
        else dsu.merge(ng, x);
    }
}

vector<int> child[300005];
int depth[300005], sps[300005][20];

void dfs(int x){
    sps[x][0] = par[x];
    for(int y: child[x]){
        depth[y] = depth[x] + 1;
        dfs(y);
    }
}

int getLCA(int x, int y){
    if(depth[x] > depth[y]) swap(x, y);
    for(int d=0; d<20; d++) if((depth[y] - depth[x]) & (1<<d)) y = sps[y][d];
    if(x==y) return x;
    for(int d=19; d>=0; d--) if(sps[x][d] != sps[y][d]) x = sps[x][d], y = sps[y][d];
    return par[x];
}

int main(){
    scanf("%d %d %lld", &n, &m, &L);
    nm = n*m;
    for(int i=1; i<=nm; i++) scanf("%lld", &arr[i]);

    /// 트리 만들기
    vector<Point> vec;
    for(int i=1; i<=nm; i++) vec.push_back(Point(i, arr[i]));
    sort(vec.begin(), vec.end());
    dsu.init(n*m, arr);

    int tmp = 0;
    for(Point &p: vec){
        while(tmp < nm && p.v + L >= vec[tmp].v) addPoint(vec[tmp++].x);
        par[p.x] = dsu.maxLoc[dsu.find(p.x)];
    }
    dsu.do_dfs();

    /// 트리 정리
    vector<int> roots;
    for(int i=1; i<=nm; i++){
        if(par[i] == i) roots.push_back(i);
        else child[par[i]].push_back(i);
    }
    for(int root: roots) dfs(root);
    for(int d=1; d<20; d++) for(int i=1; i<=nm; i++) sps[i][d] = sps[sps[i][d-1]][d-1];

    function<bool(int, int)> inside = [&](int z, int q){
        return dsu.in[z] <= dsu.in[q] && dsu.in[q] <= dsu.out[z];
    };

    int q;
    scanf("%d", &q);
    while(q--){
        int a, b, c, d;
        scanf("%d %d %d %d", &a, &b, &c, &d);
        int p = (a-1)*m+b, q = (c-1)*m+d;

        if(inside(par[p], q)) {puts("1"); continue;}

        int x = p;
        for(int d=19; d>=0; d--){
            int z = sps[x][d];
            if(inside(z, q)) continue;
            x = z;
        }

        if(inside(x, q)) exit(1);
        else if(!inside(par[x], q)) puts("-1");
        else printf("%d\n", depth[p] - depth[x] + 1);
    }
}

공지사항
최근에 올라온 글
Total
Today
Yesterday