NKOJ 4090 找相同子串(后缀自动机/后缀数组+线段树)

P4090[HAOI2016]找相同子串

问题描述

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。
两个方案不同当且仅当这两个子串中有一个位置不同。

输入格式

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

输出格式

输出一个整数表示答案

样例输入

aabb
bbaa

样例输出

10


首先看看简洁优美的自动机做法。

将两个串中间加个字符,连起来建机子。

其实只需要在求Right集合大小的时候把他拆成两部分,一部分是第一个子串的Right,另一部分是第二个的。然后每个点求和,即$Ans=\sum v_1[x] \times v_2[x] \times (Max[x]-Max[pra[x]])$


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
#define N 1000005
using namespace std;
char s1[N],s2[N];
long long Ans;
int n,m,tot=1,las=1,rt=1,Max[N],pra[N],son[N][27],v[N][2];
int TOT,LA[N],NE[N],EN[N];
int NP(int x)
{
Max[++tot]=x;
return tot;
}
void Ins(int t,int d)
{
int p=las,q,np,nq;
np=NP(Max[p]+1);v[np][d]=1;
while(p&&!son[p][t])son[p][t]=np,p=pra[p];
if(!p)pra[np]=rt;
else
{
q=son[p][t];
if(Max[q]==Max[p]+1)pra[np]=q;
else
{
nq=NP(Max[p]+1);
memcpy(son[nq],son[q],sizeof(son[q]));
pra[nq]=pra[q];
pra[q]=pra[np]=nq;
while(son[p][t]==q)son[p][t]=nq,p=pra[p];
}
}
las=np;
}
void ADD(int x,int y)
{
TOT++;
EN[TOT]=y;
NE[TOT]=LA[x];
LA[x]=TOT;
}
void DFS(int x)
{
int i,y;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];DFS(y);
v[x][0]+=v[y][0];
v[x][1]+=v[y][1];
}
Ans+=1ll*v[x][0]*v[x][1]*(Max[x]-Max[pra[x]]);
}
int main()
{
scanf("%s%s",s1,s2);
n=strlen(s1);
m=strlen(s2);
for(int i=0;i<n;i++)Ins(s1[i]-'a',0);Ins(26,0);
for(int i=0;i<m;i++)Ins(s2[i]-'a',1);
for(int i=1;i<=tot;i++)ADD(pra[i],i);
DFS(rt);printf("%lld",Ans);
}

然后再来看看卡到GG的后缀数组搞法。
利用单调性可以搞成线性的,然而为了方便当然是线段树。

两个串连起来建好机子,倒起讨论,用线段树维护一下当前Height值的数量即可。
注意到相同子串个数恰是LCP之和。


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include<stdio.h>
#include<algorithm>
#include<cstring>
#define ll long long
#define N 555555
using namespace std;
char s[N],A[N],B[N];
int n,m,SA[N],H[N],Rank[N];
int wa[N],wb[N],T[N];
int tot,ls[N*4],rs[N*4],lazy[N*4][2],cnt[N*4][2];
ll sum[N*4][2];
bool cmp(int *r,int a,int b,int l)
{return r[a]==r[b]&&r[a+l]==r[b+l];}
void GSA(char *r,int *sa,int a,int b)
{
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<a;i++)T[x[i]=r[i]]++;
for(i=1;i<b;i++)T[i]+=T[i-1];
for(i=a-1;i>=0;i--)sa[--T[x[i]]]=i;
for(p=1,j=1;p<a;j<<=1,b=p)
{
for(p=0,i=a-j;i<a;i++)y[p++]=i;
for(i=0;i<a;i++)if(sa[i]>=j)y[p++]=sa[i]-j;
for(i=0;i<b;i++)T[i]=0;
for(i=0;i<a;i++)T[x[y[i]]]++;
for(i=1;i<b;i++)T[i]+=T[i-1];
for(i=a-1;i>=0;i--)sa[--T[x[y[i]]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<a;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
void GH(char *r,int *sa,int a)
{
int i,j,k=0;
for(i=1;i<=a;i++)Rank[sa[i]]=i;
for(i=0;i<a;H[Rank[i++]]=k)
for(k?k--:0,j=sa[Rank[i]-1];r[i+k]==r[j+k];k++);
}
void PD(int p,int t)
{
lazy[ls[p]][t]=lazy[rs[p]][t]=1;lazy[p][t]=0;
sum[ls[p]][t]=sum[rs[p]][t]=cnt[ls[p]][t]=cnt[rs[p]][t]=0;
}
int BT(int x,int y)
{
int p=++tot;
if(x<y)
{
int mid=x+y>>1;
ls[p]=BT(x,mid);
rs[p]=BT(mid+1,y);
}
return p;
}
void ADD(int p,int l,int r,int k,int d,int t)
{
if(lazy[p][t])PD(p,t);
if(l==r){cnt[p][t]+=d;sum[p][t]=1ll*cnt[p][t]*l;return;}
int mid=l+r>>1;
if(k<=mid)ADD(ls[p],l,mid,k,d,t);
else ADD(rs[p],mid+1,r,k,d,t);
cnt[p][t]=cnt[ls[p]][t]+cnt[rs[p]][t];
sum[p][t]=sum[ls[p]][t]+sum[rs[p]][t];
}
int GC(int p,int l,int r,int x,int y,int t)
{
if(lazy[p][t])return 0;
if(x<=l&&y>=r)
{
int k=cnt[p][t];
cnt[p][t]=sum[p][t]=0;
lazy[p][t]=1;
return k;
}
int mid=l+r>>1,cs=0;
if(x<=mid&&y>=l)cs+=GC(ls[p],l,mid,x,y,t);
if(x<=r&&y>mid)cs+=GC(rs[p],mid+1,r,x,y,t);
cnt[p][t]=cnt[ls[p]][t]+cnt[rs[p]][t];
sum[p][t]=sum[ls[p]][t]+sum[rs[p]][t];
return cs;
}
void GA()
{
int i,k;ll ans=0;int t=n+m+1;
BT(1,t);H[t+1]=H[t]+1;
for(i=t-1;i>0;i--)
{
if(H[i+1]<H[i+2])k=GC(1,1,t,H[i+1]+1,H[i+2],1);
else k=0;
if(H[i+1])
{
if(SA[i+1]>n)ADD(1,1,t,H[i+1],k+1,1);
else if(k)ADD(1,1,t,H[i+1],k,1);
}
if(SA[i]<n)ans+=sum[1][1];
if(H[i+1]<H[i+2])k=GC(1,1,t,H[i+1]+1,H[i+2],0);
else k=0;
if(H[i+1])
{
if(SA[i+1]<n)ADD(1,1,t,H[i+1],k+1,0);
else if(k)ADD(1,1,t,H[i+1],k,0);
}
if(SA[i]>n)ans+=sum[1][0];
}
printf("%lld",ans);
}
int main()
{
scanf("%s%s",s,B);
n=strlen(s);
m=strlen(B);
for(int i=n+1;i<=n+m;i++)s[i]=B[i-n-1];
s[n]='z'+1;s[n+m+1]='a'-1;
GSA(s,SA,n+m+2,300);
GH(s,SA,n+m+1);GA();
}