NKOJ 3941 (HNOI 2014)世界树(虚树+树形dp+倍增)

P3941[Hnoi2014]世界树

问题描述

世界树是一棵无比巨大的树,它伸出的枝干构成了整个世界。在这里,生存着各种各样的种族和生灵,他们共同信奉着绝对公正公平的女神艾莉森,在他们的信条里,公平是使世界树能够生生不息、持续运转的根本基石。
世界树的形态可以用一个数学模型来描述:世界树中有n个种族,种族的编号分别从1到n,分别生活在编号为1到n的聚居地上,种族的编号与其聚居地的编号相同。有的聚居地之间有双向的道路相连,道路的长度为1。保证连接的方式会形成一棵树结构,即所有的聚居地之间可以互相到达,并且不会出现环。定义两个聚居地之间的距离为连接他们的道路的长度;例如,若聚居地a和b之间有道路,b和c之间有道路,因为每条道路长度为1而且又不可能出现环,所卧a与c之间的距离为2。
出于对公平的考虑,第i年,世界树的国王需要授权m[i]个种族的聚居地为临时议事处。对于某个种族x(x为种族的编号),如果距离该种族最近的临时议事处为y(y为议事处所在聚居地的编号),则种族x将接受y议事处的管辖(如果有多个临时议事处到该聚居地的距离一样,则y为其中编号最小的临时议事处)。
现在国王想知道,在q年的时间里,每一年完成授权后,当年每个临时议事处将会管理多少个种族(议事处所在的聚居地也将接受该议事处管理)。 现在这个任务交给了以智慧著称的灵长类的你:程序猿。请帮国王完成这个任务吧。

输入格式

第一行为一个正整数n,表示世界树中种族的个数。
接下来n-l行,每行两个正整数x,y,表示x聚居地与y聚居地之间有一条长度为1的双
向道路。接下来一行为一个正整数q,表示国王询问的年数。
接下来q块,每块两行:
第i块的第一行为1个正整数m[i],表示第i年授权的临时议事处的个数。
第i块的第二行为m[i]个正整数h[l]、h[2]、…、h[m[i]],表示被授权为临时议事处的聚居地编号(保证互不相同)。

输出格式

输出包含q行,第i行为m[i]个整数,该行的第j(j=1,2…,,m[i])个数表示第i年被授权的聚居地h[j]的临时议事处管理的种族个数。

样例输入

10
2 1
3 2
4 3
5 4

6 1
7 3
8 3
9 4
10 1
5

2
6 1
5
2 7 3 6 9
1

8
4
8 7 10 3
5
2 9 3 5 8

样例输出

1 9
3 1 4 1 1
10
1 1 3 5
4 1 3 1 1

提示

N<=300000, q<=300000,m[1]+m[2]+…+m[q]<=300000


