$$\newcommand{\Z}{\mathbb{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\Q}{\mathbb{Q}} \newcommand{\N}{\mathbb{N}}\newcommand{\C}{\mathbb{C}} \newcommand{\oiv}[1]{\left] #1 \right[} \newcommand{\civ}[1]{\left[ #1 \right]} \newcommand{\ad}[1]{\text{ad}(#1)} \newcommand{\acc}[1]{\text{acc}(#1)} \newcommand{\Setcond}[2]{ \left\{\, #1 \mid #2 \, \right\}} \newcommand{\Set}[1]{ \left\{ #1 \right\}} \newcommand{\abs}[1]{ \left\lvert #1 \right\rvert}\newcommand{\norm}[1]{ \left\| #1 \right\|}\newcommand{\prt}{\mathcal{P}}\newcommand{\st}{\text{ such that }}\newcommand{\for}{\text{ for }} \newcommand{\cl}[1]{\text{cl}(#1)}\newcommand{\oiv}[1]{\left] #1 \right[}\newcommand{\interior}[1]{\text{int}(#1)}$$

BOJ 1806 부분합::::Gratus' Blog

BOJ 1806 부분합

알고리즘 문제풀이/BOJ 2019. 8. 11. 03:35

문제 설명

10,000 이하 자연수로 이루어진 길이 $n$ 짜리 수열에서, 어떤 subsegment의 합이 $S$ 이상이 되는 가장 짧은 subsegment의 길이 $L$ 을 찾는 문제. 길이는 최대 $100, 000$ 이다.

 

 

사실 문제 풀이는 별로 어렵지 않은데, 약간 이런 것도 된다 (왜 되지?) 같은게 있어서..

 

 

풀이 1 : $\mathcal{O}(n)$ : 투 포인터

포인터 변수 두개를 들고 다니면서, 합이 충분히 크면 왼쪽 포인터를 올려서 길이를 줄이고, 합이 작으면 오른쪽 포인터를 올려서 길이를 늘리는 식으로 보면서 성공할 때마다 길이를 보관한다. 

$O(n)$ 에 해결할 수 있는 방법인데, 처음에는 이렇게 풀 생각도 안하고 다르게 풀어서 이건 나중에 듣고 알았다.

 

 

풀이 2 : $\mathcal{O}(n \log n)$ : 이분 탐색

처음 코딩했을 때 이분탐색을 이용해서 풀었다.

Prefix Sum을 저장하면, 임의의 구간 $(a, b)$ 의 합을 $O(1)$에 계산할 수 있다. 이제, subsegment의 왼쪽 끝을 $i$번으로 고정하고, 오른쪽 인덱스를 이분탐색으로 찾자. 이게 가능한 이유는 전체 배열이 양수밖에 없기 때문에, $(i, x)$ 가 valid하다면 $(i, x+k)$ 가 임의의 양수 $k$에 대해 valid한 답이기 때문에 결국은 가능한 최소를 찾는 문제가 되기 때문이다.

각 $i$에 대해 $O(\log n)$ 에 최소의 가능한 오른쪽 끝을 찾을 수 있으므로, 전체 $O(n \log n)$ 에 해결할 수 있다.

 

#include <bits/stdc++.h>
#define ll long long
#define usecppio ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
using namespace std;

#define int ll
int arr[101010];
int cumul[101010];

inline int sum(int a, int b)
{
	return cumul[b]-cumul[a-1];
}

int n, s;
int ans = 0;

int32_t main()
{
	usecppio
	cin >> n >> s;
	for (int i = 1; i<=n; i++)
	{
		cin >> arr[i];
		cumul[i] = cumul[i-1]+arr[i];
	}
	int ans = INT_MAX;
	for (int i = 1; i<=n; i++)
	{
		if (sum(i,n)<s)
			continue;
		else
		{
			int lo = i, hi = n;
			while(lo+1 < hi)
			{
				int mid = (lo + hi) / 2;
				if (sum(i,mid) < s)
					lo = mid;
				else
					hi = mid;
			}
			ans = min(ans, (hi-i+1));
		}
	}
	cout << (ans==INT_MAX?0:ans);
}

풀이 3(?) $\mathcal{O}(n^2)$ : Naive

요즘 컴파일러는 지나치게 엄청 최적화를 잘 해 주고, 사실 대부분의 채점 서버는 1초에 1억보다 많은 연산을 할 수 있다. 실제로 bitset 같은걸로 $n^2 / 32$ 같은건 꽤 많이 쓰이는 테크닉이기도 하고.... 반복문 내부가 정말 많이 단순하다면, 가끔 $10^9$ 내지는 그보다 많은 연산도 1초에 돌아가는 신기한 일들이 벌어진다.

 

그러니까, 상수를 잘 커팅하면 $O(n^2)$ 도 비벼지지 않을까? 라는 생각에 이걸 시도해 봤다. 

 

이 문제를 가장 단순하게 풀면, 

- 각 인덱스 $i$에 대하여, 

- 오른쪽으로 가면서 하나씩 더해 보고, 합이 $S$를 넘어가면 길이를 보관하고 멈춘다.

- 최소 길이를 확인한다.

 

이 알고리즘의 시간 복잡도는 당연히 $O(n^2)$ 이다.

 

pragma 사용

일단 컴파일러를 믿자. https://codeforces.com/blog/entry/66279 에도 나와있는데, 컴파일러에게 마법의 주문을 몇줄 외워주면 (....) 가끔 빨라진다. 왜인지는 사실 나도 잘 모르겠는데 vectorization이라는 테크닉이랑, Intel Intrinsic 같은데를 봐야 나오는 256비트를 들고 PPAP추는 이상한 명령어들을 써서 빨라지는거 같다. 코드포스의 dmkozyrev라는 유저에 의하면 어떨때는 8배쯤 빨라져서 $O(n^2)$ 나 $O(nq)$ 솔루션이 뚫리는 경우가 있다고 한다. 

#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimization ("unroll-loops")

이것만으로는 충분하지 않았고, TLE를 받았다. 로컬에서 테스트하면 두배 정도 빨라지던데 BOJ는 이미 -O2를 붙여주기 때문에 얼마나 빠른지는 알기 어렵다. 

 

자잘한 최적화들

주어진 배열에서 최댓값이 $M$이라고 하자. 이때, 길이가 $S / M$ 이하인 배열은 아예 볼 필요가 없다. 절대 합이 $S$를 넘을 수가 없으니까. 그러므로, 미리 $S / M$ 을 계산해 놓고 매번 $i$마다 그만큼씩 건너뛰어서 그 뒤만 보자. 

이러면 추가로 끝에서부터 $S / M$ 만큼은 아예 답이 될 수 없으므로 그냥 버리면 된다.

일단은 $M$을 계산하지 않고 그냥 돌리면 어떨지 궁금해서, $M = 10, 000$ (가능한 최댓값이다) 으로 잡아놓고 돌렸다. 왜인지 잘 모르겠지만 612ms로 통과하던데, 아마도 최댓값 저격 데이터가 $S$가 1억이나 아무튼 매우 큰 수로 잡혀 있는듯? 데이터가 약해서 뚫리는건가 싶어서 로컬에서 데이터를 이것저것 만들어서 실행해 봤는데, 로컬이 빠른건지 뭐가 문제인지는 잘 모르겠지만 나는 1000ms를 확실히 넘는 데이터를 못 찾았다. 

(로컬에서는 $n = 100, 000$, $s = 100, 000$에 $10,000$이 10만개 들어오는 데이터도 500ms 안에 잘 들어온다.)

 

그리고 사실 큰 차이가 있나 싶지만 입력도 이렇게 받자.

char *p=t;
fread(t, 1, 20000000, stdin);
for (n = *p & 15; *++p & 16; n = n * 10 + (*p & 15));

숫자 하나를 읽어오는 방법인데, 저렇게 읽으면 훨씬 빠르다고 한다. UCPC 전날 과방에서 고인물 친구한테 배운 FastIO..인데, 해보니까 겨우 4ms~8ms 줄어들었다. 10만개 입력받는데 그렇게 차이가 많이 나도 이상하지만... 

 

위에서도 함수 호출 오버헤드를 줄이기 위해 inline을 썼는데, 이런건 당연히 컴파일러가 인라인처리 해주겠지만 일단 #define 매크로로 교체했다.

#define sum(a,b) cumul[b]-cumul[a-1]

 

이런 자잘한 것들을 다 넣어주면 552ms에 통과할 수 있다. 

#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma GCC optimization ("unroll-loops")
#define ll long long
using namespace std;
using pii = pair<int, int>;

#define sum(a,b) cumul[b]-cumul[a-1]
int arr[101010], n;
ll cumul[101010];
ll s;
ll total = 0;

char t[1000000];
int32_t main()
{
	ll c;
	char *p=t;
	fread(t, 1, 1000000, stdin);
	for (n = *p & 15; *++p & 16; n = n * 10 + (*p & 15));
	for (s = *p & 15; *++p & 16; s = s * 10 + (*p & 15));
	for (int i = 1; i <= n; i++)
	{
		for (c = *++p & 15; *++p & 16; c = c * 10 + (*p & 15));
		arr[i] = c;
		cumul[i] = cumul[i-1] + c;
		total += c;
	}
	if (total < s)
		return !printf("0");
	int minc = s/10000;
	int ans = 102000;
	for (int i = 1; i<=n; i++)
	{
		if (sum(i,n)<s)
			continue;
		ll cur = sum(i,i+minc-1);
		int u = min(n, i + ans);
		for (int j = i+minc; j<=u; j++)
		{
			cur += arr[j];
			if (cur >= s)
				ans = min(ans, j-i+1);
		}
	}
	cout << (ans);
}

JAVA로는 추가 시간 때문에 저렇게까지 커팅하지 않아도 $O(n^2)$ 이 뚫린다는것 같다 :( 

데이터는 언제 시간날때 더 만들어 봐야지...

'알고리즘 문제풀이 > BOJ' 카테고리의 다른 글

BOJ 12972 GCD 테이블  (0) 2019.09.13
BOJ 900문제 달성!  (0) 2019.09.01
BOJ 1655 가운데를 말해요  (0) 2019.08.18
BOJ 2515 전시장  (0) 2019.08.09
BOJ 3653 영화 수집  (1) 2019.07.30
admin