SDOI 2016 模式字符串(点分治+哈希)

[SDOI2016]模式字符串

问题描述
给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m的模式串s,其中每一位仍然是A到z的大写字母。

Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分.

所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以XYXYXY不能看作是S重复若干次得到的。

输入格式
每一个数据有多组测试,

第一行输入一个整数C,表示总的测试个数。

对于每一组测试来说:

第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,

之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(第i个字符对应了第i个结点).

之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,为模式串S。

1<=C<=10,3<=N<=10000003<=M<=1000000

输出格式
给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模式串的若干次重复.

样例输入
1
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI

样例输出
5


此题是树上路径问题,显然考虑点分治,那么我们需要判断一条路径是否是给定串的倍数。

考虑用哈希处理,我们将若干个模式串拼起来,得到一个长度大于等于n的串,然后求出这个串的前缀和后缀哈希值,这样我们点分治的时候,只需要从根出发,算算哈希值,就知道长度为$k$的前缀和后缀分别有多少个了,然后算一算就知道有多少符合条件的路径了,注意到长度$k$超过$m$的前缀和后缀直接累加到$k\%m$上就行了。

复杂度$O(n\log n)$


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
#define N 1000005
#define ll long long
#define ull unsigned long long
using namespace std;
char buf[1 << 20], *p1, *p2;
#define GC (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?0:*p1++)
inline int _R() {
int d; char t;
while (t=GC,t<'0'||t>'9');
d=t-'0';
while(t=GC,t>='0'&&t<='9')d=(d<<3)+(d<<1)+t-'0';
return d;
}
inline void _S(char *c) {
char *t=c,ch;
while(ch=GC,ch==' '||ch=='\n'||ch=='\r');
*t++=ch;
while (ch=GC,ch!=' '&&ch!='\n'&&ch!='\r')*t++=ch;
*t=0;
}
const ull p=131;
ull Suf[N],Pre[N];
ll ans,suf[N],pre[N];
char C[N],s[N];
int T,n,m,si[N],Min,rt,MX;
bool mark[N];
int TOT,LA[N],NE[N<<1],EN[N<<1];
void ADD(int x,int y)
{
TOT++;
EN[TOT]=y;
NE[TOT]=LA[x];
LA[x]=TOT;
}
void Gsi(int x,int f)
{
int i,y;si[x]=1;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y]||y==f)continue;
Gsi(y,x);si[x]+=si[y];
}
}
void Grt(int x,int s,int f)
{
int i,y,Max=s-si[x];
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y]||y==f)continue;
Grt(y,s,x);Max=max(Max,si[y]);
}
if(Max<Min)Min=Max,rt=x;
}
void Gans(int x,int f,int d,ull Hash)
{
int i,y;Hash=Hash*p+C[x];
y=m-d%m;y-=y>=m?m:0;
if(Hash==Suf[d])ans+=pre[y];
if(Hash==Pre[d])ans+=suf[y];
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y]||y==f)continue;
Gans(y,x,d+1,Hash);
}
}
void Updata(int x,int f,int d,ull Hash)
{
int i,y;Hash=Hash*p+C[x];
if(Hash==Suf[d])suf[d%m]++,MX=max(MX,d%m);
if(Hash==Pre[d])pre[d%m]++,MX=max(MX,d%m);
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y]||y==f)continue;
Updata(y,x,d+1,Hash);
}
}
void Cal(int x)
{
int i,y;ull Hash=C[x];MX=1;
if(Hash==Suf[1])suf[1]++;
if(Hash==Pre[1])pre[1]++;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y])continue;
Gans(y,x,1,0);
Updata(y,x,2,Hash);
}
fill(suf,suf+MX+1,0);
fill(pre,pre+MX+1,0);
}
void DC(int x)
{
int i,y;mark[x]=1;Cal(x);
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y])continue;
Min=1e9;Gsi(y,x);Grt(y,si[y],x);
if(si[y]>=m)DC(rt);
}
}
int main_main()
{
int i,j,k,x,y;ull z;
T=_R();
while(T--)
{
n=_R();m=_R();
TOT=0;ans=0;
fill(LA,LA+n+1,0);
fill(mark,mark+n+1,0);
_S(C+1);
for(i=1;i<n;i++)
{
x=_R();y=_R();
ADD(x,y);ADD(y,x);
}
_S(s+1);
for(i=1,j=1,z=1;i<=n;i++,z*=p,j++,j-=j>m?m:0)Pre[i]=Pre[i-1]+z*s[j];
for(i=1,j=m,z=1;i<=n;i++,z*=p,j--,j+=j<1?m:0)Suf[i]=Suf[i-1]+z*s[j];
Min=1e9;Gsi(1,0);Grt(1,si[1],0);DC(rt);
printf("%lld\n",ans);
}
}
const int main_stack=16;
char my_stack[128<<20];
int main() {
__asm__("movl %%esp, (%%eax);\n"::"a"(my_stack):"memory");
__asm__("movl %%eax, %%esp;\n"::"a"(my_stack+sizeof(my_stack)-main_stack):"%esp");
main_main();
__asm__("movl (%%eax), %%esp;\n"::"a"(my_stack):"%esp");
return 0;
}