快速数论变换

2014.12.09 22:48 Tue| 14 visits oi_2015| 2015_算法笔记| Text

貌似NTT没有省选题与之对应?这不重要!orz VFleaKing大神的神题下江南还有华丽丽的模板着实教育了吾等蒟蒻。

2014-12-09

NTT可以避免精度误差和双精度复数运算,在模P内进行运算,思想和FFT类似。这里贴一下VFleaKing大神的模板……

PS:这个P在选择的时候貌似有好多说道,一不小心就WA的一干二净?

VFleaKing的NTT

//UOJ 34

#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long s64;

const int P = 998244353;
const int FFT_G = 3;

const int MaxN = 100000;
const int MaxFFTN = 262144;

int preGPow[MaxFFTN];

inline int modpow(int a, const int &n)
{
    int res = 1;
    int t = a;
    for (int i = n; i > 0; i >>= 1)
    {
        if (i & 1)
            res = (s64)res * t % P;
        t = (s64)t * t % P;
    }
    return res;
}

void fft(int *a, int n, int s, int *out)
{
    if (n == 1)
    {
        out[0] = a[0];
        return;
    }

    int m = n >> 1;
    fft(a, m, s + 1, out);
    fft(a + (1 << s), m, s + 1, out + m);
    for (int i = 0; i < m; i++)
    {
        int o = out[i], e = (s64)out[i + m] * preGPow[i << s] % P;
        out[i] = (o + e) % P;
        out[i + m] = (o + P - e) % P;
    }
}

inline void polymulto(int n, int *a, int m, int *b)
{
    static int da[MaxFFTN];
    static int db[MaxFFTN];

    int tn = 1;
    while (tn < n + m)
        tn <<= 1;

    int curG = modpow(FFT_G, (P - 1) / tn);
    preGPow[0] = 1;
    for (int i = 1; i < tn; i++)
        preGPow[i] = (s64)preGPow[i - 1] * curG % P;

    fft(a, tn, 0, da);
    fft(b, tn, 0, db);

    for (int i = 0; i < tn; i++)
        da[i] = (s64)da[i] * db[i] % P;
    reverse(preGPow + 1, preGPow + tn);

    fft(da, tn, 0, a);

    int revTN = modpow(tn, P - 2);
    for (int i = 0; i < tn; i++)
        a[i] = (s64)a[i] * revTN % P;
}

int main()
{
    int n, m;
    static int a[MaxFFTN], b[MaxFFTN];

    cin >> n >> m;
    n++, m++;
    for (int i = 0; i < n; i++)
        scanf("%d", &a[i]);
    for (int i = 0; i < m; i++)
        scanf("%d", &b[i]);

    polymulto(n, a, m, b);

    for (int i = 0; i < n + m - 1; i++)
        printf("%d ", a[i]);
    printf("\n");

    return 0;
}

2014-12-10

今天上午发奋图强扒代码。。。

VFleaKing说递归版之所以慢是因为多次复制数组。但是实际证明,哪怕是多次复制数组,NTT还是快……

UOJ评测结果:

递归版FFT:4837ms

非递归版FFT:1097ms

递归版NTT:1140ms

非递归版NTT:960ms

VFleaKing的递归版不复制数组NTT:重测无数次之后994ms

最快提交:EoLV_01

最短提交:EoLV_01

我的NTT

递归版

#include <bits/stdc++.h>
using namespace std;

#define P 998244353
#define N 262145

int s,t,n,a[N],b[N],p[N],w[N];

int power(int a,int b)
{
    int re=1;
    a=a%P;
    while(b)
    {
        if(b&1) re=((long long)re*a)%P;
        a=((long long)a*a)%P; b>>=1;
    }
    return re;
}

void NTT(int x[],int n)
{
    for(int i=0,t=0;i<n;i++)
    {
        if(i>t) swap(x[i],x[t]);
        for(int j=n>>1;(t^=j)<j;j>>=1);
    }
    int d=0;
    for(int i=n;i;i>>=1) d++; d-=2;
    for(int s=2;s<=n;d--,s<<=1)
        for(int i=0;i<n;i+=s)
        {
            for(int t,j=0;j<s>>1;j++)
                t=(long long)w[j<<d]*x[i+j+(s>>1)]%P,
                x[i+j+(s>>1)]=(x[i+j]-t+P)%P, x[i+j]=(x[i+j]+t)%P;
        }
}

void getw(int n)
{
    int G=power(3,(P-1)/n);
    w[0]=1;
    for (int i=1;i<n;i++)
        w[i]=(long long)w[i-1]*G%P;
}

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

int main()
{
    cin>>s>>t;
    for(int i=0;i<=s;i++) a[i]=read();
    for(int i=0;i<=t;i++) b[i]=read();
    for(int x=max(s,t)+1,i=1;i>>2<x;i<<=1) n=i;
    getw(n);
    NTT(a,n), NTT(b,n);
    for(int i=0;i<n;i++) p[i]=((long long)a[i]*b[i])%P;
    reverse(w+1,w+n);
    NTT(p,n);
    int rev=power(n,P-2);
    for(int i=0;i<=s+t;i++)
        printf("%d ",int((long long)p[i]*rev%P));
    return 0;
}

非递归版

#include <bits/stdc++.h>
using namespace std;

#define P 998244353
#define N 262145

int s,t,n,a[N],b[N],p[N],w[N];

int power(int a,int b)
{
    int re=1;
    a=a%P;
    while(b)
    {
        if(b&1) re=((long long)re*a)%P;
        a=((long long)a*a)%P; b>>=1;
    }
    return re;
}

void NTT(int x[],int n)
{
    for(int i=0,t=0;i<n;i++)
    {
        if(i>t) swap(x[i],x[t]);
        for(int j=n>>1;(t^=j)<j;j>>=1);
    }
    int d=0;
    for(int i=n;i;i>>=1) d++; d-=2;
    for(int s=2;s<=n;d--,s<<=1)
        for(int i=0;i<n;i+=s)
        {
            for(int t,j=0;j<s>>1;j++)
                t=(long long)w[j<<d]*x[i+j+(s>>1)]%P,
                x[i+j+(s>>1)]=(x[i+j]-t+P)%P, x[i+j]=(x[i+j]+t)%P;
        }
}

void getw(int n)
{
    int G=power(3,(P-1)/n);
    w[0]=1;
    for (int i=1;i<n;i++)
        w[i]=(long long)w[i-1]*G%P;
}

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

int main()
{
    cin>>s>>t;
    for(int i=0;i<=s;i++) a[i]=read();
    for(int i=0;i<=t;i++) b[i]=read();
    for(int x=max(s,t)+1,i=1;i>>2<x;i<<=1) n=i;
    getw(n);
    NTT(a,n), NTT(b,n);
    for(int i=0;i<n;i++) p[i]=((long long)a[i]*b[i])%P;
    reverse(w+1,w+n);
    NTT(p,n);
    int rev=power(n,P-2);
    for(int i=0;i<=s+t;i++)
        printf("%d ",int((long long)p[i]*rev%P));
    return 0;
}