primedst(点分治+FFT)

题目链接

https://cn.vjudge.net/problem/CodeChef-PRIMEDST

题意

给定一棵树,找距离为质数的点对数

题解

很容易想到点分,但是合并的时候复杂度过大,需要考虑降低合并的复杂度

而这个树背包本质也是个多项式,所以合并的时候相当于做多项式乘法,这就可以用 $FFT$ 优化了。。

这里答案的总数会达到 $10^9$ ,所以选了个比较大的 $NTT$ 模数做 $NTT$ 。。

然后注意分治的时候当前子根和其他点产生的贡献只记了一次,其他点之间记了两次的差异即可。。




代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/**
*         ┏┓    ┏┓
*         ┏┛┗━━━━━━━┛┗━━━┓
*         ┃       ┃  
*         ┃   ━    ┃
*         ┃ >   < ┃
*         ┃       ┃
*         ┃... ⌒ ...  ┃
*         ┃ ┃
*         ┗━┓ ┏━┛
*          ┃ ┃ Code is far away from bug with the animal protecting          
*          ┃ ┃ 神兽保佑,代码无bug
*          ┃ ┃           
*          ┃ ┃       
*          ┃ ┃
*          ┃ ┃           
*          ┃ ┗━━━┓
*          ┃ ┣┓
*          ┃ ┏┛
*          ┗┓┓┏━━━━━━━━┳┓┏┛
*           ┃┫┫ ┃┫┫
*           ┗┻┛ ┗┻┛
*/

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#include<map>
#include<stack>
#include<cmath>
#include<set>
#include<bitset>
#include<complex>
#include<assert.h>
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,l,r) for(int i=l;i>=r;i--)
#define link(x) for(edge *j=h[x];j;j=j->next)
#define mem(a) memset(a,0,sizeof(a))
#define ll long long
#define eps 1e-8
#define succ(x) (1<<x)
#define lowbit(x) (x&(-x))
#define sqr(x) ((x)*(x))
#define mid (x+y)/2
#define NM 50005
#define nm 150005
using namespace std;
const double pi=acos(-1);
const ll inf=2013265921;
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f*x;
}



struct edge{int t;edge*next;}e[nm],*h[NM],*o=e;
void add(int x,int y){o->t=y;o->next=h[x];h[x]=o++;}
int n,m,size[NM],tot,smin,root,_x,_y,rev[nm];
bool v[NM],_v[nm];
ll a[nm],b[nm],c[nm],ans;
ll qpow(ll x,ll t){return t?qpow(sqr(x)%inf,t>>1)*(t&1?x:1ll)%inf:1ll;}


void dfs2(int x,int f){size[x]=1;link(x)if(!v[j->t]&&j->t!=f)dfs2(j->t,x),size[x]+=size[j->t];}
void getroot(int x,int f){
int s=tot-size[x];
link(x)if(!v[j->t]&&j->t!=f)getroot(j->t,x),s=max(s,size[j->t]);
if(s<smin)smin=s,root=x;
}

void fft(ll*a,int f){
inc(i,0,n-1)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int k=1;k<n;k<<=1){
ll t=qpow(31,(inf-1)/k/2);if(f==-1)t=qpow(t,inf-2);
for(int i=0;i<n;i+=k<<1){
ll w=1;
for(int j=0;j<k;j++,w=w*t%inf){
ll x=a[i+j],y=w*a[i+j+k]%inf;
a[i+j]=(x+y)%inf;a[i+j+k]=(x-y+inf)%inf;
}
}
}
}
void plu(ll*a,ll*b){
int bit=0;
n=n+m-1;
while(succ(bit)<n)bit++;n=succ(bit);
inc(i,m,n)b[i]=0;
ll invn=qpow(n,inf-2);
inc(i,1,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
fft(a,1);fft(b,1);inc(i,0,n-1)a[i]=a[i]*b[i]%inf;
fft(a,-1);inc(i,0,n-1)a[i]=a[i]*invn%inf;
}


void dfs(int x,int f,int t){
a[t]++;tot=max(tot,t);
link(x)if(!v[j->t]&&j->t!=f)dfs(j->t,x,t+1);
}

void div(int x){
dfs2(x,0);
tot=size[x];smin=inf;
getroot(x,0);tot=0;
dfs(root,0,0);
v[root]++;
inc(i,0,tot)c[i]=a[i],a[i]=0;
inc(i,2,tot)if(!_v[i])ans+=c[i];
m=tot+1;
link(root)if(!v[j->t]){
tot=0;
dfs(j->t,root,1);
inc(i,0,m-1)b[i]=c[i]-a[i];
n=tot+1;
plu(a,b);
inc(i,0,n-1){if(!_v[i])ans+=a[i];a[i]=0;}
}
link(root)if(!v[j->t])div(j->t);
}

int main(){
n=1e5;_v[1]++;_v[0]++;
inc(i,2,n)if(!_v[i])for(int j=i<<1;j<=n;j+=i)_v[j]++;
n=read();inc(i,2,n){_x=read();_y=read();add(_x,_y);add(_y,_x);}
_x=n;
div(1);
return 0*printf("%.7lf\n",1.0*ans/_x/(_x-1));
}