NKOJ 3446 (HN Training 2015)Shopping (点分治+树形dp)

P3446【HN Training 2015 Round7】

问题描述
这里写图片描述

这里写图片描述
这里写图片描述


容易发现最后答案是树上的一个联通块,但直接dp难度较大,考虑用点分治转化成一定包含根的联通块。

点分治后,每一层考虑包含根的联通块,那么转化成一个树形依赖背包,只有选了父节点才能选子节点,并且每个物品有个数限制。

这里对于这种树形依赖dp,可以采用dfs序来简化,因为在dfs序中,一颗子树必然是连续的一段,那么令$F[i][j]$表示在$dfs序i-n这些节点中容积为j的背包的最优解$,因此在物品数量均为1时可以得到dp方程
$$
F[i][j]=max(F[i+size[i]][j],F[i+1][j-c[i]]+w[i])
$$
第一个转移表示不选i节点的物品,那么跳过i这颗子树,第二个转移表示选。答案就是$F[1][m]$

接着考虑物品数量的限制,这里我用的二进制拆分的方法,即将d个物品拆成$log\ d$个物品,举个例子,比如将$10$个物品可以拆成$1,2,4,3$这4个物品,容易发现,无论从这10个物品中取多少个,都可以用上述4个物品表示出来。

修改后的dp方程可以参照代码,改动不大,只是多了一个$log$的复杂度。最终复杂度$Tnm\log n\log d$,实际上跑得还是很快。


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include<stdio.h>
#include<algorithm>
#include<cstring>
#define N 505
#define M 4005
using namespace std;
int T,n,m,w[N],c[N],d[N],F[N][M],ans;
int TOT,LA[N],NE[N*2],EN[N*2];
int Min,rt,si[N],VT,sz[N],wi[N],ci[N],di[N];
bool mark[N];
void ADD(int x,int y)
{
TOT++;
EN[TOT]=y;
NE[TOT]=LA[x];
LA[x]=TOT;
}
void Gsi(int x,int f)
{
int i,y;si[x]=1;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];
if(y==f||mark[y])continue;
Gsi(y,x);si[x]+=si[y];
ans=max(ans,F[1][m]);
}
}
void Grt(int x,int f,int s)
{
int i,y,Max=s-si[x];
for(i=LA[x];i;i=NE[i])
{
y=EN[i];
if(y==f||mark[y])continue;
Grt(y,x,s);
if(Max<si[y])Max=si[y];
}
if(Max<Min)Min=Max,rt=x;
}
void DFS(int x,int f)
{
int i,y,p=++VT;si[x]=1;
for(i=LA[x];i;i=NE[i])
{
y=EN[i];
if(mark[y]||y==f)continue;
DFS(y,x);si[x]+=si[y];
}
sz[p]=si[x];wi[p]=w[x];ci[p]=c[x];di[p]=d[x];
}
void Gans(int x)
{
int i,j,k,y;VT=0;DFS(x,0);//DFS序
for(i=VT;i>0;i--)
{
for(j=0;j<=m;j++)F[i][j]=F[i+sz[i]][j];
for(k=1,y=di[i];y>0;y-=k,k<<=1)//拆分物品
for(j=m;j>=k*ci[i]||j>=y*ci[i];j--)//背包
if(k<y)F[i][j]=max(F[i][j],max(F[i][j-k*ci[i]],F[i+1][j-k*ci[i]])+k*wi[i]);
else F[i][j]=max(F[i][j],max(F[i][j-y*ci[i]],F[i+1][j-y*ci[i]])+y*wi[i]);
}
ans=max(ans,F[1][m]);
for(i=VT;i>0;i--)
for(j=0;j<=m;j++)F[i][j]=0;
}
void DC(int x)
{
int i,y;mark[x]=1;
Gans(x);
for(i=LA[x];i;i=NE[i])
{
y=EN[i];if(mark[y])continue;
Min=1e9;Gsi(y,0);Grt(y,x,si[y]);DC(rt);
}
}
int main()
{
int i,j,k,x,y;
scanf("%d",&T);
while(T--)
{
TOT=0;ans=0;
memset(LA,0,sizeof(LA));
memset(mark,0,sizeof(mark));
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)scanf("%d",&w[i]);
for(i=1;i<=n;i++)scanf("%d",&c[i]);
for(i=1;i<=n;i++)scanf("%d",&d[i]);
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
ADD(x,y);ADD(y,x);
}
Min=1e9;Gsi(1,0);Grt(1,0,n);DC(rt);
printf("%d\n",ans);
}
}