树链剖分

终于开始填暑假的坑了

首先是概念:总的来说,就是把一棵树剖分成若干条链,然后利用数据结构维护每一条链。复杂度:O(logn)

在学习树链剖分之前,建议先去学线段树。

首先列出一些基本概念:

  • 重儿子:若siz[u]是v的子节点中siz值最大的,那么u就是v的重儿子。

  • 轻儿子:v的其它子节点。

  • 重边:点v与其重儿子的连边。

  • 轻边:点v与其轻儿子的连边。

  • 重链:由重边连成的路径。

  • 轻链:由轻边连成的路径。


一般来说,我们使用的树链剖分都是重链剖分。至于为什么要用重链剖分,是因为复杂度要控制在log以内,这个是可以证明的,网上也已经有不少博客讲解过,这里就不讲了。

一般的树链剖分主要使用两次DFS(当然BFS也是可以的)欲处理出一些信息,再通过线段树来维护。

第一次DFS:预处理出id(节点编号),dep(节点深度),siz(节点大小),son(节点的重儿子),fa(节点的父节点)。这一部分还是很简单的,下面直接给出代码(路径使用邻接表存储,下面的代码均是如此):
void dfs1(long long id,long long f,long long deep)
{
fa[id]=f;
dep[id]=deep;
siz[id]=1;
son[id]=-1;
long long d=head[id];
while(d)
{
if(a[d].v!=f)
{
dfs1(a[d].v,id,deep+1);
siz[id]+=siz[a[d].v];
if(siz[id]==-1||siz[son[id]]<siz[a[d].v])
son[id]=a[d].v;
}
d=a[d].next;
}
return;
}

第二次DFS:预处理出dfn(节点的dfs序),top(节点所在的链的顶点),num(dfs序所对应的节点编号)。
void dfs2(long long id,long long topf)
{
top[id]=topf;
dfn[id]=++cnt;
num[cnt]=b[id];
if(son[id]!=-1)
dfs2(son[id],topf);
long long d=head[id];
while(d)
{
if(a[d].v!=fa[id]&&a[d].v!=son[id])
dfs2(a[d].v,a[d].v);
d=a[d].next;
}
return;
}

 

我们不难发现,第二次DFS时搜索的顺序是从重儿子开始的,这是因为这样搜索下来我们可以保证每条重链上的节点的dfn是连续的。

因此我们就可以使用线段树维护(以下代码均为线段树的基本操作,可以直接略过)。
void build(long long id,long long l,long long r)
{
tr[id].l=l;
tr[id].r=r;
if(l==r)
{
tr[id].num=num[l];
return;
}
long long mid=(l+r)/2;
build(lch,l,mid);
build(rch,mid+1,r);
tr[id].num=(tr[lch].num+tr[rch].num)%p;
return;
}
void pushdown(long long id)
{
if(tr[id].l!=tr[id].r)
{
long long lazy=tr[id].lazy;
tr[id].lazy=0;
tr[lch].lazy+=lazy;
tr[rch].lazy+=lazy;
tr[lch].num+=lazy*(tr[lch].r-tr[lch].l+1);
tr[rch].num+=lazy*(tr[rch].r-tr[rch].l+1);
tr[lch].num%=p;
tr[rch].num%=p;
tr[lch].lazy%=p;
tr[rch].lazy%=p;
}
return;
}
void add(long long id,long long l,long long r,long long val)
{
pushdown(id);
if(tr[id].l==l&&tr[id].r==r)
{
tr[id].lazy+=val;
tr[id].num+=val*(tr[id].r-tr[id].l+1);
tr[id].lazy%=p;
tr[id].num%=p;
return;
}
long long mid=(tr[id].l+tr[id].r)/2;
if(r<=mid)
add(lch,l,r,val);
if(l>=mid+1)
add(rch,l,r,val);
if(l<=mid&&r>=mid+1)
{
add(lch,l,mid,val);
add(rch,mid+1,r,val);
}
tr[id].num=(tr[lch].num+tr[rch].num)%p;
return;
}
long long ask(long long id,long long l,long long r)
{
pushdown(id);
if(tr[id].l==l&&tr[id].r==r)
return tr[id].num;
long long mid=(tr[id].l+tr[id].r)/2;
if(r<=mid)
return ask(lch,l,r);
if(l>=mid+1)
return ask(rch,l,r);
if(l<=mid&&r>=mid+1)
return (ask(lch,l,mid)+ask(rch,mid+1,r))%p;
}

树链剖分主要用于求树上两点之间的最短距离,那么怎么求呢?我们可以发现,当两个点在一条链上时,两点在链上的距离就是它们之间的最短距离。而当两点不在一条链上时,我们就选择深度较大的拿个节点,将他更新到他所在的链的顶点的父节点。下面给出代码:
void lian(long long x,long long y,long long val)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
add(1,dfn[top[x]],dfn[x],val);
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
add(1,dfn[x],dfn[y],val);
return;
}

 

树链剖分有时也用于修改两点间的最短路径上的权值,实现方式也是和查询差不多的,均是通过线段树来实现。下面也直接给出代码:
long long ask_lian(long long x,long long y)
{
long long ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
ans+=ask(1,dfn[top[x]],dfn[x]);
ans%=p;
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
ans+=ask(1,dfn[x],dfn[y]);
ans%=p;
return ans;
}