注意到多次询问的点总数和n同阶,考虑虚树。
构建好虚树后,在虚树上做两次树dp,算出到每个点距离最近的关键点,然后统计答案。
注意到有些点不在虚树中,但也要统计答案,这些点可以分成两类,一类是在虚树中两点之间的点,一类是虚树中点的(不在虚树中的)子树中的点。
对于第二类点,可以在虚树上用原size减去在虚树中的子树的size,注意这里的size需要倍增找到该点在原树中的儿子。
对于第一类点,可以枚举一条虚树中的边,然后在边上倍增找到中点,然后将两段的size分别加到两端的最近点上。
总时间复杂度$O(n\log n)$


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
#define N 600005
using namespace std;
int n,q,P[N],ans[N],Q[N],top,cnt,PP[N];
int dfn[N],VT,F[N],G[N],Sum[N],Size[N];
bool mark[N];
int dep[N],fa[N][20],S=19;
int TOT,LA[N],NE[N],EN[N];
int tot,la[N],ne[N],st[N],en[N],le[N];
bool cmp(int x,int y)
{return dfn[x]<dfn[y];}
void ADD(int x,int y)
{
TOT++;
EN[TOT]=y;
NE[TOT]=LA[x];
LA[x]=TOT;
}
void add(int x,int y,int z)
{
tot++;
st[tot]=x;
en[tot]=y;
le[tot]=z;
ne[tot]=la[x];
la[x]=tot;
}
void DFS(int x,int f)
{
int i,y;
Size[x]=1;
dfn[x]=++VT;
fa[x][0]=f;
dep[x]=dep[f]+1;
for(i=1;i<=S;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
for(i=LA[x];i;i=NE[i])
{
y=EN[i];
if(y!=f)DFS(y,x),Size[x]+=Size[y];
}
}
int LCA(int x,int y)
{
if(!x||!y)return 0;
if(dep[x]<dep[y])swap(x,y);
int i,t=dep[x]-dep[y];
for(i=0;i<=S;i++)
if(t>>i&1)x=fa[x][i];
if(x==y)return x;
for(i=S;i>=0;i--)
if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void BT()
{
int i,lca;
if(!mark[1])Q[++top]=1;
for(i=1;i<=cnt;i++)
{
lca=LCA(P[i],Q[top]);
if(lca==Q[top]){Q[++top]=P[i];continue;}
while(dep[lca]<dep[Q[top-1]])
{
add(Q[top-1],Q[top],dep[Q[top]]-dep[Q[top-1]]);
top--;
}
add(lca,Q[top],dep[Q[top]]-dep[lca]);
if(lca!=Q[--top])Q[++top]=lca;
Q[++top]=P[i];
}
while(--top)add(Q[top],Q[top+1],dep[Q[top+1]]-dep[Q[top]]);
}
int Gsi(int x,int f)
{
int i;
for(i=S;i>=0;i--)if(dep[fa[x][i]]>dep[f])x=fa[x][i];
return Size[x];
}
void DP1(int x)
{
int i,y;F[x]=G[x]=1e9;Sum[x]=Size[x];
if(mark[x])F[x]=0,G[x]=x;
for(i=la[x];i;i=ne[i])
{
y=en[i];DP1(y);Sum[x]-=Gsi(y,x);
if(F[y]+le[i]<F[x])F[x]=F[y]+le[i],G[x]=G[y];
else if(F[y]+le[i]==F[x]&&G[y]<G[x])G[x]=G[y];
}
}
void DP3(int x,int d,int f)
{
if(x!=1)
{
if(F[x]>F[f]+d)F[x]=F[f]+d,G[x]=G[f];
else if(F[x]==F[f]+d&&G[x]>G[f])G[x]=G[f];
}
for(int i=la[x];i;i=ne[i])DP3(en[i],le[i],x);
}
void DP2(int x)
{
ans[G[x]]+=Sum[x];
for(int i=la[x];i;i=ne[i])DP2(en[i]);
la[x]=mark[x]=0;
}
int Gans(int x,int t)
{
int sum=0,i;
for(i=0;i<=S;i++)
if(t>>i&1)sum+=Size[fa[x][i]]-Size[x],x=fa[x][i];
return sum;
}
int main_main()
{
int i,j,k,x,y,d;
scanf("%d",&n);
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
ADD(x,y);ADD(y,x);
}
DFS(1,0);
scanf("%d",&q);
for(i=1;i<=q;i++)
{
scanf("%d",&cnt);
for(j=1;j<=cnt;j++)scanf("%d",&P[j]),mark[P[j]]=1,PP[j]=P[j];
sort(P+1,P+cnt+1,cmp);tot=0;
BT();DP1(1);DP3(1,0,0);DP2(1);
for(j=1;j<=tot;j++)
{
if(le[j]==1)continue;
x=st[j];y=en[j];
if(dep[x]<dep[y])swap(x,y);
if(G[x]==G[y])ans[G[x]]+=Gans(x,dep[x]-dep[y]-1);
else
{
k=le[j]-1-(F[x]-F[y]);
d=Gans(x,(k>>1)+(G[x]<G[y]?k&1:0));
ans[G[x]]+=d;
ans[G[y]]+=Gans(x,dep[x]-dep[y]-1)-d;
}
}
for(j=1;j<=cnt;j++)printf("%d ",ans[PP[j]]),ans[PP[j]]=0;
puts("");
}
}
const int main_stack=16;
char my_stack[128<<20];
int main() {
__asm__("movl %%esp, (%%eax);\n"::"a"(my_stack):"memory");
__asm__("movl %%eax, %%esp;\n"::"a"(my_stack+sizeof(my_stack)-main_stack):"%esp");
main_main();
__asm__("movl (%%eax), %%esp;\n"::"a"(my_stack):"%esp");
return 0;
}