动态dp学习笔记

动态dp

基本描述

是指一类树上dp问题不再局限于静态,通过多次修改权值达到动态效果的dp问题

是猫锟在WC2018上讲的黑科技

模板题

【模板】"动态 DP"&动态树分治

朴素dp很简单,即入门树形dp,“没有上司的舞会”

我们能得到转移方程

fi,0f_{i,0} 表示点 ii 不选择,子树内得到的最大权值和

fi,1f_{i,1} 表示点 ii 选择,子树内得到的最大权值和

fi,0=sSonimax(fs,0,fs,1)fi,1=ai+sSonifs,0f_{i,0}=\sum_{s\in Son_i}max(f_{s,0},f_{s,1})\\ f_{i,1}=a_i+\sum_{s\in Son_i}f_{s,0}

动态dp的一种核心思路是通过重链剖分的优秀性质实现的

将子树内点的贡献拆成轻儿子与重儿子两种

hih_i 表示点 ii 的重儿子

gi,0g_{i,0} 表示点 ii 的所有轻儿子可选可不选对 ii 的贡献之和

gi,1g_{i,1} 表示选点 ii 并且所有轻儿子不选对 ii 的贡献之和

转移方程简化为

fi,0=gi,0+max{fhi,0,fhi,1}fi,1=gi,1+fhi,0f_{i,0}=g_{i,0}+\max\{f_{h_i,0},f_{h_i,1}\}\\ f_{i,1}=g_{i,1}+f_{h_i,0}

这样的方程是可以矩阵优化的

[gi,0,gi,0gi,1,inf]×[fhi,0fhi,1]=[fi,0fi,1]\begin{bmatrix}g_{i,0},g_{i,0}\\g_{i,1},-\inf\end{bmatrix}\times\begin{bmatrix}f_{h_i,0}\\f_{h_i,1}\end{bmatrix}=\begin{bmatrix}f_{i,0}\\f_{i,1}\end{bmatrix}

这里的矩乘是 max+\max + 矩乘

发现我们想得到 fi,0f_{i,0}fi,1f_{i,1} 只要把 点 ii 所在重链的链尾到点 ii 的转移矩阵全部乘起来即可

又链尾一定是叶子,叶子的 ff 我们能轻易得到

查询可以在 O(8logn)O(8*logn) 的时间内在线段树上得到

考虑修改点 ii 权值

容易知道只会对点 ii 至根的路径产生影响

由于我们把问题转化成了若干转移矩阵的乘积,只需考虑对 gg 的影响即可

发现重链上均不会有影响,只会当路径上某点为其父亲节点的轻儿子时会产生影响

当链顶走到另一条链时,链顶是父亲的轻儿子,会对 gg 产生影响

我们暴力的将链顶的 ff 查出来,贡献到父亲的 gg 上即可

我们知道一点到根的路径上最多有 O(logn)O(logn) 条轻边

我们在 O(log2n)O(log^2n) 的时间内解决了该问题

