NKOJ 4128 (JSOI 2016)独特的树叶(树哈希)

P4128[Jsoi2016]独特的树叶

问题描述

JYY有两棵树A和B:树A有N个点,编号为1到N;树B有N+1个点,编号为1到N+1。JYY知道树B恰好是由树A加上一个叶

节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树B中的哪一个叶节点呢?

输入格式

输入一行包含一个正整数N。

接下来N-1行,描述树A,每行包含两个整数表示树A中的一条边;

接下来N行,描述树B,每行包含两个整数表示树B中的一条边。

1≤N≤10^5

输出格式

输出一行一个整数,表示树B中相比树A多余的那个叶子的编号。如果有多个符合要求的叶子,输出B中编号最小的那一个的编号。

样例输入

5
1 2
2 3
1 4
1 5
1 2
2 3
3 4
4 5
3 6

样例输出

1


这道题用树哈希来处理比较方便。
先将树A以每个点作为根的哈希值都算出来存到一个set中,然后将树B去掉一个叶子节点后算出哈希值,再在set中查找。
如果直接算n次哈希会超时,需要用到递推求哈希值。

这里为了方便递推,我用的哈希函数是
$父节点Hash=(Hash[son_1]+p)\bigoplus (Hash[son_2]+p)\bigoplus……+size*q+1$

那么递推的方式就比较显然了,具体可以参见代码。
只需要做一次树dp就可以求出以每个点为根的哈希值。

至于树B去掉一个叶子后的哈希值,事实上也可以同样的递推,只需要在处理到叶子节点的父节点时算一下以他为根,且去掉这个叶子的哈希值然后在set中查找即可。

用unordered_set可以做到近似$O(n)$


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<unordered_set>
#define ll long long
#define N 200005
using namespace std;
const ll p=1e9+7;
unordered_set<ll>Q;
ll Hash[N];
int n,D[N],si[N],Ans=1e9;
int TOT,LA[N],NE[N],EN[N];
void ADD(int x,int y)
{
TOT++;
EN[TOT]=y;
NE[TOT]=LA[x];
LA[x]=TOT;
}
void Ghash(int x,int f)
{
int i,y;si[x]=1;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(y==f)continue;
Ghash(y,x);si[x]+=si[y];
Hash[x]=Hash[x]^Hash[y]+17;
}
Hash[x]+=si[x]*p+1;
}
void DFS1(int x,int f)
{
int i,y;Q.insert(Hash[x]);
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(y==f)continue;
ll tmp=(Hash[x]-si[x]*p-1)^(Hash[y]+17);
tmp+=(n-si[y])*p+1;
Hash[y]-=si[y]*p+1;
Hash[y]^=tmp+17;
Hash[y]+=n*p+1;
si[y]=n;DFS1(y,x);
}
}
void DFS2(int x,int f)
{
int i,y;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(y==f)continue;
if(D[y]>1)
{
ll tmp=(Hash[x]-si[x]*p-1)^(Hash[y]+17);
tmp+=(si[x]-si[y])*p+1;
Hash[y]-=si[y]*p+1;
Hash[y]^=tmp+17;
Hash[y]+=si[x]*p+1;
si[y]=si[x];DFS2(y,x);
}
else
{
ll tmp=(Hash[x]-si[x]*p-1)^(Hash[y]+17);
tmp+=(si[x]-si[y])*p+1;
if(Q.count(tmp))Ans=min(Ans,y);
}
}
}
int main()
{
int i,j,k,x,y;
scanf("%d",&n);
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
ADD(x,y);ADD(y,x);
}
Ghash(1,0);DFS1(1,0);TOT=0;
memset(LA,0,sizeof(LA));
memset(Hash,0,sizeof(Hash));
for(i=1;i<=n;i++)
{
scanf("%d%d",&x,&y);
D[x]++;D[y]++;
ADD(x,y);ADD(y,x);
}
for(x=1;x<=n;x++)if(D[x]>1)break;
Ghash(x,0);DFS2(x,0);
printf("%d",Ans);
}