NTT和多项式的常用基本操作。
在模意义下的。
导入
一般来说,我们需要进行的操作是在模意义下的。
此时我们就无法使用单位根和实数意义下的FFT了。
我们的解决办法是:
原根
我们选择单位根是因为其具有良好的性质可以被我们利用。
如果我们在模意义下找到了单位根的替代品的话,就可以摆脱狭窄的值域和缓慢的三角函数了。
首先我们看一下我们之前用到的单位根的性质:
- $\omega^k_n = (\omega^1_n)^k$
同时我们需要保证我们找到的这个 $\omega^1_n$ 不会使得所有的 $\omega^i_n$ 都相同。
- $\omega^k_n = \omega^{k \ \bmod{\ n}}_n$
这可以推导出 $\omega^{k + \frac{n}{2}}_n = -\omega^k_n$。
- $\omega^{2k}_{2n} = \omega^k_n$
等价于 $(\omega^k_{2n})^2 = \omega^k_n$。
我们再来看看原根的定义。
首先我们定义「阶」。
对于一个数 $a$,其在模 $p$ 意义下的阶为,最小的整数 $k$ 使得 $a^k \equiv 1 \pmod{p}$。
对于 $a \not \perp p$ 的情况,我们认为其阶为 $\infty$ 或直接认为不存在。
当然,我们一般求阶的时候 $p$ 都是质数,且 $a \in [0,p)$。
记 $\delta_p(a)$ 为 $a$ 在模 $p$ 意义下的阶。
然后是「原根」。
对于一个数 $g$,如果满足 $\delta_p(g) = \varphi(p)$,我们就称 $g$ 为 $p$ 的一个原根。
筛选
当然,我们可以猜到,不是随便一个 $g$ 就可以当做单位根来用的。
我们需要找到一个 $g$ 使得其阶恰好为 $n$ 才可以。
一般来说我们的 $p$ 是个质数,所以其原根的阶会达到 $p-1$,而一般我们不会只处理 $n=p-1$ 的情况。
但是我们可以发现,$g^k$ 的阶数为 $\frac{p-1}{\gcd(k,p-1)}$,这给了我们拓展的机会。
同时我们发现,其仍然是 $p-1$ 的约数,所以当 $n \not | (p-1)$ 的时候我们是找不到阶恰好为 $n$ 的数的。
这也解释了为什么我们常用的NTT模数都是形如 $k \times 2^r + 1$ 的,比如说 $998244353 = 119 \times 2^{23} + 1$,$1004535809 = 479 \times 2^{21} + 1$。
NTT
此时我们把单位根换成筛选后的原根,然后套入 FFT 板子中就好了。
同时我们做一些小优化,比如说预处理单位根、用unsigned long long
来减少取模次数等等。
板子如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| #include<bits/stdc++.h> using namespace std; #define ll long long #define ull unsigned ll const int N = 2100010; const ll G = 114514, mod = 998244353; ll qpow(ll a, ll x = mod - 2) { ll res = 1; while(x) { if(x & 1)res = res * a % mod; a = a * a % mod; x >>= 1; } return res; } const ll invG = qpow(G); int tr[N]; void NTT(ll *g, bool op, int n) { static ull f[N], w[N] = { 1 }; for(int i = 0; i < n; i++)f[i] = g[tr[i]]; for(int l = 1; l < n; l <<= 1) { ull tmpG = qpow(op ? G : invG, (mod - 1) / (l * 2)); for(int i = 1; i < l; i++)w[i] = w[i - 1] * tmpG % mod; for(int k = 0; k < n; k += l * 2) { for(int p = 0; p < l; p++) { ll tmp = w[p] * f[k | l | p] % mod; f[k | l | p] = f[k | p] + mod - tmp; f[k | p] = f[k | p] + tmp; } } if(l == (1 << 10)) for(int i = 0; i < n; i++)f[i] %= mod; } if(!op) { ull invn = qpow(n); for(int i = 0; i < n; i++)g[i] = f[i] % mod * invn % mod; } else { for(int i = 0; i < n; i++)g[i] = f[i] % mod; } } int n, m; ll f[N], g[N]; int main() { scanf("%d%d", &n, &m); n++, m++; for(int i = 0; i < n; i++) scanf("%lld", &f[i]); for(int i = 0; i < m; i++) scanf("%lld", &g[i]); int len = 1; while(len < n + m)len <<= 1; for(int i = 0; i < len; i++) tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0); NTT(f, 1, len), NTT(g, 1, len); for(int i = 0; i < len; i++) f[i] = f[i] * g[i] % mod; NTT(f, 0, len); for(int i = 0; i < n + m - 1; i++) printf("%lld ", f[i]); putchar('\n'); return 0; }
|
多项式四则运算
为方便,下面我们设多项式 $F(x)$ 的 $n$ 次项系数为 $f[n]$,与 $[x^n]F(x)$ 是等价的形式,
加减
直接对应系数加减就好了。
$$
F(x) \pm G(x) = \sum_{i=0} (f[i] \pm g[i])x^i
$$
乘法
根据我们之前的定义,多项式的乘法是类似卷积形式的:
$$
F(x) \times G(x) = \sum_{i=0} x^i \sum_{k=0}^i f[k]g[i-k]
$$
我们可以使用 NTT 来在 $O(n \log n)$ 的时间复杂度内做完。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| #include<bits/stdc++.h> using namespace std; #define ll long long #define ull unsigned ll #define clr(f,n) memset(f,0,(n)*sizeof(ll)); #define cpy(f,g,n) memcpy(f,g,(n)*sizeof(ll)); const int N = 270010; const ll G = 3, mod = 998244353; ll qpow(ll a, ll x = mod - 2) { ll res = 1; while(x) { if(x & 1)res = res * a % mod; a = a * a % mod; x >>= 1; } return res; } const ll invG = qpow(G); int tr[N], tf; void initr(int n) { if(tf == n)return; tf = n; for(int i = 0; i < n; i++) tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0); } void NTT(ll *g, bool op, int n) { initr(n); static ull f[N], w[N] = { 1 }; for(int i = 0; i < n; i++)f[i] = ((mod << 5) + g[tr[i]]) % mod; for(int l = 1; l < n; l <<= 1) { ull tmpG = qpow(op ? G : invG, (mod - 1) / (l * 2)); for(int i = 1; i < l; i++)w[i] = w[i - 1] * tmpG % mod; for(int k = 0; k < n; k += l * 2) { for(int p = 0; p < l; p++) { ll tmp = w[p] * f[k | l | p] % mod; f[k | l | p] = f[k | p] + mod - tmp; f[k | p] = f[k | p] + tmp; } } if(l == (1 << 10)) for(int i = 0; i < n; i++)f[i] %= mod; } if(!op) { ull invn = qpow(n); for(int i = 0; i < n; i++)g[i] = f[i] % mod * invn % mod; } else { for(int i = 0; i < n; i++)g[i] = f[i] % mod; } } void mul(ll *f, ll *g, int n) { for(int i = 0; i < n; i++) f[i] = f[i] * g[i] % mod; } void timep(ll *f, ll *g, int m, int lim) { static ll sav[N]; int n = 1; while(n < m + m)n <<= 1; clr(sav, n); cpy(sav, g, n); NTT(f, 1, n); NTT(sav, 1, n); mul(f, sav, n); NTT(f, 0, n); clr(f + lim, n - lim); clr(sav, n); }
|
在前面NTT的基础上增加了一些宏定义,并略微调整之使得其适配后面的操作。
后面的直接调用此模板即可。
乘法逆元
定义就是对于一个多项式 $F(x)$,找到一个多项式 $G(x)$ 是的 $F(x) \times G(x) = 1$。
首先很明显,$g[0]$ 就是 $f[0]$ 的乘法逆元。
然后我们尝试由此拓展出整个多项式。
直接递推显然不行,我们无法承受 $O(n^2)$ 的时间复杂度。
那么试试倍增。
假设我们当前进行到了第 $n$ 次项,我们尝试将其推广到 $2n$ 次项。
那么我们当前的多项式 $R_*(x)$ 就可以看做是 $G(x) \pmod{x^n}$,而我们将要推广到的多项式 $R(x)$ 就可以看做是 $G(x) \pmod{x^{2n}}$。
这样才可以从 $g[0]$ 开始倍增。
很明显有 $R(x) = R_*(x) \pmod{x^n}$,作差得 $R(x) - R_*(x) = 0 \pmod{x^n}$。
平方一下得 $R^2(x) - 2R_*(x)R(x) + R^2_*(x) = 0 \pmod{x^{2n}}$,
等式两边同时乘以 $F(x)$ 可以把 $R(x)$ 消掉,得 $R(x) - 2R_*(x) + R^2_*(x)F(x) = 0 \pmod{x^{2n}}$,
由此可以根据 $F(x)$ 和 $R_*(x)$ 求出 $R(x)$。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| void invp(ll *f, int m) { int n = 1; while(n < m)n <<= 1; static ll w[N], r[N], sav[N]; w[0] = qpow(f[0]); for(int len = 2; len <= n; len <<= 1) { for(int i = 0; i < (len >> 1); i++)r[i] = w[i]; cpy(sav, f, len); NTT(sav, 1, len); NTT(r, 1, len); mul(r, sav, len); NTT(r, 0, len); clr(r, len >> 1); cpy(sav, w, len); NTT(sav, 1, len); NTT(r, 1, len); mul(r, sav, len); NTT(r, 0, len); for(int i = len >> 1; i < len; i++) w[i] = (w[i] * 2 - r[i] + mod) % mod; } cpy(f, w, m); clr(sav, n); clr(w, n); clr(r, n); }
|
注意此处需要开两倍空间,所有带有乘法逆元操作的都需要开两倍空间。
带余数的除法
如果我们可以直接把余式减去,然后就可以直接做上面的多项式除法了。
但是余式的系数集中在低次项,我们能做的只有 $\bmod{x^k}$。
考虑翻转一下多项式的系数,定义 $F^T(x) = \sum_{i=0} f[n-i]x^i$。
稍微换个形式可以得到 $F^T(x) = x^n F(x^{-1})$。
首先我们有式子 $F(x) = Q(x) \times G(x) + R(x)$,其中 $F(x)$ 和 $G(x)$ 已知。
换元得到 $F(x^{-1}) = Q(x^{-1}) \times G(x^{-1}) + R(x^{-1})$,
两边同时乘以 $x^n$ 得到 $x^n F(x^{-1}) = x^n Q(x^{-1}) \times G(x^{-1}) + x^n R(x^{-1})$,
写成翻转的形式得到 $F^T(x) = Q^T(x) \times G^T(x) + x^{n-m+1} R^T(x)$。
此时余式 $R(x)$ 原本集中在低次项的系数全部到了高次项,而低次项都为 $0$。
我们对 $x^{n-m+1}$ 取模,得到 $F^T(x) = Q^T(x) \times G^T(x) \pmod{x^{n-m+1}}$。
然后我们就可以求出来 $Q^T(x)$,进而得到 $Q(x)$。
剩下的 $R(x)$ 直接减就好了。
1 2 3 4 5 6 7 8 9 10 11 12 13
| void divp(ll *f, ll *g, int n, int m) { static ll q[N], t[N]; int l = n - m + 1; rev(g, m); cpy(q, g, l); rev(g, m); rev(f, n); cpy(t, f, l); rev(f, n); invp(q, l); timep(q, t, l, l); rev(q, l); timep(g, q, n, n); for(int i = 0; i < m - 1; i++) g[i] = (f[i] - g[i] + mod) % mod; clr(g + m - 1, l); cpy(f, q, l); clr(f + l, n - l); }
|
函数执行完之后,f
数组内是商式,g
数组内是余式。
多项式开根
我们有式子 $F(x) - G^2(x) = 0$,其中 $F(x)$ 已知。
设 $H(G(x)) = G(x)^2 - F(x)$,则 $H’(G(x)) = 2G(x)$。
根据牛顿迭代 $G(F(x)) = F_*(x) - \frac{G(F_*(x))}{G’(F_*(x))} \pmod{x^n}$,得到:
$$
\begin{align}
G(x) &= G_*(x) - \frac{H(G_*(x))}{H’(G_*(x))} \\
&= G_*(x) - \frac{G^2_*(x) - F(x)}{2G_*(x)} \\
&= G_*(x) + \frac{F(x) - G^2_*(x)}{2G_*(x)} \\
&= \frac{F(x) + G^2_*(x)}{2G_*(x)}
\end{align}
$$
由此可以倍增。
因为题目保证了 $f[0] = 1$,所以我们的 $g[0]$ 也就是 $1$ 了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| void sqrtp(ll *f, int m) { int n = 1; while(n < m)n <<= 1; static ll b1[N], b2[N]; b1[0] = 1; for(int len = 2; len <= n; len <<= 1) { for(int i = 0; i < (len >> 1); i++)b2[i] = (b1[i] << 1) % mod; invp(b2, len); NTT(b1, 1, len); mul(b1, b1, len); NTT(b1, 0, len); for(int i = 0; i < len; i++)b1[i] = (f[i] + b1[i]) % mod; timep(b1, b2, len, len); } cpy(f, b1, m); clr(b1, n + n); clr(b2, n + n); }
|
多项式求导与积分
根据积分公式和求导的公式直接做就好了。
导数:$F’(x) = \sum_{i=1} f[i]ix^{i-1}$
积分:$\int F(x) = C + \sum_{i=0} \frac{f[i]x^{i+1}}{i}$
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| ll inv[N]; void initi(int lim) { inv[1] = 1; for(int i = 2; i <= lim; i++) inv[i] = inv[mod % i] * (mod - mod / i) % mod; } void derp(ll *f, int m) { for(int i = 1; i < m; i++) f[i - 1] = f[i] * i % mod; f[m - 1] = 0; } void intp(ll *f, int m) { for(int i = m; i; i--) f[i] = f[i - 1] * inv[i] % mod; f[0] = 0; }
|
需要预处理逆元。
多项式ln与exp
为了保证我们可以只使用四则运算来得到 $\ln(F(x))$ 和 $e^{F(x)}$,我们需要使用麦克劳林级数:
$$
\begin{gather}
\ln(F(x)) = - \sum_{i=1} \frac{(1-F(x))^i}{i} \\
\exp(F(x)) = \sum_{i=0} \frac{F^i(x)}{i!}
\end{gather}
$$
当然,因为我们有了积分和求导两个操作,我们可以利用 $\ln’(x)=\frac{1}{x}$ 来简化我们的操作。
设 $G(x) = \ln(F(x))$,则
$$
\begin{align}
\ln(F(x)) &= G(x) \\
\frac{d}{dx} \ln(F(x)) &= \frac{d}{dx} G(x) \\
\frac{dF(x)}{dx} \frac{d}{dF(x)} \ln(F(x)) &= \frac{d}{dx} G(x) \\
\frac{F’(x)}{F(x)} &= G’(x)\\
\int \frac{F’(x)}{F(x)} &= G(x)
\end{align}
$$
因为exp是ln的逆运算,所以可以直接牛顿迭代来倍增。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| void lnp(ll *f, int m) { static ll g[N]; cpy(g, f, m); invp(g, m); derp(f, m); timep(f, g, m, m); intp(f, m - 1); clr(g, m); } void exp(ll *f, int m) { static ll s[N], s2[N]; int n = 1; while(n < m)n <<= 1; s2[0] = 1; for(int len = 2; len <= n; len <<= 1) { cpy(s, s2, len >> 1); lnp(s, len); for(int i = 0; i < len; i++)s[i] = (f[i] - s[i] + mod) % mod; s[0] = (s[0] + 1) % mod; timep(s2, s, len, len); } cpy(f, s2, m); clr(s, n); clr(s2, n); }
|
注意求exp的时候需要调用求ln的函数。
多项式快速幂
为了做 $O(n \log n)$ 的快速幂,我们需要利用 $\ln(ab) = \ln(a)+b$ 将幂运算转化为乘法。
具体来说, $A^k(x) = e^{\ln(A(x)) \times k}$。
此时需要满足 $a[0] = 1$。
如果不是的话,我们可以将 $A(x)$ 写作 $B(x) \times cx^p$ 的形式,此时 $b[0]=1$,可以套用上面的方法。
具体来说,$A^k(x) = (B(x) \times cx^p)^k = B^k(x) \times c^kxp^k$。
此时需要注意 $c$ 的幂次需要根据费马小定理对 $mod-1$ 取模。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| int n; char s[N]; ll f[N]; int main() { scanf("%d%s", &n, s); for(int i = 0; i < n; i++) scanf("%lld", &f[i]); int p = 0, c; while(!f[p])p++; c = qpow(f[p]); ll k = 0, k2 = 0; for(int i = 0; isdigit(s[i]); i++) { k = (k * 10 + s[i] - '0') % mod; k2 = (k2 * 10 + s[i] - '0') % (mod - 1); if(k * p >= n) { for(int i = 0; i < n; i++)printf("0 "); putchar('\n'); return 0; } } initi(n = n - p * k); for(int i = 0; i < n; i++)f[i] = f[i + p] * c % mod; clr(f + n, p * k); lnp(f, n); for(int i = 0; i < n; i++)f[i] = f[i] * k % mod; exp(f, n); for(int i = 0; i < p * k; i++)printf("0 "); c = qpow(c, mod - 1 - k2); for(int i = 0; i < n; i++)f[i] = f[i] * c % mod; for(int i = 0; i < n; i++) printf("%lld ", f[i]); putchar('\n'); return 0; }
|