快速数论变换与多项式

NTT和多项式的常用基本操作。

在模意义下的。


导入

一般来说,我们需要进行的操作是在模意义下的。

此时我们就无法使用单位根和实数意义下的FFT了。

我们的解决办法是:

原根

我们选择单位根是因为其具有良好的性质可以被我们利用。

如果我们在模意义下找到了单位根的替代品的话,就可以摆脱狭窄的值域和缓慢的三角函数了。

首先我们看一下我们之前用到的单位根的性质:

  1. $\omega^k_n = (\omega^1_n)^k$

同时我们需要保证我们找到的这个 $\omega^1_n$ 不会使得所有的 $\omega^i_n$ 都相同。

  1. $\omega^k_n = \omega^{k \ \bmod{\ n}}_n$

这可以推导出 $\omega^{k + \frac{n}{2}}_n = -\omega^k_n$。

  1. $\omega^{2k}_{2n} = \omega^k_n$

等价于 $(\omega^k_{2n})^2 = \omega^k_n$。

我们再来看看原根的定义。

首先我们定义「阶」。

对于一个数 $a$,其在模 $p$ 意义下的阶为,最小的整数 $k$ 使得 $a^k \equiv 1 \pmod{p}$。
对于 $a \not \perp p$ 的情况,我们认为其阶为 $\infty$ 或直接认为不存在。
当然,我们一般求阶的时候 $p$ 都是质数,且 $a \in [0,p)$。
记 $\delta_p(a)$ 为 $a$ 在模 $p$ 意义下的阶。

然后是「原根」。

对于一个数 $g$,如果满足 $\delta_p(g) = \varphi(p)$,我们就称 $g$ 为 $p$ 的一个原根。

筛选

当然,我们可以猜到,不是随便一个 $g$ 就可以当做单位根来用的。

我们需要找到一个 $g$ 使得其阶恰好为 $n$ 才可以。

一般来说我们的 $p$ 是个质数,所以其原根的阶会达到 $p-1$,而一般我们不会只处理 $n=p-1$ 的情况。

但是我们可以发现,$g^k$ 的阶数为 $\frac{p-1}{\gcd(k,p-1)}$,这给了我们拓展的机会。
同时我们发现,其仍然是 $p-1$ 的约数,所以当 $n \not | (p-1)$ 的时候我们是找不到阶恰好为 $n$ 的数的。
这也解释了为什么我们常用的NTT模数都是形如 $k \times 2^r + 1$ 的,比如说 $998244353 = 119 \times 2^{23} + 1$,$1004535809 = 479 \times 2^{21} + 1$。

NTT

此时我们把单位根换成筛选后的原根,然后套入 FFT 板子中就好了。

同时我们做一些小优化,比如说预处理单位根、用unsigned long long来减少取模次数等等。

板子如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned ll
const int N = 2100010;
const ll G = 114514, mod = 998244353;
ll qpow(ll a, ll x = mod - 2)
{
ll res = 1;
while(x)
{
if(x & 1)res = res * a % mod;
a = a * a % mod;
x >>= 1;
}
return res;
}
const ll invG = qpow(G);
int tr[N];
void NTT(ll *g, bool op, int n)
{
static ull f[N], w[N] = { 1 };
for(int i = 0; i < n; i++)f[i] = g[tr[i]];
for(int l = 1; l < n; l <<= 1)
{
ull tmpG = qpow(op ? G : invG, (mod - 1) / (l * 2));
for(int i = 1; i < l; i++)w[i] = w[i - 1] * tmpG % mod;
for(int k = 0; k < n; k += l * 2)
{
for(int p = 0; p < l; p++)
{
ll tmp = w[p] * f[k | l | p] % mod;
f[k | l | p] = f[k | p] + mod - tmp;
f[k | p] = f[k | p] + tmp;
}
}
if(l == (1 << 10))
for(int i = 0; i < n; i++)f[i] %= mod;
}
if(!op)
{
ull invn = qpow(n);
for(int i = 0; i < n; i++)g[i] = f[i] % mod * invn % mod;
}
else
{
for(int i = 0; i < n; i++)g[i] = f[i] % mod;
}
}
int n, m;
ll f[N], g[N];
int main()
{
scanf("%d%d", &n, &m);
n++, m++;
for(int i = 0; i < n; i++)
scanf("%lld", &f[i]);
for(int i = 0; i < m; i++)
scanf("%lld", &g[i]);
int len = 1;
while(len < n + m)len <<= 1;
for(int i = 0; i < len; i++)
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
NTT(f, 1, len), NTT(g, 1, len);
for(int i = 0; i < len; i++)
f[i] = f[i] * g[i] % mod;
NTT(f, 0, len);
for(int i = 0; i < n + m - 1; i++)
printf("%lld ", f[i]);
putchar('\n');
return 0;
}