完整代码如下

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=1e5+5,inf=1e9;
struct Edge{
	int to,nxt;	
}e[2*maxn];
int cnt;
int head[maxn];
int n,m;
int v[maxn];
int read(){
	int x=0,y=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*y;
}
void add(int u,int v){
	e[cnt].to=v;
	e[cnt].nxt=head[u];
	head[u]=cnt++;
	return ;
}
int pr[maxn];
int siz[maxn],dep[maxn],h[maxn];
void dfs1(int x,int fa){
	pr[x]=fa;
	dep[x]=dep[fa]+1;
	siz[x]=1;
	for(int i=head[x];i!=-1;i=e[i].nxt){
		int tmp=e[i].to;
		if(tmp==fa)continue;
		dfs1(tmp,x);
		siz[x]+=siz[tmp];
		if(siz[h[x]]<siz[tmp])h[x]=tmp;
	}
	return ;
}
int dfsnum;
int top[maxn],end[maxn];
int dfn[maxn],num[maxn];
int f[maxn][2],g[maxn][2];
void dfs2(int x,int tp){
	top[x]=tp;
	dfn[x]=++dfsnum;
	num[dfsnum]=x;
	if(h[x])dfs2(h[x],tp);
	g[x][1]=v[x];
	for(int i=head[x];i!=-1;i=e[i].nxt){
		int tmp=e[i].to;
		if(tmp==pr[x]||tmp==h[x])continue;
		dfs2(tmp,tmp);
		g[x][0]+=max(f[tmp][0],f[tmp][1]);
		g[x][1]+=f[tmp][0];
	}
	f[x][0]=g[x][0]+max(f[h[x]][0],f[h[x]][1]);
	f[x][1]=g[x][1]+f[h[x]][0];
	return ;
}
struct matrix{
	int n,m;
	int x[2][2];
	matrix operator *(matrix a)const{
		matrix ans;
		ans.n=n;ans.m=a.m;
		for(int i=0;i<n;i++)
			for(int j=0;j<a.m;j++){
				ans.x[i][j]=-inf;
				for(int k=0;k<m;k++)
					ans.x[i][j]=max(ans.x[i][j],x[i][k]+a.x[k][j]);
			}
		return ans;
	}
};
matrix o;
matrix c[maxn<<2];
void build(int k,int l,int r){
	if(l==r){
		c[k].n=c[k].m=2;
		c[k].x[0][0]=g[num[l]][0];
		c[k].x[0][1]=g[num[l]][0];
		c[k].x[1][0]=g[num[l]][1];
		c[k].x[1][1]=-inf;
		return ;
	}
	int mid=l+((r-l)>>1);
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
	c[k]=c[k<<1]*c[k<<1|1];
	return ;
}
void modify(int k,int l,int r,int x){
	if(l>x||r<x)return ;
	if(l==r){
		c[k].n=c[k].m=2;
		c[k].x[0][0]=g[num[l]][0];
		c[k].x[0][1]=g[num[l]][0];
		c[k].x[1][0]=g[num[l]][1];
		c[k].x[1][1]=-inf;
		return ;
	}
//	cout<<k<<" "<<l<<" "<<r<<" "<<x<<endl;
	int mid=l+((r-l)>>1);
	modify(k<<1,l,mid,x);
	modify(k<<1|1,mid+1,r,x);
	c[k]=c[k<<1]*c[k<<1|1];
	return ;
}
matrix query(int k,int l,int r,int x,int y){
	if(l>y||r<x)return o;
	if(l>=x&&r<=y)return c[k];
	int mid=l+((r-l)>>1);
	return query(k<<1,l,mid,x,y)*query(k<<1|1,mid+1,r,x,y);
}
matrix query_node(int x){
	matrix A;
	A.n=2;A.m=1;
	A.x[0][0]=0;A.x[1][0]=v[end[x]];
	matrix p=query(1,1,n,dfn[x],dfn[end[x]]-1);
	p=p*A;
//	cout<<p.x[0][0]<<" "<<p.x[1][0]<<endl;
	return p;
}
void modify_node(int x,int val){
	g[x][1]+=val-v[x];
	v[x]=val;
	modify(1,1,n,dfn[x]);
	while(dep[top[x]]>1){
		x=top[x];
		matrix p=query_node(x);
		g[pr[x]][0]+=max(p.x[0][0],p.x[1][0])-max(f[x][0],f[x][1]);
		g[pr[x]][1]+=p.x[0][0]-f[x][0];
		f[x][0]=p.x[0][0];f[x][1]=p.x[1][0];
		modify(1,1,n,dfn[pr[x]]);
		x=pr[x];
	}
	return ;
}
int main(){
	n=read();m=read();
	for(int i=1;i<=n;i++)v[i]=read();
	memset(head,-1,sizeof(head));
	for(int i=1;i<n;i++){
		int u,v;
		u=read();v=read();
		add(u,v);
		add(v,u);
	}
	dfs1(1,0);
	dfs2(1,1);
	for(int i=1;i<=n;i++)if(dep[i]>dep[end[top[i]]])end[top[i]]=i;
	for(int i=1;i<=n;i++)end[i]=end[top[i]];
	build(1,1,n);
	o.n=o.m=2;
	o.x[0][1]=o.x[1][0]=-inf;
	for(int i=1;i<=m;i++){
		int x,y;
		x=read();y=read();
		modify_node(x,y);n
		matrix p=query_node(1);
		printf("%d\n",max(p.x[0][0],p.x[1][0]));
	}
	return 0;
}