FFT

FFT学习

代码见github

P开头的题目来自luogu.com.cn

Introduction

FFT用于解决离散卷积问题。比如说已知

{f0,f1,,fn1}{g0,g1,,gn1}\begin{aligned} \{f_0, f_1, \cdots, f_{n - 1}\} \\ \{g_0, g_1, \cdots, g_{n - 1}\} \end{aligned}

{h0,h1,,hn1}\{h_0, h_1, \cdots, h_{n - 1}\}

满足

hi=j=0ifjgijh_i = \sum_{j = 0}^{i} f_j\cdot g_{i - j}

h=fgh = f * g

Analysis

插值定理

已知nn个点(x0,y0),(x1,y1),,(xn1,yn1)(x_0, y_0), (x_1, y_1), \cdots, (x_{n - 1}, y_{n - 1}),其中xixj(ij),xi,yiCx_i \neq x_j (i\neq j), x_i, y_i\in \mathbb C,则可以唯一确定通过这nn个点的n1n-1次多项式

pn1(x)=a0+a1x+a2x2++an1xn1p_{n-1}(x) = a_0 + a_1x + a_2 x^2 + \cdots + a_{n - 1}x^{n - 1}

证明

设多项式形式如上,则

{y0=a0+a1x0++an1x0n1y1=a0+a1x1++an1x1n1yn1=a0+a1xn1++an1xn1n1\begin{cases} y_0 &= a_0 + a_1x_0 + \cdots + a_{n - 1}x_0^{n-1}\\ y_1 &= a_0 + a_1x_1 + \cdots + a_{n - 1}x_1^{n-1}\\ &\vdots\\ y_{n-1} &= a_0 + a_1x_{n-1} + \cdots + a_{n - 1}x_{n-1}^{n-1}\\ \end{cases}

写成矩阵的形式:

(y0y1yn1)=(1x0x0n11x1x1n11xn1xn1n1)(a0a1an1)\begin{pmatrix} y_0\\y_1\\\vdots\\y_{n-1} \end{pmatrix}=\begin{pmatrix} 1 & x_0 & \cdots & x_0^{n-1}\\ 1 & x_1 & \cdots & x_{1}^{n-1}\\ \vdots & \vdots & \ddots & \vdots\\ 1 & x_{n-1} & \cdots & x_{n-1}^{n-1} \end{pmatrix}\begin{pmatrix} a_0\\a_1\\\vdots\\a_{n-1}\end{pmatrix}

中间大矩阵是一个vandermonde矩阵,行列式为i>j(xixj)0\prod_{i>j} (x_i - x_j)\neq 0,所以是可逆的。

y\bf y已知,求a\bf a,用a=X1y\bf a = \bf X^{-1} \bf y即可。

问题转换

不妨把ffgg数列都写成多项式的形式,即:

f(x)=f0+f1x++fn1xn1g(x)=g0+g1x++gn1xn1\begin{aligned} f(x) = f_0 + f_1x + \cdots + f_{n-1}x^{n-1}\\ g(x) = g_0 + g_1x + \cdots + g_{n-1}x^{n-1} \end{aligned}

f×gf\times g可以算出来:

h(x)=f(x)×g(x)=i=02n2xijfjgijh(x) = f(x)\times g(x) = \sum_{i=0}^{2n-2} x^{i}\sum_{j} f_jg_{i-j}

得到的h(x)h(x)是一个2n22n-2次多项式,且h(x)h(x)x0,x1,,xn1x^0, x^1, \cdots, x^{n-1}的系数就是hh数列。

假设我们选定x0,x1,,x2n2x_0, x_1, \cdots, x_{2n-2},用这2n12n-1个点作为横坐标对ffgg两个函数进行插值,由于ffgg是已知的,所以可以计算f(x0),f(x1),,f(x2n2)f(x_0), f(x_1), \cdots, f(x_{2n-2})g(x0),g(x1),,g(x2n2)g(x_0), g(x_1), \cdots, g(x_{2n-2})

同时有

h(x)=f(x)g(x)h(x) = f(x)g(x)

所以

h(x0)=f(x0)g(x0)h(x1)=f(x1)g(x1)h(x2n2)=f(x2n2)g(x2n2)\begin{aligned} h(x_0) &= f(x_0)g(x_0)\\ h(x_1) &= f(x_1)g(x_1)\\ \vdots &\\ h(x_{2n-2}) &= f(x_{2n-2})g(x_{2n-2}) \end{aligned}

所以我们得到了hh函数的2n12n-1个插值,于是唯一确定了h(x)h(x)这个函数。得到该函数后取前nn个系数就可以得到hh数列。

于是我们期望的流程是:

  • ffgg函数插值;
  • 计算出hh在各个插值点的值;
  • 反解出hh的系数。

为了以后说明的方便,我们假设n=2mn=2^m

Interpolate

接下来研究怎么对ffgg函数插值。只研究ff函数,因为gg完全相同。

为了方便,我们现在假设只需要对ff进行nn个点的插值(本身需要2n12n-1个点,但是我们也可以直接把n2nn\leftarrow 2n,然后假设ff是一个2n2n次多项式,进行2n2n个点的插值,效果相同)。

计算f(x0)f(x_0)O(n)O(n)的,所以获得所有插值就O(n2)O(n^2)了!

怎么解决这个问题呢?我们需要取一些特殊的插值点,以便重复利用插值点的数值。

如果有两个插值点x0x_0x1x_1满足x0=x1x_0 = -x_1,则我们如果对ff的系数进行奇偶分类(我们假设n=2mn=2^m):

fe(x)=f0+f2x+f4x2++fn2xn/21fo(x)=f1+f3x+f5x2++fn1xn/21\begin{aligned} f_e(x) = f_0 + f_2x + f_4x^2 + \cdots + f_{n-2}x^{n/2-1}\\ f_o(x) = f_1 + f_3x + f_5x^2 + \cdots + f_{n-1}x^{n/2-1} \end{aligned}

那么有

f(x0)=fe(x02)+x0fo(x02)f(x1)=fe(x12)+x1fo(x12)\begin{aligned} f(x_0) = f_e(x_0^2) + x_0f_o(x_0^2)\\ f(x_1) = f_e(x_1^2) + x_1f_o(x_1^2)\\ \end{aligned}

由于x02=x12x_0^2 = x_1^2,可以化为

f(x0)=fe(x02)+x0fo(x02)f(x1)=fe(x02)x0fo(x02)\begin{aligned} f(x_0) = f_e(x_0^2) + x_0f_o(x_0^2)\\ f(x_1) = f_e(x_0^2) - x_0f_o(x_0^2)\\ \end{aligned}

也就是说,我们只需要计算fef_efof_o两个长度为n/2n/2的函数在x02x_0^2处的插值信息,我们就可以得到ffx0x_0x1x_1处的插值信息。

但是,如果插值点都取实数的话,到第二层,所有x0x\geq 0,也就不存在相反数了(不能递归)。

解决这个问题的办法就是采用复数单位根来作为插值点!

对于长度为nn的函数ff,我们把复平面上的单位圆分成nn份来取插值点。具体而言:

xk=e2jπkn=cos(2πkn)+jsin(2πkn)x_k = e^{\frac{2\mathrm j\pi k}{n}} = \cos(\frac{2\pi k}{n}) + \mathrm j \sin(\frac{2\pi k}{n})

其中j=1\mathrm j = \sqrt -1

为了方便,我们可以记ω=e2jπn\omega = e^{\frac{2\mathrm j\pi}{n}},则

xk=ωkx_k = \omega^k

那么就有性质:

xk=xk+n2,k<n2x_k = -x_{k + \frac{n}{2}}, k < \frac{n}{2}

这样就可以两两凑成一对,问题转化为了两个n2\frac{n}{2}的子问题,并且由于xk2x_k^2刚好可以覆盖所有的ejπkne^{\frac{\mathrm j\pi k}{n}},所以这个问题可以一直递归下去,边界就是返回当前的系数。

复杂度:

T(n)=2T(n/2)+O(n)    T(n)=O(nlogn)T(n) = 2T(n/2) + O(n)\implies T(n) = O(n\log n)

Calculate hh

第二步:计算出hh函数在插值点的值,直接乘就好了,O(n)O(n)。接下来研究第三步:反解出hh的系数。

我们不妨先回顾一下第一步解决了什么问题:

给定n1n-1次多项式ff,求出ffnn次单位根上的值。
形式化的说:已知f0,f1,,fn1f_0, f_1, \cdots, f_{n-1},计算y0,y1,,yn1y_0, y_1, \cdots, y_{n-1}

(y0y1yn1)=(1x0x0n11x1x1n11xn1xn1n1)(f0f1fn1)\begin{pmatrix} y_0\\y_1\\\vdots\\y_{n-1} \end{pmatrix} = \begin{pmatrix} 1 & x_0 & \cdots & x_0^{n-1}\\ 1 & x_1 & \cdots & x_1^{n-1}\\ \vdots & \vdots & \ddots & \vdots\\ 1 & x_{n-1} & \vdots & x_{n-1}^{n-1} \end{pmatrix}\begin{pmatrix} f_0\\f_1\\\vdots\\f_{n-1} \end{pmatrix}

其中xk=ωkx_k = \omega^kω=e2jπ/n\omega = e^{2\mathrm j\pi/n}

为了方便,我们依然认为hh函数的长度只有nn(如果不够,用之前相同的方法,后面补00,把f,g,hf,g,h都补到相同的n=2mn=2^m)。

那么当前的问题是

(y0(h)y1(h)yn1(h))=(1x0x0n11x1x1n11xn1xn1n1)(h0h1hn1)\begin{pmatrix} y^{(h)}_0\\y^{(h)}_1\\\vdots\\y^{(h)}_{n-1} \end{pmatrix} = \begin{pmatrix} 1 & x_0 & \cdots & x_0^{n-1}\\ 1 & x_1 & \cdots & x_1^{n-1}\\ \vdots & \vdots & \ddots & \vdots\\ 1 & x_{n-1} & \vdots & x_{n-1}^{n-1} \end{pmatrix}\begin{pmatrix} h_0\\h_1\\\vdots\\h_{n-1} \end{pmatrix}

已知y(h)\mathbf y^{(h)},求h\mathbf h

我们知道中间那个矩阵是可逆的,所以我们现在尝试求他的逆(普通方法是O(n3)O(n^3)的,我们需要更快的办法)。

令中间的矩阵为X\bm X,则

Xij=(ωi)j=ωij\bm X_{ij} = (\omega^{i})^j = \omega^{ij}

则可以看出X\bm X是一个对称矩阵。计算XXT\bm X\bm X^{*T}

XXijT=k=0n1ωikωkj=1ω(ij)n1ωij={0ijni=j\bm X\bm X^{*T}_{ij} = \sum_{k=0}^{n-1} \omega^{ik}\omega^{-kj} = \frac{1 - \omega^{(i-j)n}}{1 - \omega^{i-j}} = \begin{cases} 0 & i\neq j\\ n & i = j\end{cases}

所以说$$\bm X^{-1} = \frac{1}{n} \bm X^{T} = \frac{1}{n}\bm X^{}$$

那么问题就转化为了

h=Xy(h)\bm h = \bm X^{*}\bm y^{(h)}

这个问题和第一阶段的问题非常相似,只不过第一阶段用的ω\omega换成了ω1\omega^{-1}

所以我们再做一次第一阶段的问题,就可以在O(nlogn)O(n\log n)时间内完成反解!

NTT

复数的计算依赖浮点数,容易丢失精度。

于是可以使用数论变换代替(寻找一个恰当的模数MODMOD,某个整数ω\omega在模MODMOD运算下可以形成一个2m2^m的环)。

实现技巧

  1. 复数运算加速:使用手写的complex类,可以快1.5x~2x(见P3803_fast.cpp);
  2. 使用倍增替代递归:朴素的实现(不改变数据分布)不会带来好处,只会让cache全部未命中(stride太大)。见P3803_loop.cpp);更优化的实现,将初始数据排布变为左边都是偶数,右边都是奇数,连续访问内存。见P3803_loop_v2.cpp)。
  3. 减少运算冗余:每次插值计算只需要算一半,另一半用缓存好的取相反数;
  4. 使用NTT解决精度问题。没有使用NTT很难过P4721,见P4721_naiveFFT.cpp。优化过的见P4721_NTT.cpp。可以通过python3 P4721_judger.py检验正确性。

练习题

编译环境

g++14, macOS

参考vscode tasks.json

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
{
"version": "2.0.0",
"tasks": [
{
"type": "shell",
"label": "C++ build & run",
"command": "/opt/homebrew/bin/g++-14",
"args": [
"-g",
"${file}",
"-o",
"${fileDirname}/build/${fileBasenameNoExtension}",
"&&",
"${fileDirname}/build/${fileBasenameNoExtension}"
],
"options": {
"cwd": "${fileDirname}"
},
"problemMatcher": [
"$gcc"
],
"group": "build",
"detail": "compiler: /opt/homebrew/bin/g++-14"
}
]
}