多项式四则运算

为方便,下面我们设多项式 $F(x)$ 的 $n$ 次项系数为 $f[n]$,与 $[x^n]F(x)$ 是等价的形式,

加减

直接对应系数加减就好了。

$$
F(x) \pm G(x) = \sum_{i=0} (f[i] \pm g[i])x^i
$$

乘法

根据我们之前的定义,多项式的乘法是类似卷积形式的:

$$
F(x) \times G(x) = \sum_{i=0} x^i \sum_{k=0}^i f[k]g[i-k]
$$

我们可以使用 NTT 来在 $O(n \log n)$ 的时间复杂度内做完。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned ll
#define clr(f,n) memset(f,0,(n)*sizeof(ll));
#define cpy(f,g,n) memcpy(f,g,(n)*sizeof(ll));
const int N = 270010;
const ll G = 3, mod = 998244353;
ll qpow(ll a, ll x = mod - 2)
{
ll res = 1;
while(x)
{
if(x & 1)res = res * a % mod;
a = a * a % mod;
x >>= 1;
}
return res;
}
const ll invG = qpow(G);
int tr[N], tf;
void initr(int n)
{
if(tf == n)return;
tf = n;
for(int i = 0; i < n; i++)
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0);
}
void NTT(ll *g, bool op, int n)
{
initr(n);
static ull f[N], w[N] = { 1 };
for(int i = 0; i < n; i++)f[i] = ((mod << 5) + g[tr[i]]) % mod;
for(int l = 1; l < n; l <<= 1)
{
ull tmpG = qpow(op ? G : invG, (mod - 1) / (l * 2));
for(int i = 1; i < l; i++)w[i] = w[i - 1] * tmpG % mod;
for(int k = 0; k < n; k += l * 2)
{
for(int p = 0; p < l; p++)
{
ll tmp = w[p] * f[k | l | p] % mod;
f[k | l | p] = f[k | p] + mod - tmp;
f[k | p] = f[k | p] + tmp;
}
}
if(l == (1 << 10))
for(int i = 0; i < n; i++)f[i] %= mod;
}
if(!op)
{
ull invn = qpow(n);
for(int i = 0; i < n; i++)g[i] = f[i] % mod * invn % mod;
}
else
{
for(int i = 0; i < n; i++)g[i] = f[i] % mod;
}
}
void mul(ll *f, ll *g, int n)
{
for(int i = 0; i < n; i++)
f[i] = f[i] * g[i] % mod;
}
void timep(ll *f, ll *g, int m, int lim)
{
static ll sav[N];
int n = 1;
while(n < m + m)n <<= 1;
clr(sav, n); cpy(sav, g, n);
NTT(f, 1, n); NTT(sav, 1, n);
mul(f, sav, n); NTT(f, 0, n);
clr(f + lim, n - lim); clr(sav, n);
}

在前面NTT的基础上增加了一些宏定义,并略微调整之使得其适配后面的操作。
后面的直接调用此模板即可。

乘法逆元

定义就是对于一个多项式 $F(x)$,找到一个多项式 $G(x)$ 是的 $F(x) \times G(x) = 1$。

首先很明显,$g[0]$ 就是 $f[0]$ 的乘法逆元。

然后我们尝试由此拓展出整个多项式。

直接递推显然不行,我们无法承受 $O(n^2)$ 的时间复杂度。
那么试试倍增。

假设我们当前进行到了第 $n$ 次项,我们尝试将其推广到 $2n$ 次项。

那么我们当前的多项式 $R_*(x)$ 就可以看做是 $G(x) \pmod{x^n}$,而我们将要推广到的多项式 $R(x)$ 就可以看做是 $G(x) \pmod{x^{2n}}$。
这样才可以从 $g[0]$ 开始倍增。
很明显有 $R(x) = R_*(x) \pmod{x^n}$,作差得 $R(x) - R_*(x) = 0 \pmod{x^n}$。
平方一下得 $R^2(x) - 2R_*(x)R(x) + R^2_*(x) = 0 \pmod{x^{2n}}$,
等式两边同时乘以 $F(x)$ 可以把 $R(x)$ 消掉,得 $R(x) - 2R_*(x) + R^2_*(x)F(x) = 0 \pmod{x^{2n}}$,
由此可以根据 $F(x)$ 和 $R_*(x)$ 求出 $R(x)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void invp(ll *f, int m)
{
int n = 1;
while(n < m)n <<= 1;
static ll w[N], r[N], sav[N];
w[0] = qpow(f[0]);
for(int len = 2; len <= n; len <<= 1)
{
for(int i = 0; i < (len >> 1); i++)r[i] = w[i];
cpy(sav, f, len); NTT(sav, 1, len);
NTT(r, 1, len); mul(r, sav, len);
NTT(r, 0, len); clr(r, len >> 1);
cpy(sav, w, len); NTT(sav, 1, len);
NTT(r, 1, len); mul(r, sav, len);
NTT(r, 0, len);
for(int i = len >> 1; i < len; i++)
w[i] = (w[i] * 2 - r[i] + mod) % mod;
}
cpy(f, w, m);
clr(sav, n); clr(w, n); clr(r, n);
}