讲到这里,树链剖分的主体就结束了,还有一些别的操作就等到了具体的题目再讲吧。

相关练习:洛谷P3384树链剖分(本篇博客使用的代码片段皆节选于此题的代码,忽略我奇怪的变量和函数名

以下是本题的完整代码:
#include<iostream>
#include<cstdio>
using namespace std;
#define lch id*2
#define rch id*2+1
long long n,m,sum,r,p,cnt,b[100005],head[100005],siz[100005],top[100005],son[100005],dep[100005],fa[100005],dfn[100005],num[100005];
struct node
{
long long v;
long long next;
}a[200005];
struct tree
{
long long l;
long long r;
long long num;
long long lazy;
}tr[400005];
void ins(long long u,long long v)
{
++sum;
a[sum].v=v;
a[sum].next=head[u];
head[u]=sum;
return;
}
void dfs1(long long id,long long f,long long deep)
{
fa[id]=f;
dep[id]=deep;
siz[id]=1;
son[id]=-1;
long long d=head[id];
while(d)
{
if(a[d].v!=f)
{
dfs1(a[d].v,id,deep+1);
siz[id]+=siz[a[d].v];
if(siz[id]==-1||siz[son[id]]<siz[a[d].v])
son[id]=a[d].v;
}
d=a[d].next;
}
return;
}
void dfs2(long long id,long long topf)
{
top[id]=topf;
dfn[id]=++cnt;
num[cnt]=b[id];
if(son[id]!=-1)
dfs2(son[id],topf);
long long d=head[id];
while(d)
{
if(a[d].v!=fa[id]&&a[d].v!=son[id])
dfs2(a[d].v,a[d].v);
d=a[d].next;
}
return;
}
void build(long long id,long long l,long long r)
{
tr[id].l=l;
tr[id].r=r;
if(l==r)
{
tr[id].num=num[l];
return;
}
long long mid=(l+r)/2;
build(lch,l,mid);
build(rch,mid+1,r);
tr[id].num=(tr[lch].num+tr[rch].num)%p;
return;
}
void pushdown(long long id)
{
if(tr[id].l!=tr[id].r)
{
long long lazy=tr[id].lazy;
tr[id].lazy=0;
tr[lch].lazy+=lazy;
tr[rch].lazy+=lazy;
tr[lch].num+=lazy*(tr[lch].r-tr[lch].l+1);
tr[rch].num+=lazy*(tr[rch].r-tr[rch].l+1);
tr[lch].num%=p;
tr[rch].num%=p;
tr[lch].lazy%=p;
tr[rch].lazy%=p;
}
return;
}
void add(long long id,long long l,long long r,long long val)
{
pushdown(id);
if(tr[id].l==l&&tr[id].r==r)
{
tr[id].lazy+=val;
tr[id].num+=val*(tr[id].r-tr[id].l+1);
tr[id].lazy%=p;
tr[id].num%=p;
return;
}
long long mid=(tr[id].l+tr[id].r)/2;
if(r<=mid)
add(lch,l,r,val);
if(l>=mid+1)
add(rch,l,r,val);
if(l<=mid&&r>=mid+1)
{
add(lch,l,mid,val);
add(rch,mid+1,r,val);
}
tr[id].num=(tr[lch].num+tr[rch].num)%p;
return;
}
long long ask(long long id,long long l,long long r)
{
pushdown(id);
if(tr[id].l==l&&tr[id].r==r)
return tr[id].num;
long long mid=(tr[id].l+tr[id].r)/2;
if(r<=mid)
return ask(lch,l,r);
if(l>=mid+1)
return ask(rch,l,r);
if(l<=mid&&r>=mid+1)
return (ask(lch,l,mid)+ask(rch,mid+1,r))%p;
}
void lian(long long x,long long y,long long val)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
add(1,dfn[top[x]],dfn[x],val);
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
add(1,dfn[x],dfn[y],val);
return;
}
long long ask_lian(long long x,long long y)
{
long long ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
ans+=ask(1,dfn[top[x]],dfn[x]);
ans%=p;
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
ans+=ask(1,dfn[x],dfn[y]);
ans%=p;
return ans;
}
int main()
{
scanf("%lld%lld%lld%lld",&n,&m,&r,&p);
for(long long i=1;i<=n;++i)
scanf("%lld",&b[i]);
for(long long i=1;i<n;++i)
{
long long x,y;
scanf("%lld%lld",&x,&y);
ins(x,y);
ins(y,x);
}
dfs1(r,0,1);
dfs2(r,r);
build(1,1,n);
for(long long i=1;i<=m;++i)
{
long long pan;
scanf("%lld",&pan);
if(pan==1)
{
long long x,y,z;
scanf("%lld%lld%lld",&x,&y,&z);
lian(x,y,z);
}
if(pan==2)
{
long long x,y;
scanf("%lld%lld",&x,&y);
printf("%lld\n",ask_lian(x,y));
}
if(pan==3)
{
long long x,z;
scanf("%lld%lld",&x,&z);
add(1,dfn[x],dfn[x]+siz[x]-1,z);
}
if(pan==4)
{
long long x;
scanf("%lld",&x);
printf("%lld\n",ask(1,dfn[x],dfn[x]+siz[x]-1));
}
}
return 0;
}

 

评论

此博客中的热门博文