[CF1090J]Two Prefixes

Description

给定字符串 $S, T$,求 $S$ 的非空前缀拼上 $T$ 的非空前缀形成的本质不同的字符串数。

$$
1 \le |S|, |T| \le 10 ^5
$$

Solution

后缀自动机做法还不会,略。

对文本串 $T$ 建出 KMP 自动机,匹配串 $S$ 在上面跑所到的最终节点 $u$ 满足 $u$ 所代表的 $T$ 的前缀是最长的,能匹配 $S$ 的某个后缀的前缀节点

考虑简单容斥,总字符串数为 $|S| \times |T|$,减去重复字符串数。

假设 $a, a’(|a| < |a’|)$ 和 $b, b’$ 分别为 $S, T$ 的某个前缀,并且 $a + b = a’ + b’$。

那么 $b’$ 一定为 $b$ 的某个 border(换言之,$b’$ 既是 $b$ 的前缀,也是 $b$ 的后缀)。

紫色和蓝色部分分别相同。

可以得到 $a’$ 的去掉 $a$ 的那段后缀和 $b$ 去掉 $b’$ 的部分(也就是紫色部分)完全一样。因为后缀的前缀相当于子串,所以枚举所有的 $b$,令 $b’$ 为 $b$ 最长的 border,用 $b’ - b$ 得到 $T$ 的一段前缀,这段前缀在 $S$ 中作为子串出现的次数就是重复次数。

为何不是所有 border 而是最长的 border,是因为要避免重复减去。

黄色部分为 $T _j$ 的最长的 border,蓝线部分相同。由上面的内容,$S _i + T _{j} = S _{BLUE} + T _{YELLOW}$。此时视作删去 $S _i + T _j$,而非删去 $S _{BLUE} + T _{YELLOW}$。

如果计算所有的 border,$S _i + T _j$ 这个不合法的贡献就有可能会被删去多于一次。因为对于任意的 $S _i + T _j$ 都是删除不少于一次,所以一定不合法。取最长的 border 是为了保证能删除所有的不合法的情况。

现在还需要解决找到 $T$ 的某一个前缀在 $S$ 中出现了几次的问题。将 $T$ 的 KMP 自动机建出来,用 $S$ 的字符逐个匹配,可以得到 $S$ 的所有后缀与 $T$ 能匹配的最长公共前缀。

设 $u$ 是 $S$ 的某个后缀匹配到的最长公共前缀的前缀节点,那么需要在 fail 树上将 $u$ 到根的路径全部 +1,因为 $u$ 仅代表最长而不是所有,需要将是它的 border 的前缀都 +1。

这是个简单的树上差分问题,统计完后直接做就行。

$O(n)$。代码里面记得注意一些细节,因为不能够接空串。也不需要显式的建出 KMP 自动机并补转移边,直接暴力跳 fail 就行了,复杂度是对的。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

typedef long long LL;

const int maxN = 1e5 + 5;

int n, m;
LL ans;
int fail[maxN], cnt[maxN];
char S[maxN], T[maxN];

int main()
{
freopen("password.in", "r", stdin);
freopen("password.out", "w", stdout);
ios::sync_with_stdio(false);
cin >> (S + 1) >> (T + 1);
n = strlen(S + 1), m = strlen(T + 1);
for (int i = 2, j = 0; i <= m; ++i)
{
while (j and T[j + 1] != T[i])
j = fail[j];
j += (T[j + 1] == T[i]);
fail[i] = j;
}
for (int i = 1, j = 0; i <= n; ++i)
{
while (j and T[j + 1] != S[i])
j = fail[j];
j += (T[j + 1] == S[i]);
if (j < i)
++cnt[j];
else
++cnt[fail[j]];
}
for (int i = m; i; --i)
cnt[fail[i]] += cnt[i];
(ans = n) *= m;
for (int i = 1; i <= m; ++i)
if (fail[i] and i - fail[i] > 0)
ans -= cnt[i - fail[i]];
cout << ans << endl;
return 0;
}