https://www.acmicpc.net/problem/9465

 

9465번: 스티커

첫째 줄에 테스트 케이스의 개수 T가 주어진다. 각 테스트 케이스의 첫째 줄에는 n (1 ≤ n ≤ 100,000)이 주어진다. 다음 두 줄에는 n개의 정수가 주어지며, 각 정수는 그 위치에 해당하는 스티커의

www.acmicpc.net


이 문제를 무식하게 모든 경우의 수를 따져서 풀려고 하면 어떻게 할 수 있을까? 각 열에서 하나의 스티커만 뽑을 수 있으므로 최대 n개의 스티커를 뽑을 수 있다. 즉 각 열에 대해 위의 스티커를 뽑던가, 아래의 스티커를 뽑던가, 안 뽑던가의 3가지 갈래가 있고 n열까지 있으니 3^n개의 조합들이 나오고, 여기서 변을 공유하는 스티커를 뽑은 케이스를 쳐내야한다. 즉 이 방법의 시간복잡도는 O(3^n)정도가 될 것이고 n은 문제에서 최대 100,000으로 주어진다 했으니 연산횟수는 천문학적인(?) 숫자가 된다. 주어진 시간제한 1초 내에 이 방법으로 문제를 푸는 것은 매우 힘들다..고 봐야 한다.

 

그럼 어떻게 접근할 수 있는가. 그리디하게 접근한다면 일단 각 열에서 무조건 스티커를 떼는데, 가장 높은 값의 스티커들만 떼가야 한다. 그러나 당연히 이는 불가능하다. 변을 공유하는 스티커를 뗄 수 없으니 1번째 열의 스티커를 떼면 그 다음 열에서 뗄 수 있는 스티커, 또 그 다음 열에서 뗄 수 있는 스티커가 정해진다. 이 방법은 최적의 답을 도출할 수 없다.

 

이번엔 dp를 의심해보자. 일단 이 문제는 부분 문제로 쪼갤 수 있다. 1번째 열부터 차례차례 스티커를 떼고 n번째 열의 스티커를 뗀다고 해보자. 물론 각 열의 스티커는 안 뗄 수도 있다. 이 때 정답은 가장 큰 점수를 얻는 스티커 조합이므로, 위쪽 스티커든 아래쪽 스티커든 간에 n번째 열의 스티커는 떼야 한다. 그렇다면 답은

 

마지막 스티커(n번째 열)을 위쪽을 뽑은 경우

마지막 스티커(n번째 열)을 아래쪽을 뽑은 경우

 

이 둘 중 하나다. 여기서 단순히 n번째 열 스티커를 위쪽을 뽑았다면, 이 점수를 n - 1번째 스티커를 아래쪽을 뽑은 경우에 나오는 값에 더해야 합니다. 라고 하면 안된다.  다음 예시를 보자

  1열 2열 3열
1행 10 2 15
2행 2 1 25

1 ~ 3열까지 스티커를 뽑았을 때 나오는 최대 점수는 몇인가? 라는 문제라면 1, 2열에서 뭘 뽑던 3열에서 스티커 하나를 뽑은 케이스가 정답이 된다. 3열에서 아래쪽을 뽑은 케이스를 알기 위해 2열에서 위쪽을 뽑은 케이스를 더한다면 이 때 구해지는 점수는 2 + 2 + 25 = 29가 된다. 그러나 3열에서 아래쪽을 뽑은 케이스의 최대값은 35다! 1열에서 10을, 2열에서 아무것도 안 뽑고 3열에서 25를 뽑은 케이스. 즉, n번째 열의 위쪽을 뽑은 경우 얻는 점수의 최대는

 

n - 1번째 스티커를 아래쪽을 뽑은 경우의 최대누적합

n - 2번째 스티커를 뽑은 경우(위, 아래 상관없음)의 최대누적합

 

이 둘 중 더 큰 값에 n번째 열의 위쪽스티커의 점수를 합한 것이다. dp[1][i]이 i번째 열에서 위쪽 스티커를 뽑은 경우의 최대누적합, dp[2][i]가 i번째 열에서 아래쪽 스티커를 뽑은 경우의 최대누적합이라 하면 점화식은 다음과 같이 얻어진다.

 

dp[1][n] = max(dp[2][n-1], max(dp[1][n-2], dp[2][n-2]))

dp[2][n] = max(dp[1][n-1], max(dp[1][n-2], dp[2][n-2]))

 

이걸 고대로 코드로 옮기면 된다.

 

import sys

T = int(sys.stdin.readline())

for _ in range(T):
    n = int(sys.stdin.readline())
    sticker = [[0]]
    sticker.append([0] + list(map(int, sys.stdin.readline().split())))
    sticker.append([0] + list(map(int, sys.stdin.readline().split())))

    dp = [[0 for i in range(n + 1)] for j in range(3)]
    dp[1][1] = sticker[1][1]
    dp[2][1] = sticker[2][1]

    for i in range(2, n + 1):
        dp[1][i] = max(dp[2][i - 1], max(dp[1][i - 2], dp[2][i - 2])) + sticker[1][i]
        dp[2][i] = max(dp[1][i - 1], max(dp[1][i - 2], dp[2][i - 2])) + sticker[2][i]

    print(max(dp[1][n], dp[2][n]))

 

 

 

 

 

+ Recent posts