$$\newcommand{\Z}{\mathbb{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\Q}{\mathbb{Q}} \newcommand{\N}{\mathbb{N}}\newcommand{\C}{\mathbb{C}} \newcommand{\oiv}[1]{\left] #1 \right[} \newcommand{\civ}[1]{\left[ #1 \right]} \newcommand{\ad}[1]{\text{ad}(#1)} \newcommand{\acc}[1]{\text{acc}(#1)} \newcommand{\Setcond}[2]{ \left\{\, #1 \mid #2 \, \right\}} \newcommand{\Set}[1]{ \left\{ #1 \right\}} \newcommand{\abs}[1]{ \left\lvert #1 \right\rvert}\newcommand{\norm}[1]{ \left\| #1 \right\|}\newcommand{\prt}{\mathcal{P}}\newcommand{\st}{\text{ such that }}\newcommand{\for}{\text{ for }} \newcommand{\cl}[1]{\text{cl}(#1)}\newcommand{\oiv}[1]{\left] #1 \right[}\newcommand{\interior}[1]{\text{int}(#1)}$$

FFT (Fast Fourier Transform) 와 실수오차::::Gratus' Blog

FFT (Fast Fourier Transform) 와 실수오차

알고리즘 공부/Mathematics 2020. 2. 27. 21:31

FFT (Fast Fourier Transform) 은 기본적으로, 다항식의 계수 표현 (Coefficient form. 정식 용어는 아닌거 같지만 어디선가 본 말이다) 과 점 표현 (Point-value form) 을 오가는 알고리즘이다.

FFT를 이용하면, 두 다항식의 곱을 $O(n \log n)$ 에 계산하는 등 일반적으로는 할 수 없는 연산들을 빠르게 수행할 수 있기 때문에 다양한 용도로 사용할 수 있고, PS에서는 나름 고인물 알고리즘의 시작(?) 같은 느낌으로 생각되는 것 같다. ICPC Regional에도 가끔 나오는 것 같고...물론 요즘같이 다들 고여버린 때는 별로 의미가 없지만 ㅋㅋㅋ

 

FFT가 왜 작동하는지, 어떻게 작동하는지, 어떤 일들을 할 수 있는지는 다양한 자료를 통해 공부할 수 있고 나도 완벽히 이해하지 못한 부분도 있는 것 같아서, 여기서는 설명하지 않으려고 한다. FFT를 잘 모른다면 다음과 같은 자료들을 강력히 추천(?) 한다. 

이렇게 FFT를 공부하더라도, 실제로는 보통 라이브러리 (팀노트 또는 개인 레포) 에 넣어놓고 필요할 때마다 베껴 쓰게 되는게 어느정도 현실(?) 이다. 구현이 상당히 귀찮고 은근히 쓸데가 많아서, 외울 자신이 없다면 베껴 쓰는게 괜찮은것 같다. (나는 개인적으로 두번째 링크의 코드를 일부 수정해서 쓰고 있다) 

 

이런식으로 적당히 FFT 라이브러리를 하나 찾아서 대충 이해하고 대충 복붙하는 것으로 해결할 수 있는 FFT 꿀문제가 몇개 있다. 전부 SAC Platinum 이상의 문제들이니, 말 그대로 경험치 파밍이 가능하다 :) 

  • BOJ 1067 "이동"
  • BOJ 14756, 한국 ICPC 2017 인예 "Telescope"

그런데, 결국 이 문제들은 "다항식 곱셈을 빨리 할 수 있다면, 풀 수 있을 텐데" 의 형태로 환원된다. 그리고 보통 이보다 어려운 FFT 문제는 어떻게 이 문제를 다항식으로 해결할지 생각하기가 어려운 문제들이 많다. 

이와는 달리, 오늘 주제(?) 로 삼은 문제는 https://www.acmicpc.net/problem/11385 로, 정말 단순히 "두 다항식을 곱하세요" 라는 느낌의 문제이다. 그럼에도 SAC 난이도는 다이아 3으로 책정되어 있다.