注意此处需要开两倍空间,所有带有乘法逆元操作的都需要开两倍空间。

带余数的除法

如果我们可以直接把余式减去,然后就可以直接做上面的多项式除法了。

但是余式的系数集中在低次项,我们能做的只有 $\bmod{x^k}$。

考虑翻转一下多项式的系数,定义 $F^T(x) = \sum_{i=0} f[n-i]x^i$。
稍微换个形式可以得到 $F^T(x) = x^n F(x^{-1})$。

首先我们有式子 $F(x) = Q(x) \times G(x) + R(x)$,其中 $F(x)$ 和 $G(x)$ 已知。
换元得到 $F(x^{-1}) = Q(x^{-1}) \times G(x^{-1}) + R(x^{-1})$,
两边同时乘以 $x^n$ 得到 $x^n F(x^{-1}) = x^n Q(x^{-1}) \times G(x^{-1}) + x^n R(x^{-1})$,
写成翻转的形式得到 $F^T(x) = Q^T(x) \times G^T(x) + x^{n-m+1} R^T(x)$。
此时余式 $R(x)$ 原本集中在低次项的系数全部到了高次项,而低次项都为 $0$。
我们对 $x^{n-m+1}$ 取模,得到 $F^T(x) = Q^T(x) \times G^T(x) \pmod{x^{n-m+1}}$。
然后我们就可以求出来 $Q^T(x)$,进而得到 $Q(x)$。
剩下的 $R(x)$ 直接减就好了。

1
2
3
4
5
6
7
8
9
10
11
12
13
void divp(ll *f, ll *g, int n, int m)
{
static ll q[N], t[N];
int l = n - m + 1;
rev(g, m); cpy(q, g, l); rev(g, m);
rev(f, n); cpy(t, f, l); rev(f, n);
invp(q, l); timep(q, t, l, l); rev(q, l);
timep(g, q, n, n);
for(int i = 0; i < m - 1; i++)
g[i] = (f[i] - g[i] + mod) % mod;
clr(g + m - 1, l);
cpy(f, q, l); clr(f + l, n - l);
}

函数执行完之后,f数组内是商式,g数组内是余式。

多项式开根

我们有式子 $F(x) - G^2(x) = 0$,其中 $F(x)$ 已知。

设 $H(G(x)) = G(x)^2 - F(x)$,则 $H’(G(x)) = 2G(x)$。

根据牛顿迭代 $G(F(x)) = F_*(x) - \frac{G(F_*(x))}{G’(F_*(x))} \pmod{x^n}$,得到:

$$
\begin{align}
G(x) &= G_*(x) - \frac{H(G_*(x))}{H’(G_*(x))} \\
&= G_*(x) - \frac{G^2_*(x) - F(x)}{2G_*(x)} \\
&= G_*(x) + \frac{F(x) - G^2_*(x)}{2G_*(x)} \\
&= \frac{F(x) + G^2_*(x)}{2G_*(x)}
\end{align}
$$

由此可以倍增。

因为题目保证了 $f[0] = 1$,所以我们的 $g[0]$ 也就是 $1$ 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void sqrtp(ll *f, int m)
{
int n = 1;
while(n < m)n <<= 1;
static ll b1[N], b2[N];
b1[0] = 1;
for(int len = 2; len <= n; len <<= 1)
{
for(int i = 0; i < (len >> 1); i++)b2[i] = (b1[i] << 1) % mod;
invp(b2, len);
NTT(b1, 1, len); mul(b1, b1, len); NTT(b1, 0, len);
for(int i = 0; i < len; i++)b1[i] = (f[i] + b1[i]) % mod;
timep(b1, b2, len, len);
}
cpy(f, b1, m); clr(b1, n + n); clr(b2, n + n);
}

多项式求导与积分

