AC Automaton

记录AC自动机算法的学习。

回顾KMP算法

对于一个字符串ss,每一位的下标为1,2,,n1, 2, \cdots, n,则我们可以求出一个数组pp,满足以下约定:

sis_i表示ssii个字符构成的子串;
pip_i表示sis_i的子串(不包括sis_i自身)中,最长的、且同时为sis_i的前后缀的子串长度。

这样的定义同时保证了pip_i指向了sis_i的最长前后缀的最后一个字符。

考虑用O(n)O(n)时间求出这个pp数组:

假设p0=0p_0 = 0,且p0,p1,,pi1p_0, p_1, \cdots, p_{i - 1}都已经求好,那么我们如下方法求pip_i

  1. j=pi1j = p_{i - 1}
  2. 如果j+1j+1对应的字符和ii对应的字符相等,那么pip_i就是j+1j+1
  3. 如果不满足(2),且j=0j=0,则pi=0p_i=0
  4. 如果(2)(3)均不满足,则jpjj\leftarrow p_j,回到(2)。

这样做的正确性是:

  1. 首先pipi1+1p_i \leq p_{i - 1} + 1,因为只要定了pip_i,则往前推一格就能得到pi1p_{i-1}的下界;
  2. 同理,当pi1+1p_{i-1}+1ii对应的字符不相等时,pippi1+1p_i\leq p_{p_{i - 1}} + 1
  3. 以此类推,所以只要不停地往前迭代pp即可。

复杂度正确性:

pp每次最多增加11,往前迭代过程中jj每次最少减少11,所以总共最多迭代nn次。

