$$\newcommand{\Z}{\mathbb{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\Q}{\mathbb{Q}} \newcommand{\N}{\mathbb{N}}\newcommand{\C}{\mathbb{C}} \newcommand{\oiv}[1]{\left] #1 \right[} \newcommand{\civ}[1]{\left[ #1 \right]} \newcommand{\ad}[1]{\text{ad}(#1)} \newcommand{\acc}[1]{\text{acc}(#1)} \newcommand{\Setcond}[2]{ \left\{\, #1 \mid #2 \, \right\}} \newcommand{\Set}[1]{ \left\{ #1 \right\}} \newcommand{\abs}[1]{ \left\lvert #1 \right\rvert}\newcommand{\norm}[1]{ \left\| #1 \right\|}\newcommand{\prt}{\mathcal{P}}\newcommand{\st}{\text{ such that }}\newcommand{\for}{\text{ for }} \newcommand{\cl}[1]{\text{cl}(#1)}\newcommand{\oiv}[1]{\left] #1 \right[}\newcommand{\interior}[1]{\text{int}(#1)}$$

Codeforces 1251F (Educational 75) Red-White Fence::::Gratus' Blog

Codeforces 1251F (Educational 75) Red-White Fence

알고리즘 문제풀이/Codeforces 2020. 4. 24. 02:29

난이도 : Codeforces 2600

사용하는 알고리즘 : FFT, Combinatorics

 

1. 문제 설명

Red Fence $k < 5$ 개와, White Fence $n < 1e5$ 개의 길이가 주어진다. 이때, 다음과 같은 조건을 만족하는 Fence를 만들려고 한다.

- 가운데의 Red-Fence를 기준으로, 그 앞과 뒤에 0개 이상의 White Fence를 설치한다.

- 이때, Red Fence까지의 높이는 Strictly increasing해야 한다. 즉, 1 < 3 < 5 < 7, 7이 빨간색과 같은 형태.

- Red Fence부터 끝까지의 높이는 Strictly Decreasing해야 한다. 

즉, Red Fence를 중심으로 하는 Strict bitonic 을 요구한다.

쿼리가 잔뜩 주어지고, 각 쿼리는 둘레가 $q$인 올바른 Fence의 경우의 수를 출력하는 문제.

 

2. 풀이 설명

직사각형을 이러저리 붙여서 만든 도형이기 때문에, 둘레를 구하는 공식을 잘 써보면 다음과 같은 사실을 알 수 있다. 

먼저, Red fence가 적으므로, Red fence를 먼저 고정하고 White fence들을 골라 만든다고 생각하자.  
길이가 $L$ 인 Red fence를 쓰고, (길이와 무관하게) White fence를 $n$개 쓴다면 둘레는 $2(L+n+1)$ 이므로, Red fence를 고정한 채로 다음을 빠르게 답할 수 있으면 된다.  

 

"길이가 $L$ 인 Red fence를 골랐을 때, 만들 수 있는 서로 다른 Fence조합의 경우의 수"  

길이가 $L$ 보다 Strictly 작은 흰색 판자만 생각하자. 다음과 같은 다항식을 관리할 것이다.

$$f(x) = \sum_{n = 1} a_n x^n$$이때, $a_n$은 선택한 red fence를 포함, $n$개의 판자를 이용해서 만드는 fence의 경우의 수. 각각의 판자들이 이 다항식에 기여하는 부분을 찾자.

 

- 판자의 어떤 길이가 유일하다면 (길이가 $t$인 판자가 하나밖에 없다면), 이 판자를 "쓴다면", 왼쪽 또는 오른쪽에 넣을 수 있고 (+1개, 2), 안 쓰는 (+0개, 1) 경우가 있다. 즉, unique한 판자는 다항식에 $(2x+1)$ 을 기여한다.

 

- 이 판자가 둘 이상 있다면, 이 판자를 양쪽에 넣을 수 있고 (+2개, 1), 왼쪽 또는 오른쪽에 넣을 수 있고 (+1개, 2), 안 넣을수 있다 (+0개, 1). 즉, non-unique한 판자는 몇개가 있든 다항식에 $(x^2 + 2x + 1)$ 만큼만 기여한다.

 

따라서, unique한 판자와 그렇지 않은 판자의 개수를 모두 세고, 각각을 $a, b$ 라고 할 때, 

$$(x^2+2x+1)^b (2x+1)^a$$ 의 계수들을 빠르게 계산하면 된다. 각 $a$제곱, $b$제곱 부분은 이항정리를 이용하여 전처리할 수 있고, 최대 수십만차 다항식을 곱셈해야 하므로 FFT를 쓰면 된다. Modulo 998244353이므로 NTT를 쓰자.

 

NTT에 대해서는 그런게 가능하며 어떻게 한다는건 알았지만 긁어 붙일만한 Reference Code가 없었는데, FFT에서 Field만 바꾼 것이므로 그냥 바로 구현할 수 있었다. 

 

3. 코드

FFT 구현체는 기본적으로 https://blog.myungwoo.kr/54 를 많이 참조했는데, 실수연산시의 정밀도 등을 위해 root의 값을 미리 다 구해놓는다는 정도의 차이가 있다. 그외에는 거의 대부분의 사람들이 비슷한 FFT 코드를 쓰는 것 같다.

실수 FFT 코드와 최대한 비슷하게 만들기 위해 typedef int base 같은게 들어가 있다. ㅋㅋ

#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#define ll long long
#define eps 1e-7
#define all(x) ((x).begin()),((x).end())
#define usecppio ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
using namespace std;
#define int ll
const bool debug = 0;
using pii = pair<int, int>;
typedef long double D;
/* Library for solving Combinatorics */
const int MAX_N = 606060;
const int mod = 998244353;
ll modpow(ll x, ll y, ll p = mod)
{
    ll res = 1;
    x = x % p;
    while (y > 0)
    {
        if (y & 1)
            res = (res*x) % p;
        y = y>>1;
        x = (x*x) % p;
    }
    return res;
}

int fact[MAX_N+5];
int invfact[MAX_N+5];

int binom(int n, int r)
{
    int t = (fact[n]*(invfact[r]))%mod;
    return (t*invfact[n-r])%mod;
}

void precompute()
{
    fact[0] = 1; invfact[0] = 1;
    for (int i = 1; i<=MAX_N; i++)
    {
        fact[i] = fact[i-1]*i;
        fact[i] %= mod;
    }
    invfact[MAX_N] = modpow(fact[MAX_N],mod-2);
    for (int i = MAX_N-1; i>=1; i--)
    {
        invfact[i] = invfact[i+1]*(i+1);
        invfact[i] %= mod;
    }
}

int frac(int a, int b)
{
    return (a*(modpow(b,mod-2)))%mod;
}
/* Library for solving Combinatorics */

vector <int> one;
vector <int> many;
vector <int> res;
vector <int> white(303030, 0);
vector <int> red;

//NTT Polynomial Multiplication
#define sz(v) ((int)(v).size())
typedef int base;
typedef vector <int> vi;
typedef vector <base> vb;
const double PI = acos(-1);

void fft(vb &a, bool invert)
{
    int n = sz(a);
    for (int i = 1, j=0; i<n; i++)
    {
        int bit = n>>1;
        for (; j>=bit; bit>>=1)
        {
            j -= bit;
        }
        j += bit;
        if (i < j)
            swap(a[i],a[j]);
    }
    vector<base> root(n/2);
    int ang = modpow(3, (mod - 1) / n);
    if(invert) ang = modpow(ang, mod - 2);
    root[0] = 1;
    for(int i = 1; i<n/2; i++) root[i] = (root[i-1]*ang)%mod;
    for (int len = 2; len <= n; len <<= 1)
    {
        int step = n / len;
        for (int i = 0; i<n; i+= len)
        {
            for (int j = 0; j<len/2; j++)
            {
                base u = a[i+j], v = (a[i+j+len/2]*root[step*j])%mod;
                a[i+j] = (u+v)%mod;
                a[i+j+len/2] = (u-v)%mod;
            }
        }
    }
    if (invert)
    {
        for (int i = 0; i<n; i++)
            a[i] = frac(a[i],n);
    }
    for (int i = 0; i<n; i++)
        a[i] = (a[i]+10*mod)%mod;
}

void multiply(const vi &a, const vi &b, vi &res_)
{
    vector <base> fa(all(a)), fb(all(b));
    int n = 1;
    while(n < max(sz(a),sz(b)))
        n <<= 1;
    n <<= 1;
    fa.resize(n); fb.resize(n);
    fft(fa,0), fft(fb,0);
    for (int i = 0; i<n; i++)
        fa[i] = (fa[i]*fb[i]+mod)%mod;
    fft(fa,1);
    res_.resize(n);
    for (int i = 0; i<n; i++)
        res_[i] = (fa[i]+mod)%mod;
}

int ans[606060];


int32_t main()
{
    int n, k;
    precompute();
    usecppio
    cin >> n >> k;
    for (int i = 0; i<n; i++)
    {
        int x;
        cin >> x;
        white[x]++;
    }
    for (int i = 0; i<k; i++)
    {
        int r; cin >> r; red.push_back(r);
    }
    for (int rb:red)
    {
        one.clear(); many.clear();
        res.clear();
        one.resize(303030);
        many.resize(303030);
        int a = 0, b = 0;
        for (int i = 0; i<rb; i++)
        {
            if (!white[i]) continue;
            if (white[i]==1)
                a++;
            else b++;
        }
        for (int i = 0; i<=a; i++)
        {
            int u = binom(a, i);
            int v = modpow(2, i);
            one[i] = (u*v)%mod;
        }

        for (int i = 0; i<=2*b; i++)
            many[i] = binom(2*b, i);

        multiply(one, many, res);
        for (int i = 0; i<=n; i++)
        {
            ans[rb + i + 1] += res[i];
            ans[rb + i + 1] %= mod;
        }
    }
    int q; cin >> q;
    while(q--)
    {
        int Q; cin >> Q;
        cout << ans[Q/2] << '\n';
    }
}
admin