가장 일반적인 FFT 코드 (위 블로그의 코드 등) 을 사용하면, 보통은 바로 WA를 적립하게 된다. 왜인지 생각해 보면,

- 먼저, FFT를 위해서는 Complex 를 써야 하는데, 대부분이 double 두개로 각각 real, complex 를 나타내는 클래스를 사용한다. (딱히 이외의 방법이 없다) 

- 이를 이용해서 sin, cos 등의 삼각함수를 이용해야 한다.

두 부분 모두 실수오차에 시달리기 딱 좋은 부분이다. 특히 C++ 의 double 자료형은 대략 소숫점 아래 15자리 정도의 정밀도를 가지고 있는데, 이 문제의 경우 long long 범위 전체에서 정수 반올림을 했을 때 오차가 없어야 한다. 이런 끔찍한 경우에 어떻게 FFT를 잘 사용할 수 있는지 생각해 보자.

 

먼저, 당연히 Number Theoretic Transform 을 사용해서 이 문제를 해결할 수 있다. NTT가 뭔지는 이 글을 읽으러 온 PS에 관심이 있는 사람이라면 들어봤을 확률이 높지만, 복소수 대신 $\mathbb{Z}_p$ 에서 FFT를 수행하는 것이다. 어떤 소수 $p$의 원시근 $w$ 를 이용해서 FFT-tic 한 방법을 사용할 수 있음이 알려져 있다.

보통 FFT 문제 -> mod 998244353인 경우, $p = 998,244,353$, $w = 3$ 의 NTT를 쓰는 문제인 경우가 많다. 다만 이 해법의 문제는, 씽크스몰처럼 모듈러를 취하는 부분이 없거나, 10억 7같은 수의 모듈러를 취하는 경우 쓸 수 없다는 점이다. 구체적으로는 $p = a \times 2^b + 1$ 일 때 좋은데 10억 7은 그게 없는 소수라서...

 

다만, 모듈러가 아예 없는 문제의 경우, 적당히 큰 소수 두어개를 골라서 모듈러를 취한 NTT를 쓴 다음, 이를 Chinese Remainder Theorem 을 이용해서 원래의 수를 복원해내는 방법이 있다. 이렇게 하고 싶다면, $998244353 = 1+7*17*2^{23}$ 과 $167772161 = 1+5*2^{25}$ 를 쓰면 된다. 이렇게 하면 2배 + CRT 돌리는데 필요한 추가 시간 만큼 으로 오차 없는 FFT를 할 수 있다. 

 

그 외의 방법으로는, 계수를 쪼개서 다항식을 2개 쓰는 방법이 있다. 적당한 상수 C를 잡고

$$A(x) = A_1(x) + CA_2(x)$$와 같이 다항식을 쪼개고 나면, 각 다항식의 계수는 줄어든 상태이므로 같은 방법으로 쪼갠 $B_1, B_2$ 와 곱하여, 

$$A(x)B(x) = A_1(x)B_1(x) + C(A_1(x)B_2(x)+A_2(x)B_1(x)) + C^2(A_2(x)B_2(x))$$ 로 나타낼 수 있다. 이렇게 하려면 FFT를 4번 해야 하므로 (Complex 부분을 잘 이용해서 3번 하는 방법도 있는 것 같다. 이부분은 더 복잡한 것 같고 아직 잘 모르겠어서 일단은 뺐는데, 혹시 나중에 알아내면 추가할 계획이다. 설마 3번하면 통과하고 4번하면 잘리는 문제가 나오겠어?) 느리지만, 각각의 계수를 크게 줄일 수 있어서 오차가 줄어들게 된다. $C = \sqrt{M}$ 을 고른다면, 각각의 계수를 최대 $\sqrt{M}$ 이하로 제한하는 셈이 되므로, long long 범위 내의 정수계수 다항식 곱셈을 하기 위해 10자리 정도 정밀도만 있으면 충분해진다. 즉, double을 쓰더라도 이렇게 하면 오차가 없게 된다.