根据积分公式和求导的公式直接做就好了。

导数:$F’(x) = \sum_{i=1} f[i]ix^{i-1}$
积分:$\int F(x) = C + \sum_{i=0} \frac{f[i]x^{i+1}}{i}$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
ll inv[N];
void initi(int lim)
{
inv[1] = 1;
for(int i = 2; i <= lim; i++)
inv[i] = inv[mod % i] * (mod - mod / i) % mod;
}
void derp(ll *f, int m)
{
for(int i = 1; i < m; i++)
f[i - 1] = f[i] * i % mod;
f[m - 1] = 0;
}
void intp(ll *f, int m)
{
for(int i = m; i; i--)
f[i] = f[i - 1] * inv[i] % mod;
f[0] = 0;
}

需要预处理逆元。

多项式ln与exp

为了保证我们可以只使用四则运算来得到 $\ln(F(x))$ 和 $e^{F(x)}$,我们需要使用麦克劳林级数:

$$
\begin{gather}
\ln(F(x)) = - \sum_{i=1} \frac{(1-F(x))^i}{i} \\
\exp(F(x)) = \sum_{i=0} \frac{F^i(x)}{i!}
\end{gather}
$$

当然,因为我们有了积分和求导两个操作,我们可以利用 $\ln’(x)=\frac{1}{x}$ 来简化我们的操作。

设 $G(x) = \ln(F(x))$,则

$$
\begin{align}
\ln(F(x)) &= G(x) \\
\frac{d}{dx} \ln(F(x)) &= \frac{d}{dx} G(x) \\
\frac{dF(x)}{dx} \frac{d}{dF(x)} \ln(F(x)) &= \frac{d}{dx} G(x) \\
\frac{F’(x)}{F(x)} &= G’(x)\\
\int \frac{F’(x)}{F(x)} &= G(x)
\end{align}
$$

因为exp是ln的逆运算,所以可以直接牛顿迭代来倍增。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void lnp(ll *f, int m)
{
static ll g[N];
cpy(g, f, m);
invp(g, m); derp(f, m);
timep(f, g, m, m);
intp(f, m - 1);
clr(g, m);
}
void exp(ll *f, int m)
{
static ll s[N], s2[N];
int n = 1;
while(n < m)n <<= 1;
s2[0] = 1;
for(int len = 2; len <= n; len <<= 1)
{
cpy(s, s2, len >> 1);
lnp(s, len);
for(int i = 0; i < len; i++)s[i] = (f[i] - s[i] + mod) % mod;
s[0] = (s[0] + 1) % mod;
timep(s2, s, len, len);
}
cpy(f, s2, m);
clr(s, n); clr(s2, n);
}

注意求exp的时候需要调用求ln的函数。

多项式快速幂

为了做 $O(n \log n)$ 的快速幂,我们需要利用 $\ln(ab) = \ln(a)+b$ 将幂运算转化为乘法。

具体来说, $A^k(x) = e^{\ln(A(x)) \times k}$。

此时需要满足 $a[0] = 1$。

如果不是的话,我们可以将 $A(x)$ 写作 $B(x) \times cx^p$ 的形式,此时 $b[0]=1$,可以套用上面的方法。

具体来说,$A^k(x) = (B(x) \times cx^p)^k = B^k(x) \times c^kxp^k$。
此时需要注意 $c$ 的幂次需要根据费马小定理对 $mod-1$ 取模。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
int n;
char s[N];
ll f[N];
int main()
{
scanf("%d%s", &n, s);
for(int i = 0; i < n; i++)
scanf("%lld", &f[i]);
int p = 0, c;
while(!f[p])p++;
c = qpow(f[p]);
ll k = 0, k2 = 0;
for(int i = 0; isdigit(s[i]); i++)
{
k = (k * 10 + s[i] - '0') % mod;
k2 = (k2 * 10 + s[i] - '0') % (mod - 1);
if(k * p >= n)
{
for(int i = 0; i < n; i++)printf("0 ");
putchar('\n');
return 0;
}
}
initi(n = n - p * k);
for(int i = 0; i < n; i++)f[i] = f[i + p] * c % mod;
clr(f + n, p * k);
lnp(f, n);
for(int i = 0; i < n; i++)f[i] = f[i] * k % mod;
exp(f, n);
for(int i = 0; i < p * k; i++)printf("0 ");
c = qpow(c, mod - 1 - k2);
for(int i = 0; i < n; i++)f[i] = f[i] * c % mod;
for(int i = 0; i < n; i++)
printf("%lld ", f[i]);
putchar('\n');
return 0;
}