luogu4719(动态dp)

题目链接

https://www.luogu.org/problemnew/show/P4719

题解

教程链接

动态 $DP$ 是在普通 $DP$ 的基础上加了修改操作,使得题目变得非常不可做。。经常需要用到数据结构维护,常见的如:用线段树维护最大子段和以及维护 $LIS$ 等

然后连模板题都这么难么。。

考虑树上独立集的求法,设 $d[i][0/1]$ 为取/不取 $i$ 的最大独立集

数据结构只能维护序列上的情况,因此可以用树连剖分之后考虑链上的情况,如果要转化成链上问题得先把轻儿子处理好,那么设 $g[i][0/1]$ 为不考虑重儿子的最大独立集

加入重儿子的时候,可以用广义矩阵乘法(定义乘法为加法,加法为取 $max$ )做如下转移:

那么修改的时候只考虑这个节点作为轻儿子所产生的影响即可




代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
/**
*         ┏┓    ┏┓
*         ┏┛┗━━━━━━━┛┗━━━┓
*         ┃       ┃  
*         ┃   ━    ┃
*         ┃ >   < ┃
*         ┃       ┃
*         ┃... ⌒ ...  ┃
*         ┃ ┃
*         ┗━┓ ┏━┛
*          ┃ ┃ Code is far away from bug with the animal protecting          
*          ┃ ┃ 神兽保佑,代码无bug
*          ┃ ┃           
*          ┃ ┃       
*          ┃ ┃
*          ┃ ┃           
*          ┃ ┗━━━┓
*          ┃ ┣┓
*          ┃ ┏┛
*          ┗┓┓┏━━━━━━━━┳┓┏┛
*           ┃┫┫ ┃┫┫
*           ┗┻┛ ┗┻┛
*/

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#include<map>
#include<stack>
#include<cmath>
#include<set>
#include<bitset>
#include<assert.h>
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,l,r) for(int i=l;i>=r;i--)
#define link(x) for(edge *j=h[x];j;j=j->next)
#define mem(a) memset(a,0,sizeof(a))
#define ll long long
#define eps 1e-8
#define succ(x) (1<<x)
#define lowbit(x) (x&(-x))
#define sqr(x) ((x)*(x))
#define mid (x+y)/2
#define NM 100005
#define nm 200005
#define pi 3.1415926535897931
using namespace std;
const int inf=1e9+7;
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f*x;
}




struct edge{int t;edge*next;}e[nm],*h[NM],*o=e;
void add(int x,int y){o->t=y;o->next=h[x];h[x]=o++;}
int d[NM],top[NM],son[NM],end[NM],TOP,tot,id[NM],f[NM],size[NM],c[NM];
int dp[NM][2],g[NM][2],n,m,a[NM],_x,_y;
void dfs1(int x){
size[x]=1;dp[x][1]=a[x];dp[x][0]=0;
link(x)if(j->t!=f[x]){
f[j->t]=x;d[j->t]=d[x]+1;
dfs1(j->t);
dp[x][1]+=dp[j->t][0];
dp[x][0]+=max(dp[j->t][1],dp[j->t][0]);
if(size[j->t]>size[son[x]])son[x]=j->t;
size[x]+=size[j->t];
}
}
int dfs2(int x){
top[x]=TOP;id[x]=++tot;end[x]=x;c[tot]=x;
if(son[x])end[x]=dfs2(son[x]);
g[x][1]=a[x];g[x][0]=0;
link(x)if(!top[j->t]){
dfs2(TOP=j->t);
g[x][1]+=dp[j->t][0];
g[x][0]+=max(dp[j->t][1],dp[j->t][0]);
}
return end[x];
}

struct mat{int n,m,a[2][2];}one;
mat operator*(const mat&x,const mat&y){
mat s;s.n=x.n;s.m=y.m;mem(s.a);
inc(i,0,x.n)inc(k,0,x.m)inc(j,0,s.m)s.a[i][j]=max(s.a[i][j],x.a[i][k]+y.a[k][j]);
return s;
}

struct node{
node*l,*r;
int x,y;
mat s;
node(int x,int y,node*l=0,node*r=0):x(x),y(y),l(l),r(r){
if(x==y){
s.n=s.m=1;s.a[0][0]=s.a[0][1]=g[c[x]][0];s.a[1][0]=g[c[x]][1];s.a[1][1]=-inf;
}else upd();
}
void upd(){s=l->s*r->s;}
void mod(){
if(x==y){
s.a[0][0]=s.a[0][1]=g[c[x]][0];s.a[1][0]=g[c[x]][1];
return;
}
if(_x<=mid)l->mod();else r->mod();
upd();
}
mat sum(){
if(_x<=x&&y<=_y)return s;
if(_x>mid)return r->sum();
if(_y<=mid)return l->sum();
return l->sum()*r->sum();
}
}*root;
node*build(int x,int y){return x==y?new node(x,y):new node(x,y,build(x,mid),build(mid+1,y));}

void ch(int x){
g[x][1]+=_y-a[x];a[x]=_y;
_x=id[x];root->mod();
while(top[x]>1){
x=top[x];
_x=id[x];_y=id[end[x]];mat t=root->sum()*one;
g[f[x]][1]+=t.a[0][0]-dp[x][0];
g[f[x]][0]+=max(t.a[0][0],t.a[1][0])-max(dp[x][0],dp[x][1]);
dp[x][1]=t.a[1][0];dp[x][0]=t.a[0][0];
x=f[top[x]];_x=id[x];root->mod();
}
}


int main(){
//freopen("data.in","r",stdin);
n=read();m=read();
inc(i,1,n)a[i]=read();
inc(i,2,n){_x=read();_y=read();add(_x,_y);add(_y,_x);}
dfs1(f[1]=1);dfs2(TOP=1);
//inc(i,1,n)printf("%d ",top[i]);putchar('\n');
//inc(i,1,n)printf("%d ",end[i]);putchar('\n');
root=build(1,n);
one.n=1;one.m=0;
_x=1;_y=id[end[1]];mat ans=root->sum()*one;
//printf("%d\n",max(ans.a[0][0],ans.a[1][0]));
while(m--){
_x=read();_y=read();ch(_x);
_x=1;_y=id[end[1]];mat ans=root->sum()*one;
printf("%d\n",max(ans.a[0][0],ans.a[1][0]));
}
return 0;
}