AHOI/HNOI2018 排列(贪心+堆)

「AHOI / HNOI2018」排列

问题描述

给定 $n$ 个整数 $a_1, a_2, …, a_n(0 \le a_i \le n)$,以及 $n$ 个整数 $w_1, w_2, …, w_n$。称 $a_1, a_2, …, a_n$ 的一个排列 $a_{p[1]}, a_{p[2]}, …, a_{p[n]}$ 为 $a_1, a_2, …, a_n$ 的一个合法排列,当且仅当该排列满足:对于任意的 $k$ 和任意的 $j$,如果 $j \le k$,那么 $a_{p[j]}$ 不等于 $p[k]$。(换句话说就是:对于任意的 $k$ 和任意的 $j$,如果 $p[k]$ 等于 $a_{p[j]}$,那么 $k<j$。)

定义这个合法排列的权值为 $w_{p[1]} + 2w_{p[2]} + … + nw_{p[n]}$。你需要求出在所有合法排列中的最大权值。如果不存在合法排列,输出 $-1$。

样例解释中给出了合法排列和非法排列的实例。

输入格式

第一行一个整数 $n$。

接下来一行 $n$ 个整数,表示 $a_1,a_2,…, a_n$。

接下来一行 $n$ 个整数,表示 $w_1,w_2,…,w_n$。

输出格式

输出一个整数表示答案。

样例输入

3
0 1 1
5 7 3

样例输出

32

提示

对于前 $20\%$ 的数据,$1 \le n \le 10$;

对于前 $40\%$ 的数据,$1 \le n \le 15$;

对于前 $60\%$ 的数据,$1 \le n \le 1000$;

对于前 $80\%$ 的数据,$1 \le n \le 100000$;

对于 $100\%$ 的数据,$1 \le n \le 500000$,$0 \le a_i \le n (1 \le i \le n)$,$1 \le w_i \le 10^9$ ,所有 $w_i$ 的和不超过 $1.5 \times 10^{13}$。


弄清楚题意后发现,题目要求的就是$a_{a_i}$在新排列中必须在$a_i$之前,那么我们从$i\rightarrow a_i$连边,这样的话就变成一颗以0为根的树,那么要求就是必须先选父亲才能选儿子。

要求权值最大等价于先选权值小的,那么考虑全局最小值优先选。如果当前全局最小值的父亲为0,那么直接选。

否则它一定在选了父亲之后第一个选,因此可以将它和它的父亲缩到一起,考虑缩点后的权值是多少,这里比较巧妙,实际上权值给成$\frac{\sum w_i}{size}$就是对的,具体证明我们考虑假如先选缩点后的$(a_1,a_2,…,a_k)$比先选$a_{k+1},…,a_n$优,那么意味着

$$
a_1+2a_2+…+ka_k+(k+1)a_{k+1}+…+na_n \geq a_{k+1}+…+(n-k)a_n+(n-k+1)a_1+…+na_k
$$
那么移项就得到

$$
(n-k)(a_1+a_2+…+a_k)\leq k(a_{k+1}+…+a_n)
$$
那么容易发现这样做是对的。因此只需要用并查集+堆来维护就行了。


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#define ll long long
#define N 500005
using namespace std;
char buf[1<<20],*p1,*p2;
#define GC (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?0:*p1++)
inline void _R(int &x)
{
char t=GC;
while(t<48||t>57)t=GC;
for(x=0;t>47&&t<58;t=GC)x=(x<<1)+(x<<3)+t-48;
}
struct node{ll v;int s,x;};
bool operator<(node a,node b){return a.v*b.s>b.v*a.s;}
int n,a[N],fa[N],si[N];
ll sum[N],ans;
priority_queue<node>Q;
int gf(int x){return x==fa[x]?x:fa[x]=gf(fa[x]);}
int main()
{
register int i,j,k,x,y,fx,fy;_R(n);
for(i=0;i<=n;i++)fa[i]=i;
for(i=1;i<=n;i++)
{
_R(a[i]);
fx=gf(i);fy=gf(a[i]);
if(fx==fy)return puts("-1"),0;
fa[fx]=fy;
}
for(i=1;i<=n;i++)_R(x),sum[i]=x;
for(i=0;i<=n;i++)si[i]=1,fa[i]=i;
for(i=1;i<=n;i++)Q.push((node){sum[i],si[i],i});
while(Q.size())
{
node tmp=Q.top();
Q.pop();
if(tmp.s!=si[tmp.x])continue;
fx=gf(a[tmp.x]);
ans+=tmp.v*si[fx];
fa[tmp.x]=fx;
si[fx]+=tmp.s;
sum[fx]+=tmp.v;
if(fx)Q.push((node){sum[fx],si[fx],fx});
}
printf("%lld",ans);
}