NKOJ4029 CodeChef COUNTARI(分块+FFT)

P4029 [CodeChef] COUNTARI

问题描述

给定一个长度为N的数组A[],求有多少对$i, j, k(1\leq i<j<k\leq N)$满足$A[k]-A[j]=A[j]-A[i]$

输入格式

第一行一个整数$N(N<=10^5)$。

接下来一行N个数$A[i](A[i]<=30000)$。

输出格式

一行一个整数。

样例输入

10

3 5 3 6 3 4 10 4 5 2

样例输出

9


容易想到一个朴素的暴力,枚举$j$的位置,然后两边卷积得到$A[i]+A[k]=2A[j]$的$(i,k)$的数量。考虑优化。

用分块处理,设块的大小为$K$,那么分三种情况来讨论。

当$i,j,k$在同一块内时,枚举$(i,k)$,同时维护$cnt[j]$,那么单块可以在$O(K^2)$内出解,总复杂度就是$O(NK)$的

当$i,j,k$有两个在同一块内时,同样枚举$(i,k)$,复杂度仍然是$O(NK)$

当$i,j,k$均不在同一块时,枚举每一块,将左右两边的块卷积起来,复杂度就是$\frac{N}{K}M\log M$,$M=max{A[i]}$

那么只需要一个比较优秀的$K$即可解决。似乎当$K=5.3\sqrt{N}$时效果不错。


代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<complex>
#define N 100005
#define ll long long
using namespace std;
struct com
{
double r,v;
com(double x=0,double y=0){r=x;v=y;}
}B[N],D[N],w[2][N];
com operator+(com &a,com &b){return com(a.r+b.r,a.v+b.v);}
com operator-(com &a,com &b){return com(a.r-b.r,a.v-b.v);}
com operator*(com &a,com &b){return com(a.r*b.r-a.v*b.v,a.r*b.v+a.v*b.r);}
int n,K,tot,lp[N],rp[N],A[N],Lcnt[N],Rcnt[N],cnt[N],L,rev[N];
ll ans;
const double pi=acos(-1.0);
void FFT_pre()
{
int i,j,k;
com f=com(cos(2.0*pi/L),sin(2.0*pi/L));
com g=com(cos(2.0*pi/L),-sin(2.0*pi/L));
w[0][0]=w[1][0]=1;
for(i=1;i<L;i++)
{
w[0][i]=w[0][i-1]*g;
w[1][i]=w[1][i-1]*f;
}
}
void FFT(com C[],int ty)
{
register int i,j,k,m,t;com t0,t1;
for(i=0;i<L;i++)if(i<rev[i])swap(C[i],C[rev[i]]);
for(m=1,t;t=L/(m<<1),m<L;m<<=1)
for(k=0;k<L;k+=m<<1)
for(i=k,j=0;i<k+m;i++,j+=t)
{
t0=C[i];t1=C[i+m]*w[ty][j];
C[i]=t0+t1;
C[i+m]=t0-t1;
}
if(ty==1)return;t0=1.0/L;
for(i=0;i<L;i++)C[i]=C[i]*t0;
}
int main()
{
int i,j,k,x,y,Max=0,len=0;
scanf("%d",&n);K=5.3*sqrt(n);
for(i=1;i<=n;i++)
{
scanf("%d",&A[i]);
Max=max(Max,A[i]);
Rcnt[A[i]]++;
rp[i/K]=i;
if(!lp[i/K])lp[i/K]=i;
}
tot=n/K;L=1;while(L<=(Max<<1))L<<=1,len++;FFT_pre();
for(i=0;i<L;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<len>>1);
for(i=0;i<=tot;i++)
{
for(j=lp[i];j<=rp[i];j++)Rcnt[A[j]]--;
for(j=lp[i];j<rp[i];j++)
{
for(k=j+1;k<=rp[i];k++)
{
if(A[j]+A[k]+1&1)ans+=cnt[A[j]+A[k]>>1];
if((A[j]<<1)-A[k]>=0)ans+=Lcnt[(A[j]<<1)-A[k]];
if((A[k]<<1)-A[j]>=0)ans+=Rcnt[(A[k]<<1)-A[j]];
cnt[A[k]]++;
}
for(k=j+1;k<=rp[i];k++)cnt[A[k]]--;
}
if(i!=0&&i!=tot)
{
fill(B,B+L,0);fill(D,D+L,0);
for(j=0;j<L;j++)B[j]=Lcnt[j],D[j]=Rcnt[j];
FFT(B,1);FFT(D,1);
for(j=0;j<L;j++)B[j]=B[j]*D[j];
FFT(B,0);
for(j=lp[i];j<=rp[i];j++)ans+=floor(B[A[j]<<1].r+0.5);
}
for(j=lp[i];j<=rp[i];j++)Lcnt[A[j]]++;
}
printf("%lld",ans);

}