← All Articles

[BAEKJOON] 11437. LCA

Posted on

문제 :

N(2 ≤ N ≤ 50,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.

두 노드의 쌍 M(1 ≤ M ≤ 10,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력한다.


입력 :

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정점 쌍이 주어진다.


출력 :

M개의 줄에 차례대로 입력받은 두 정점의 가장 가까운 공통 조상을 출력한다.




풀이 :

원래는 이 문제가 아닌 다른 그래프 이론 문제를 풀다가 LCA 알고리즘을 활용해야 한다는 것을 알게 된 후 차근차근 LCA부터 공부해보자라는 취지로 문제를 풀어보았다.

  • LCA (Lowest Common Ancester) Algorithm
  • LCA 알고리즘은 트리에서 두 노드의 최소 공통 부모 노드를 찾는 알고리즘을 말한다. 알고리즘의 경우 로직이 크게 어려운 것은 없으나 만약 LCA 알고리즘에 대해 전혀 모른다면 문제를 조금만 변형하더라도 어렵게 다가올 수 있겠다 싶은 생각이 들었다.
  • LCA 알고리즘은 우선 가장 핵심 포인트는 두 노드의 depth 중 더 낮은 depth 의 노드로 깊이를 맞춰줘야 한다는 점이다.
  • 가령, 1번 노드가 depth가 4이고, 2번 노드가 depth가 6일 경우 2번 노드를 depth가 4가 될 때까지 부모 노드로 타고 올라와서 계산해주어야 한다.
  • depth를 같게 맞춰줬을 때부턴 차례로 같은 depth의 노드가 동일한지 검사해보면 된다.

난 본 문제의 테스트케이스가 꽤 크다 생각하여 dp 방식으로 각 노드별 부모 노드를 모두 저장하여 계산해 주려 했다.

하지만 본 문제는 dp를 굳이 사용하지 않아도 풀 수 있었으며 오히려 dp 를 사용하게 되면 메모리 초과 결과가 출력되었다.




코드 :

코드 보기/접기
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <utility>

#define pii pair<int, int>

using namespace std;
vector<vector<int>> adj;
int *parent, *depth;

void bfs(int n) {
    bool *visit = new bool[n + 1]{false};
    queue<pii > q;
    q.push(pii(1, 1));
    depth[1] = 1;

    while (!q.empty()) {
        int curidx = q.front().first, curdepth = q.front().second, size = adj[curidx].size();
        q.pop();
        visit[curidx] = true;
        for (int k = 0; k < size; k++) {
            int cmp = adj[curidx][k];
            if (visit[cmp]) continue;
            parent[cmp] = curidx;
            depth[cmp] = curdepth + 1;
            q.push(pii(cmp, curdepth + 1));
        }
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int n, m, i, fir, sec, idx;
    cin >> n;
    parent = new int[n + 1];
    depth = new int[n + 1];
    adj.resize(n + 1);
    for (i = 1; i < n; i++) {
        cin >> fir >> sec;
        adj[fir].push_back(sec);
        adj[sec].push_back(fir);
    }

    bfs(n);
    cin >> m;
    while (m--) {
        cin >> fir >> sec;
        idx = min(depth[fir], depth[sec]);
        if (depth[fir] != idx)
            for (int i = depth[fir]; i > idx; i--)
                fir = parent[fir];
        else if (depth[sec] != idx)
            for (int i = depth[sec]; i > idx; i--)
                sec = parent[sec];
        for (; idx >= 0; idx--) {
            if (fir == sec) {
                cout << fir << '\n';
                break;
            }
            fir = parent[fir];
            sec = parent[sec];
        }
    }
    return 0;
}

AlgorithmalgorithmbaekjoonC++