알고리즘/[ Baekjoon ]

[ BOJ ][JAVA][2213] 트리의 독립집합

kim.svadoz 2021. 4. 21. 23:18
반응형

www.acmicpc.net/problem/2213

 

2213번: 트리의 독립집합

첫째 줄에 트리의 정점의 수 n이 주어진다. n은 10,000이하인 양의 정수이다. 1부터 n사이의 정수가 트리의 정점이라고 가정한다. 둘째 줄에는 n개의 정수 w1, w2, ..., wn이 주어지는데, wi는 정점 i의

www.acmicpc.net

시간 제한 메모리 제한 제출 정답 맞은 사람 정답 비율
2 초 128 MB 2845 1367 1039 49.500%

문제

그래프 G(V, E)에서 정점의 부분 집합 S에 속한 모든 정점쌍이 서로 인접하지 않으면 (정점쌍을 잇는 에지가 없으면) S를 독립 집합(independent set)이라고 한다. 독립 집합의 크기는 정점에 가중치가 주어져 있지 않을 경우는 독립 집합에 속한 정점의 수를 말하고, 정점에 가중치가 주어져 있으면 독립 집합에 속한 정점의 가중치의 합으로 정의한다. 독립 집합이 공집합일 때 그 크기는 0이라고 하자. 크기가 최대인 독립 집합을 최대 독립 집합이라고 한다.

문제는 일반적인 그래프가 아니라 트리(연결되어 있고 사이클이 없는 그래프)와 각 정점의 가중치가 양의 정수로 주어져 있을 때, 최대 독립 집합을 구하는 것이다.

입력

첫째 줄에 트리의 정점의 수 n이 주어진다. n은 10,000이하인 양의 정수이다. 1부터 n사이의 정수가 트리의 정점이라고 가정한다. 둘째 줄에는 n개의 정수 w1, w2, ..., wn이 주어지는데, wi는 정점 i의 가중치이다(1 ≤ i ≤ n). 셋째 줄부터 마지막 줄까지는 에지 리스트가 주어지는데, 한 줄에 하나의 에지를 나타낸다. 에지는 정점의 쌍으로 주어진다. 입력되는 정수들 사이에는 콤마가 없고 대신 빈칸이 하나 혹은 그 이상 있다. 가중치들의 값은 10,000을 넘지 않는 자연수이다.

출력

첫째 줄에 최대 독립집합의 크기를 출력한다. 둘째 줄에는 최대 독립집합에 속하는 정점을 오름차순으로 출력한다. 최대 독립 집합이 하나 이상일 경우에는 하나만 출력하면 된다.

예제 입력 1

7
10 30 40 10 20 20 70
1 2
2 3
4 3
4 5
6 2
6 7

예제 출력 1

140
1 3 5 7

접근

트리의 독립집합

  1. 기본적으로 1번노드에서 시작해 그래프를 탐색하는데 , 현재노드를 선택과 비선택 두가지 가지로 나누어 메모이제이션을 채운다.
  2. 지금 노드에서 선택했다면 인접한 노드는 선택하면 안되고, 선택하지 않았다면 선택해도 되고 안해도 된다.
    트리는 비선형 구조이다. 탐색 순서를 정하기 위해서 dfs 트리를 만들어준다.
    이 때 list[]는 입력 데이터를 이용해 만든 인접리스트,
    p[][]는 dfs트리를 저장할 인접리스트이다.
    dp[1][0] 은 i번 노드를 루트로 하는 서브트리에서 i노드를 포함하지 않는 경우의 답
    dp[1][1] 은 i번 노드를 루트로 하는 서브트리에서 i노드를 포함하는 경우의 답 으로 정의.

코드

처음 푼 풀이

import java.io.*;
import java.util.*;

public class p2213 {
    static BufferedReader br;
    static StringTokenizer st;
    static int n, w[], dp[][];
    static List<Integer> list[];
    static List<Pair> p[][];
    static List<Integer> ans;
    static boolean visit[];
    public static void main(String[] args) throws IOException {
        br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());

