树链剖分

终于开始填暑假的坑了

首先是概念:总的来说,就是把一棵树剖分成若干条链,然后利用数据结构维护每一条链。复杂度:O(logn)

在学习树链剖分之前,建议先去学线段树。

首先列出一些基本概念:

  • 重儿子:若siz[u]是v的子节点中siz值最大的,那么u就是v的重儿子。
  • 轻儿子:v的其它子节点。
  • 重边:点v与其重儿子的连边。
  • 轻边:点v与其轻儿子的连边。
  • 重链:由重边连成的路径。
  • 轻链:由轻边连成的路径。

一般来说,我们使用的树链剖分都是重链剖分。至于为什么要用重链剖分,是因为复杂度要控制在log以内,这个是可以证明的,网上也已经有不少博客讲解过,这里就不讲了。

一般的树链剖分主要使用两次DFS(当然BFS也是可以的)欲处理出一些信息,再通过线段树来维护。

第一次DFS:预处理出id(节点编号),dep(节点深度),siz(节点大小),son(节点的重儿子),fa(节点的父节点)。这一部分还是很简单的,下面直接给出代码(路径使用邻接表存储,下面的代码均是如此):

void dfs1(long long id,long long f,long long deep)
{
    fa[id]=f;
    dep[id]=deep;
    siz[id]=1;
    son[id]=-1;
    long long d=head[id];
    while(d)
    {
        if(a[d].v!=f)
        {
            dfs1(a[d].v,id,deep+1);
            siz[id]+=siz[a[d].v];
            if(siz[id]==-1||siz[son[id]]<siz[a[d].v])
                son[id]=a[d].v;
        }
        d=a[d].next;
    }
    return;
}

第二次DFS:预处理出dfn(节点的dfs序),top(节点所在的链的顶点),num(dfs序所对应的节点编号)。

void dfs2(long long id,long long topf)
{
    top[id]=topf;
    dfn[id]=++cnt;
    num[cnt]=b[id];
    if(son[id]!=-1)
        dfs2(son[id],topf);
    long long d=head[id];
    while(d)
    {
        if(a[d].v!=fa[id]&&a[d].v!=son[id])
            dfs2(a[d].v,a[d].v);
        d=a[d].next;
    }
    return;
}

 

我们不难发现,第二次DFS时搜索的顺序是从重儿子开始的,这是因为这样搜索下来我们可以保证每条重链上的节点的dfn是连续的。

因此我们就可以使用线段树维护(以下代码均为线段树的基本操作,可以直接略过)。

void build(long long id,long long l,long long r)
{
    tr[id].l=l;
    tr[id].r=r;
    if(l==r)
    {
        tr[id].num=num[l];
        return;
    }
    long long mid=(l+r)/2;
    build(lch,l,mid);
    build(rch,mid+1,r);
    tr[id].num=(tr[lch].num+tr[rch].num)%p;
    return;
}
void pushdown(long long id)
{
    if(tr[id].l!=tr[id].r)
    {
        long long lazy=tr[id].lazy;
        tr[id].lazy=0;
        tr[lch].lazy+=lazy;
        tr[rch].lazy+=lazy;
        tr[lch].num+=lazy*(tr[lch].r-tr[lch].l+1);
        tr[rch].num+=lazy*(tr[rch].r-tr[rch].l+1);
        tr[lch].num%=p;
        tr[rch].num%=p;
        tr[lch].lazy%=p;
        tr[rch].lazy%=p;
    }
    return;
}
void add(long long id,long long l,long long r,long long val)
{
    pushdown(id);
    if(tr[id].l==l&&tr[id].r==r)
    {
        tr[id].lazy+=val;
        tr[id].num+=val*(tr[id].r-tr[id].l+1);
        tr[id].lazy%=p;
        tr[id].num%=p;
        return;
    }
    long long mid=(tr[id].l+tr[id].r)/2;
    if(r<=mid)
        add(lch,l,r,val);
    if(l>=mid+1)
        add(rch,l,r,val);
    if(l<=mid&&r>=mid+1)
    {
        add(lch,l,mid,val);
        add(rch,mid+1,r,val);
    }
    tr[id].num=(tr[lch].num+tr[rch].num)%p;
    return;
}
long long ask(long long id,long long l,long long r)
{
    pushdown(id);
    if(tr[id].l==l&&tr[id].r==r)
        return tr[id].num;
    long long mid=(tr[id].l+tr[id].r)/2;
    if(r<=mid)
        return ask(lch,l,r);
    if(l>=mid+1)
        return ask(rch,l,r);
    if(l<=mid&&r>=mid+1)
        return (ask(lch,l,mid)+ask(rch,mid+1,r))%p;
}

树链剖分主要用于求树上两点之间的最短距离,那么怎么求呢?我们可以发现,当两个点在一条链上时,两点在链上的距离就是它们之间的最短距离。而当两点不在一条链上时,我们就选择深度较大的拿个节点,将他更新到他所在的链的顶点的父节点。下面给出代码:

void lian(long long x,long long y,long long val)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        add(1,dfn[top[x]],dfn[x],val);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    add(1,dfn[x],dfn[y],val);
    return;
}

 

树链剖分有时也用于修改两点间的最短路径上的权值,实现方式也是和查询差不多的,均是通过线段树来实现。下面也直接给出代码:

long long ask_lian(long long x,long long y)
{
    long long ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        ans+=ask(1,dfn[top[x]],dfn[x]);
        ans%=p;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    ans+=ask(1,dfn[x],dfn[y]);
    ans%=p;
    return ans;
}

讲到这里,树链剖分的主体就结束了,还有一些别的操作就等到了具体的题目再讲吧。

