백준 14595 동방 프로젝트(Large)

문제 링크

  • https://www.acmicpc.net/problem/14595

알고리즘

  • 분리 집합
  • 정렬
  • 그리디

풀이

이 문제는 분리 집합을 사용해서 풀 수도 있지만, 그냥 단순 정렬과 그리디 방법으로도 풀 수 있다.

동방 프로젝트(Small)를 풀고 바로 시도했던 문제다. 무려 4개월 전에 풀었던 문제인데 기억에 남아 정리해보려고 한다.

처음엔 Small 버전에서 제출했던 방식 그대로 시도했다가 당연하게도 시간초과 판정을 받았다. 고민 끝에 쓸데없는 반복 작업을 줄이기 위해서 이미 부순 방 목록을 저장해보기로 했다. 빌런이 부술 방 목록을 오름차순으로 정렬한 다음에, 이전에 부순 방보다 작은 번호의 방이면 유니온 함수 작업을 건너뛰도록 설계했다. 예를 들어서 1번부터 10번까지 방이 있을 때 만약에 3번에서 10번까지 방을 부쉈다면, 그 다음에 5번부터 10번까지 방을 부숴야 할 때는 건너뛰는 것이다.

분리 집합 알고리즘을 이용한 코드는 다음과 같다.

분리 집합 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import sys
input = sys.stdin.readline

n = int(input())
m = int(input())

parent = [0] * (n+1)
for i in range(1, n+1):
    parent[i] = i
        
def find_parent(x):
    if parent[x] != x:
        parent[x] = find_parent(parent[x])
        
    return parent[x]

def union(x, y):
    x = find_parent(x)
    y = find_parent(y)
    
    if x < y:
        parent[x] = y
        
    else:
        parent[y] = x
    
    
cnt = 0
lst = []
for i in range(m):
    x, y = map(int, input().split())
    lst.append((x, y))
    
lst = sorted(lst, key = lambda x : (x[0], x[1]))

memo = 0
for s, e in lst:
    if s < memo:
        s = memo
        
    for j in range(s, e+1):
        if find_parent(j) != find_parent(e):
            union(j, e)
            cnt += 1
            memo = e
            
print(n - cnt)            
    




그러나 더 간단하게 풀 수 있다. 사실 위에서 설명한 논리대로라면 분리 집합 알고리즘을 사용할 필요조차 없다.

그래서 그냥 그리디와 정렬 알고리즘만 써서 풀어봤다. 결과적으로 효율성이 향상되었다.

정렬 & 그리디 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
input = sys.stdin.readline

n = int(input())
m = int(input())

cnt = 0
lst = []
for i in range(m):
    x, y = map(int, input().split())
    lst.append((x, y))
    
lst = sorted(lst, key = lambda x : (x[0], x[1]))

memo = 1
cnt = 1
for s, e in lst:
    if s <= memo:
        memo = max(memo, e)
        
    else:
        cnt += s - memo
        memo = e
        
cnt += n - memo
            
print(cnt)