본문 바로가기
Algorithm/BAEKJOON

[백준/C++] 트리의 독립집합 (No. 2213)

by code_pie 2024. 6. 9.

 

 

 

풀이

 

[문제 풀이]

 

처음에 이 문제를 봤을 때, 아무 생각 없이 하나의 노드를 기준으로 두번 검사하면 되겠다는 생각을 했다.

 

하지만 예제의 입력을 보자마자 착각했다는 것을 깨닫고 다시 풀었다...

 

새로 푼 방법은 두가지 경우로 나눠서 계산하는 방법으로 트리를 그리면서 계산하는 방법이다.

 

먼저 그래프가 트리의 형태이므로 한 정점을 기준으로 트리를 그린다.

 

이후 현재 노드와 하위 노드들(자식 노드들)의 값을 포함해 만들 수 있는 독립집합의 최대값을 DP라는 배열에 저장한다.

 

DP라는 배열을 그리는 이유는 중복된 계산을 줄이기 위해서다.

 

어떤 식으로 중복된 계산을 줄였는지 아래 그림을 참고하자.

 

예제

 

파란색으로 표시된 경우는 현재 노드의 가중치 비용을 더한 값이다.

 

즉, 현재 노드가 독립 집합에 포함된 경우를 의미한다.

 

현재 노드가 독립집합에 포함된 경우를 생각해보면, 하위에 있는 노드들은 선택이 되면 안된다.

그러므로 현재 노드가 독립집합에 포함된 경우에 만들 수 있는 최대값은 [현재 노드의 가중치 + 하위 노드들이 선택되지 않은 경우의 가중치합] 이 된다.

 

반대로 현재 노드가 독립 집합에 포함되지 않은 경우는 어떻게 구할 수 있을까?

 

현재 노드가 독립집합에 포함되지 않은 경우에는 하위 자식을 포함해도 되고 포함하지 않아도 된다.

그러므로 현재 노드가 독립집합에 포함되지 않은 경우에는 [하위 노드들이 선택되거나 선택되지 않은 경우의 최대 가중치들을 더한 값] 이 현재 노드가 독립집합에 포함되지 않은 경우의 최대값이 된다.

 

이렇게 최대값을 이용해 계산하면 루트 노드에는 루트노드를 포함한 경우의 최대값, 루트노드를 포함하지 않은 경우의 최대값이 저장된다.

 

 

마찬가지로 독립집합의 원소들을 찾는 방법도 비슷하다.

 

이제 두 값을 비교해 최대값을 선택하고, 다시 역으로 어떤 노드들이 선택 되었는지 탐색해 나가면 된다.

 

만약 루트노드가 포함된 경우가 최대값이었다면, 하위 노드들은 선택되면 안된다.

(독립집합에 포함되면 안되므로)

 

하위노드들은 선택 되지 않았으므로, 선택되지 않은 하위 노드들의 자식은 독립집합에 포함될 수 있고, 독립집합에 포함되지 않을 수 있다.

 

즉, 자식노드의 두 경우 중 최대값을 선택하면 된다.

 

만약 자식노드가 포함된 경우가 값이 더 크다면, 자식노드는 독립집합에 포함된 것이다.

만약 자식노드가 포함되지 않은 경우가 더 값이 크다면, 자식 노드는 독립집합에 포함되지 않는다.

 

이를 이용해 어떤 자식노드가 포함되었는지 DP배열을 역으로 탐색해 나가면서 구하면 된다.

 

 

[아이디어 정리]

  1. 특정 노드를 기준으로 트리를 그리고, 트리를 그릴 때 현재노드가 포함된 경우의 최대값, 포함되지 않은 경우의 최대값을 계산한다.
  2. 트리를 그리고 나면 루트노드에는 루트노드가 포함된 경우의 독립집합의 최대값, 루트노드가 포함되지 않은 경우의 독립집합의 최대값이 저장된다.
  3. 어떤 원소가 독립집합에 포함되었는지는 최대값을 이용해 계산할 수 있다.
  4. 현재 노드가 독립집합에 포함된 경우에는 자식 노드는 독립집합에 포함될 수 없다.
  5. 현재 노드가 독립집합에 포함되지 않은 경우에는 자식노드가 포함된 경우의 최대값과 자식노드가 포함되지 않은 경우의 최대값을 비교한다.
  6. 만약 자식노드가 포함된 경우의 최대값이 더 크다면, 그 자식노드를 독립집합에 포함하면 된다.

 

 

 