参考代码(【模板】KMP

注意特判i=1i=1的情况,否则会pi=ip_i = i

Show Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <iostream>
#include <vector>
signed main() {
std::string s1, s2; std::cin >> s1 >> s2;
s1 = '.' + s1;
s2 = '.' + s2;
const int m = s1.length() - 1;
const int n = s2.length() - 1;
std::vector<int> p(n + 1, 0);
for (int i = 1, j = 0; i <= n; ++i) {
while (j && s2[j + 1] != s2[i]) j = p[j];
if (j + 1 != i && s2[j + 1] == s2[i]) p[i] = ++j;
}
for (int i = 1, j = 0; i <= m; ++i) {
while (j && s2[j + 1] != s1[i]) j = p[j];
if (s2[j + 1] == s1[i]) ++j;
if (j == n) (std::cout << (i - j + 1) << std::endl), j = p[j];
}
for (int i = 1; i <= n; ++i) {
std::cout << p[i] << ' ';
}
std::cout << std::endl;
return 0;
}
## AC自动机

问题引入

假设现在有nn个字符串s1,s2,,sns_1, s_2, \cdots, s_n,要判断每个字符串是否是字符串tt的子串。

如果用KMP算法,对于nn个字符串求失配指针pp,复杂度为O(nsi)O(n\sum \left|s_i\right|);分别跟tt匹配,复杂度为O(t)=O(nt)O(\sum \left|t\right|) = O(n\left|t\right|),复杂度为O(n(si+t))O(n(\sum \left|s_i\right| + \left|t\right|)),当t\left|t\right|很大时就GG了。

AC自动机算法

搞一个Trie,把所有sis_i都丢进去,其中根节点编号为00,设点uu的子节点中,连着字符cc的子节点为neu,cne_{u, c},没有则是00

然后我们模仿KMP对字典树上每个节点求失配指针pp。在树上进行一个BFS,每次求vv节点的pp时,深度小于vv的节点都已经求好了。设uuvv的父亲,(u,v)(u,v)边对应字符为cc,那么我们不停的找到pup_u,如果nepu,cne_{p_u, c}存在,那么就找到了,否则一直取pp。这样就可以找到所有点的pp

匹配的时候,如果uu的下一个失配了,那么我们就跳到pup_u继续尝试匹配。

复杂度分析

建树这一步复杂度为O(nm×si)O(nm\times\sum\left|s_i\right|),其中mm表示字符数量。

但是求解失配指针的这一步好像炸缸了,可能会被卡到平方级(每次都往前跳O(n)O(n)下),但不知道为啥过了

为了优化求失配指针的过程,我们不能再一直往前跳了,但是匹配的时候,由于要匹配的东西是一个线性结构,依然满足之前的均摊复杂度,只要保证每次跳完长度至少1-1即可。

回顾一下需要往前继续跳的情况:nepu,c=0ne_{p_u, c}=0。反正空着也是空着,我们不妨让nepu,c=neppu,cne_{p_u, c}=ne_{p_{p_u}, c},这样就有个很有意思的结果:如果nene不存在,那么nene会指向第一个存在的失配指针,因为如果它指向的也不存在,那么那个不存在的会指向一个存在的(归纳)。

这样处理之后,我们就不需要跳了,直接设pv=nepu,cp_v = ne_{p_u, c}就好了。

这样求解失配指针的复杂度就优化到了O(nmsi)O(nm\sum\left|s_i\right|)

最后就是查询。以此题为例。

首先可以想到在每个trie的节点上开一个vector记录当前节点是哪些字符串的结尾。

tt进行匹配,目标字符串tt对于第ii位匹配到节点uu,则表示到uu节点的路径的所有后缀都可以匹配到第ii位。给节点uu打上标记vis[u]++,最后我们希望所有的pu,ppu,p_u, p_{p_u}, \cdots都可以继承vis[u]

直接暴力跳会和之前一样复杂度爆炸,不如用拓扑排序(pp数组建成的图当然是一个DAG)。

这样查询的复杂度是O(t)O(\left|t\right|)

修改后也没快多少。

总的时间复杂度:O(nmsi+t)O(nm\sum\left|s_i\right|+\left|t\right|),如果mm可视为常数(比如说26),则复杂度是O(nsi+t)O(n\sum\left|s_i\right| + \left|t\right|),比KMP快,尤其是在t\left|t\right|比较大的情况下。

但是当mm不可忽略时还有待研究。可能可以使用类似map的东西,带一个log\log的复杂度。

参考代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <iostream>
#include <vector>
#include <cstring>
#include <queue>
#include <algorithm>

const int M = 26;
class acAutomaton {
private:
std::vector<std::vector<int> > ne;
std::vector<std::vector<int> > end;
std::vector<int> p;
int idx = 0;
int num = 0;
int add() {
ne.push_back(std::vector<int>(M, 0));
end.push_back(std::vector<int>());
return idx++;
}
void calcPre() {
p.resize(idx);
p[0] = 0;
std::queue<int> q;
for (int i = 0; i < M; ++i) {
if (ne[0][i]) q.push(ne[0][i]);
// 注意,直接插入根节点会出现和KMP一样的问题
}

while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < M; ++i) {
int& v = ne[u][i];
if (v) {
p[v] = ne[p[u]][i]; // 上一个可以的地方
q.push(v);
} else {
v = ne[p[u]][i]; // 直接插到上一个可以的地方
}
}
}

// for (int i = 0; i < idx; ++i) {
// // for (int j = 0; j < M; ++j) {
// // printf("i = %d, j = %d, ne = %d\n", i, j, ne[i][j]);
// // }
// // printf(">>i = %d, p = %d\n", i, p[i]);
// }
}
public:
acAutomaton() {add(); }
void addToTrie(const std::string& s) {
int now = 0;
for (const char& c : s) {
if (!ne[now][c - 'a']) ne[now][c - 'a'] = add();
now = ne[now][c - 'a'];
// end[now] = (&c == &s.back() ? ++num : 0);
if (&c == &s.back()) end[now].push_back(num++);
// if (end[now]) printf("end %d = %d\n", now, end[now]);
}
}
std::vector<int> calc(const std::string& t) {
calcPre();
std::vector<int> vis(idx, 0);
int u = 0;
for (const char c : t) {
while (u && !ne[u][c - 'a']) u = p[u];
u = ne[u][c - 'a'];
vis[u] ++;
}

// for (int i = 1; i < idx; ++i) {
// int j = i;
// if (!vis[j]) continue;
// while (p[j]) j = p[j], vis[j] += vis[j];
// }

std::vector<int> d(idx, 0);
std::queue<int> q;
for (int i = 1; i < idx; ++i) {
d[p[i]]++;
}
for (int i = 1; i < idx; ++i) {
if (!d[i]) q.push(i);
}

while (!q.empty()) {
int u = q.front(); q.pop();
if (!u) continue;
int v = p[u];
vis[v] += vis[u];
--d[v];
if (!d[v]) q.push(v);
}

std::vector<int> res(num);
for (int i = 1; i < idx; ++i) {
// printf("vis[%d] = %d\n", i, vis[i]);
if (end[i].empty()) continue;
for (const int j : end[i]) res[j] = vis[i];
}
return res;
}
};

signed main() {
int n; std::cin >> n;
acAutomaton T;
for (int i = 0; i < n; ++i) {
std::string s; std::cin >> s;
T.addToTrie(s);
}
std::string t; std::cin >> t;
for (const auto x : T.calc(t)) std::cout << x << std::endl;
return 0;
}