        dp = new int[2][n + 1]; // dp 배열
        w = new int[n + 1]; // 가중치 배열
        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; ++i) {
            w[i] = Integer.parseInt(st.nextToken());
        }

        list = new ArrayList[n + 1]; // 문제에 주어진 그래프
        p = new ArrayList[n + 1][2]; // dp의 계산을 위한 리스트배열

        // 0 : false, 1 : true
        for (int i = 1; i <= n; ++i) {
            list[i] = new ArrayList<>();
            p[i][0] = new ArrayList<Pair>();
            p[i][1] = new ArrayList<Pair>();
        }

        for (int i = 1; i < n; ++i) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            list[u].add(v);
            list[v].add(u);
        }

        int start = Solution();

        // ans에 정답 리스트가 들어있음.
        ans = new ArrayList<Integer>();
        StringBuilder sb = new StringBuilder();

        sb.append(getValue(1, start)).append('\n');

        Collections.sort(ans);
        for (int x : ans) {
            sb.append(x).append(' ');
        }
        System.out.println(sb.toString());
    }

    static int Solution() {
        Arrays.fill(dp[0], -1);
        Arrays.fill(dp[1], -1);
        visit = new boolean[n + 1];
        if (dfs(1, true) > dfs(1, false)) {
            return 1;
        } else {
            return 0;
        }
    }

    static int dfs(int n, boolean t) {
        int tt = t ? 1: 0;
        if(dp[tt][n] != -1) {
            return dp[tt][n];
        }
        int sum = t ? w[n] : 0;
        visit[n] = true;
        if(t) {
            // true라면 다음 노드는 반드시 false를 pick
            for (int next : list[n]) {
                if(!visit[next]) {
                    p[n][tt].add(new Pair(next, 0));
                    sum += dfs(next, false);
                }
            }
        } else {
            // false라면 다음 노드는 true or false pick
            for (int next : list[n]) {
                if (visit[next]) continue;
                if (dfs(next, false) > dfs(next, true)) {
                    p[n][tt].add(new Pair(next, 0));
                    sum += dfs(next, false);
                } else {
                    p[n][tt].add(new Pair(next, 1));
                    sum += dfs(next, true);
                }
            }
        }
        visit[n] = false;
        return dp[tt][n] = sum;
    }

    static int getValue(int n, int t) {
        int sum = 0;
        if (t == 1) {
            ans.add(n);
            sum = w[n];
        }
        for (Pair next : p[n][t]) {
            sum += getValue(next.node, next.t);
        }
        return sum;
    }

    static class Pair {
        int node;
        int t;

        Pair(int node, int t) {
            this.node = node;
            this.t = t;
        }
    }
}

더 좋은 풀이

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.StringTokenizer;
public class Main {
    private static PriorityQueue pq = new PriorityQueue();
    private static ArrayList<Integer>[] edges;
    private static int[] vertex;
    private static int[][] dp;
    private static final String NEW_LINE = "\n";
    private static final String SPACE = " ";
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());
        vertex = new int[N + 1];
        edges = new ArrayList[N + 1];
        dp = new int[N + 1][3];

        StringTokenizer st = new StringTokenizer(br.readLine());
        for(int i = 0; i <= N; i++) {
            if(i != 0) vertex[i] = Integer.parseInt(st.nextToken());
            edges[i] = new ArrayList<>();
            dp[i][0] = dp[i][1] = dp[i][2] = -1;
        }
        int loop = N - 1;
        while(loop-- > 0) {
            st = new StringTokenizer(br.readLine());
            int node1 = Integer.parseInt(st.nextToken());
            int node2 = Integer.parseInt(st.nextToken());
            edges[node1].add(node2);
            edges[node2].add(node1);
        }
        StringBuilder sb = new StringBuilder();
        int[] result = {recursion(0, 1, 0), recursion(0, 1, 1)};
        sb.append(Math.max(result[0], result[1])).append(NEW_LINE);
        if(result[1] >= result[0]) vertexTrace(0, 1, 1);
        else vertexTrace(0, 1, 0);
        while(!pq.isEmpty()) {
            sb.append(pq.poll()).append(SPACE);
        }
        System.out.println(sb.toString());
    }
    private static void vertexTrace(int prev, int current, int select) {
        if(select == 1) pq.offer(current);
        for(int next: edges[current]) {
            if(prev == next) continue;
            int left = recursion(current,  next, 0);
            int right = recursion(current,  next, 1);
            if (left > right) {
                vertexTrace(current, next, 0);
            }
            else {
                if(select == 1) vertexTrace(current, next, 0);
                else vertexTrace(current, next, 1);
            }
        }
    }
    private static int recursion(int prev, int current, int select) {
        int result = dp[current][select];
        if(result != -1) return result;
        result = select == 1 ? vertex[current]: 0;
        for(int next: edges[current]) {
            if(prev == next) continue;
            if(select == 0) result += Math.max(recursion(current, next, 0), recursion(current, next, 1));
            else result += recursion(current, next, 0);
        }
        return dp[current][select] = result;
    }
}
반응형

'알고리즘 > [ Baekjoon ]' 카테고리의 다른 글

[ BOJ ][JAVA][2251] 물통  (0) 2021.04.22
[ BOJ ][JAVA][2231] 분해합  (0) 2021.04.21
[ BOJ ][JAVA][2193] 이친수  (0) 2021.04.21
[ BOJ ][JAVA][2186] 문자판  (0) 2021.04.21
[ BOJ ][JAVA][2178] 미로 탐색  (0) 2021.04.21