Code

 

 

#include <string>
#include <vector>
#include <iostream>
#include <queue>
#include <cmath>
#include <algorithm>
#include <unordered_map>
#include <set>
using namespace std;

void MakeTree(int nNode, vector<vector<int>>& graph, vector<int>& par, vector<int>& cost, vector<vector<int>>& DP)//비용 return
{
    // 0은 나를 포함, 1은 나를 포함 x
    DP[0][nNode] = cost[nNode];
    DP[1][nNode] = 0;
    int nextN;
    for (int i = 0; i < graph[nNode].size(); i++)
    {
        // 나를 포함하는 경우
        nextN = graph[nNode][i];
        if (par[nNode] != nextN) {
            par[nextN] = nNode;
            MakeTree(nextN, graph, par, cost, DP);
            DP[0][nNode] += DP[1][nextN];
        }
    }

    for (int i = 0; i < graph[nNode].size(); i++)
    {
        // 나를 포함하지 않는 경우
        nextN = graph[nNode][i];
        if (par[nNode] != nextN) {
            DP[1][nNode] += max(DP[1][nextN], DP[0][nextN]);
        }
    }
    return;
}

void FindA(int nNode, int idx, vector<vector<int>>& graph, vector<int>& par, vector<int>& cost, vector<vector<int>>& DP, vector<int>& answer)
{
    if (idx == 0) {
        answer.push_back(nNode);
    }
    int nextN;
    for (int i = 0; i < graph[nNode].size(); i++) {
        nextN = graph[nNode][i];
        if (par[nNode] != nextN) { 
            if (idx == 0) { //나를 포함했으므로 무조건 자식은 안 포함된 경우로
                FindA(nextN, 1, graph, par, cost, DP, answer);
            }
            else {
                if (DP[0][nextN] >= DP[1][nextN]) { // 자식이 독립집합에 포함될 경우
                    FindA(nextN, 0, graph, par, cost, DP, answer);
                }
                else {
                    FindA(nextN, 1, graph, par, cost, DP, answer);
                }
            }
        }
    }

}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);
    int N,st,ed;
    cin >> N;
    vector<int> cost(N + 1);
    for (int i = 1; i <= N; i++) {
        cin >> cost[i];
    }
    vector<vector<int>>graph(N + 1);
    vector<int> par(N + 1, 0);
    for (int i = 1; i < N; i++) {
        cin >> st >> ed;
        graph[st].push_back(ed);
        graph[ed].push_back(st);
    }
    vector<int>answer;
    vector<vector<int>> DP(2, vector<int>(N + 1));
    MakeTree(1, graph, par, cost, DP);
    if (DP[0][1] > DP[1][1]) {
        cout << DP[0][1]<< "\n";
        FindA(1, 0, graph, par, cost, DP, answer);
    }
    else {
        cout << DP[1][1] << "\n";
        FindA(1, 1, graph, par, cost, DP, answer);
    }
    sort(answer.begin(), answer.end());
    for (int i = 0; i < answer.size(); i++) {
        cout << answer[i] << " ";
    }

    return 0;
}

 


처음에 문제를 잘못 생각해서 코드가 길어진 부분이 있다;;

덕분에 루트노드인 1의 최대값을 비교하는 부분에서 DP[0][1]>DP[0][1]과 같이 같은 값을 비교해서 시간이 오래 걸렸다;;

 

그냥 처음에 문제를 잘못 풀면 비슷한 부분이 있어도 아예 지우고 푸는게 시간적으로 더 나은것 같다...

 

https://www.acmicpc.net/problem/2213

 

반응형