씽크스몰 코드 보기

더보기
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#define usecppio ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define all(x) ((x).begin()),((x).end())
using pii = pair <int, int>;
#define int ll
#define INF 0x7f7f7f7f7f7f7f7f
const bool debug = false;
int N, M;

/* FFT Library : Originally Written by Myungwoo *
 * https://blog.myungwoo.kr/54                  *
 * Nonrecursive, Bit-Flipping Trick             *
 * Several Modifications                        */
#define sz(v) ((int)(v).size())
typedef complex<double> base;
typedef vector <int> vi;
typedef vector <base> vb;
const double PI = acos(-1);
void fft(vector<base> &a, bool inv = false){
    int n = (int)a.size();
    for(int i=1,j=0;i<n;i++){
        int bit = n >> 1;
        for(;j>=bit;bit>>=1)j -= bit;
        j += bit;
        if(i < j)swap(a[i], a[j]);
    }
    vector<base> root(n/2);
    double ang = 2*acos(-1)/n*(inv?-1:1);
    for(int i=0;i<n/2;i++)root[i]=base(cos(ang*i), sin(ang*i));
    for(int idx = 2; idx <= n; idx <<= 1){
        int step = n / idx;
        for(int i=0;i<n;i+=idx){
            for(int j=0;j<idx/2;j++){
                base u = a[i+j], v = a[i+j+idx/2]*root[step*j];
                a[i+j] = u+v;
                a[i+j+idx/2] = u-v;
            }
        }
    }
    if(inv){
        for(auto &i : a) i /= n;
    }
}
const int LARGE = 1e6;
/* FFT polynomial Multiplication with higher precision */
void multiply(const vi &a, const vi &b, vi &res)
{
    vector <base> fa_big, fb_big;
    vector <base> fa_small, fb_small;
    int cut_val = sqrt(LARGE);
    int n = 1;
    while(n < max(sz(a),sz(b)))
        n <<= 1;

    fa_big.resize(n);
    fa_small.resize(n);
    fb_big.resize(n);
    fb_small.resize(n);
    for (int i = 0; i<sz(a); i++)
    {
        fa_big[i] = a[i]/cut_val;
        fa_small[i] = a[i]%cut_val;
    }
    for (int i = 0; i<sz(b); i++)
    {
        fb_big[i] = b[i]/cut_val;
        fb_small[i] = b[i]%cut_val;
    }
    fft(fa_big,0), fft(fb_big,0);
    fft(fa_small, 0), fft(fb_small, 0);
    vector <base> fa_big_2(all(fa_big));
    vector <base> fa_small_2(all(fa_small));
    for (int i = 0; i<n; i++)
    {
        fa_big_2[i] *= fb_big[i];
        fa_big[i] *= fb_small[i];
        fa_small[i] *= fb_small[i];
        fa_small_2[i] *= fb_big[i];
    }
    fft(fa_small,1);
    fft(fa_small_2, 1);
    fft(fa_big, 1);
    fft(fa_big_2, 1);
    res.resize(n);
    for (int i = 0; i<n; i++)
    {
        int ss = (int64_t)round(fa_small[i].real());
        int sb = (int64_t)round(fa_small_2[i].real());
        int bs = (int64_t)round(fa_big[i].real());
        int bb = (int64_t)round(fa_big_2[i].real());
        res[i] = ss;
        res[i] += (sb+bs)*cut_val;
        res[i] += bb*cut_val*cut_val;
    }
}

vector <int> f, g, res;
int32_t main()
{
    usecppio
    cin >> N >> M;
    f.resize(2*(N+1)+5);
    g.resize(2*(M+1)+5);
    for (int i = 0; i<=N; i++)
        cin >> f[i];
    for (int i = 0; i<=M; i++)
        cin >> g[i];
    multiply(f,g,res);
    int t = 0;
    for (int i = 0; i<res.size(); i++)
    {
        //printf("%lld ",res[i]);
        t ^= res[i];
    }
    //printf("\n");
    cout << t << '\n';
}
admin