相关练习:洛谷P3384树链剖分(本篇博客使用的代码片段皆节选于此题的代码,忽略我奇怪的变量和函数名

以下是本题的完整代码:

#include<iostream>
#include<cstdio>
using namespace std;
#define lch id*2
#define rch id*2+1
long long n,m,sum,r,p,cnt,b[100005],head[100005],siz[100005],top[100005],son[100005],dep[100005],fa[100005],dfn[100005],num[100005];
struct node
{
    long long v;
    long long next;
}a[200005];
struct tree
{
    long long l;
    long long r;
    long long num;
    long long lazy;
}tr[400005];
void ins(long long u,long long v)
{
    ++sum;
    a[sum].v=v;
    a[sum].next=head[u];
    head[u]=sum;
    return;
}
void dfs1(long long id,long long f,long long deep)
{
    fa[id]=f;
    dep[id]=deep;
    siz[id]=1;
    son[id]=-1;
    long long d=head[id];
    while(d)
    {
        if(a[d].v!=f)
        {
            dfs1(a[d].v,id,deep+1);
            siz[id]+=siz[a[d].v];
            if(siz[id]==-1||siz[son[id]]<siz[a[d].v])
                son[id]=a[d].v;
        }
        d=a[d].next;
    }
    return;
}
void dfs2(long long id,long long topf)
{
    top[id]=topf;
    dfn[id]=++cnt;
    num[cnt]=b[id];
    if(son[id]!=-1)
        dfs2(son[id],topf);
    long long d=head[id];
    while(d)
    {
        if(a[d].v!=fa[id]&&a[d].v!=son[id])
            dfs2(a[d].v,a[d].v);
        d=a[d].next;
    }
    return;
}
void build(long long id,long long l,long long r)
{
    tr[id].l=l;
    tr[id].r=r;
    if(l==r)
    {
        tr[id].num=num[l];
        return;
    }
    long long mid=(l+r)/2;
    build(lch,l,mid);
    build(rch,mid+1,r);
    tr[id].num=(tr[lch].num+tr[rch].num)%p;
    return;
}
void pushdown(long long id)
{
    if(tr[id].l!=tr[id].r)
    {
        long long lazy=tr[id].lazy;
        tr[id].lazy=0;
        tr[lch].lazy+=lazy;
        tr[rch].lazy+=lazy;
        tr[lch].num+=lazy*(tr[lch].r-tr[lch].l+1);
        tr[rch].num+=lazy*(tr[rch].r-tr[rch].l+1);
        tr[lch].num%=p;
        tr[rch].num%=p;
        tr[lch].lazy%=p;
        tr[rch].lazy%=p;
    }
    return;
}
void add(long long id,long long l,long long r,long long val)
{
    pushdown(id);
    if(tr[id].l==l&&tr[id].r==r)
    {
        tr[id].lazy+=val;
        tr[id].num+=val*(tr[id].r-tr[id].l+1);
        tr[id].lazy%=p;
        tr[id].num%=p;
        return;
    }
    long long mid=(tr[id].l+tr[id].r)/2;
    if(r<=mid)
        add(lch,l,r,val);
    if(l>=mid+1)
        add(rch,l,r,val);
    if(l<=mid&&r>=mid+1)
    {
        add(lch,l,mid,val);
        add(rch,mid+1,r,val);
    }
    tr[id].num=(tr[lch].num+tr[rch].num)%p;
    return;
}
long long ask(long long id,long long l,long long r)
{
    pushdown(id);
    if(tr[id].l==l&&tr[id].r==r)
        return tr[id].num;
    long long mid=(tr[id].l+tr[id].r)/2;
    if(r<=mid)
        return ask(lch,l,r);
    if(l>=mid+1)
        return ask(rch,l,r);
    if(l<=mid&&r>=mid+1)
        return (ask(lch,l,mid)+ask(rch,mid+1,r))%p;
}
void lian(long long x,long long y,long long val)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        add(1,dfn[top[x]],dfn[x],val);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    add(1,dfn[x],dfn[y],val);
    return;
}
long long ask_lian(long long x,long long y)
{
    long long ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        ans+=ask(1,dfn[top[x]],dfn[x]);
        ans%=p;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    ans+=ask(1,dfn[x],dfn[y]);
    ans%=p;
    return ans;
}
int main()
{
    scanf("%lld%lld%lld%lld",&n,&m,&r,&p);
    for(long long i=1;i<=n;++i)
        scanf("%lld",&b[i]);
    for(long long i=1;i<n;++i)
    {
        long long x,y;
        scanf("%lld%lld",&x,&y);
        ins(x,y);
        ins(y,x);
    }
    dfs1(r,0,1);
    dfs2(r,r);
    build(1,1,n);
    for(long long i=1;i<=m;++i)
    {
        long long pan;
        scanf("%lld",&pan);
        if(pan==1)
        {
            long long x,y,z;
            scanf("%lld%lld%lld",&x,&y,&z);
            lian(x,y,z);
        }
        if(pan==2)
        {
            long long x,y;
            scanf("%lld%lld",&x,&y);
            printf("%lld\n",ask_lian(x,y));
        }
        if(pan==3)
        {
            long long x,z;
            scanf("%lld%lld",&x,&z);
            add(1,dfn[x],dfn[x]+siz[x]-1,z);
        }
        if(pan==4)
        {
            long long x;
            scanf("%lld",&x);
            printf("%lld\n",ask(1,dfn[x],dfn[x]+siz[x]-1));
        }
    }
    return 0;
}

 

发表评论

电子邮件地址不会被公开。 必填项已用*标注