Kuhn-Munkres

Kuhn-Munkres算法(KM算法)的作用是求解二分图最大权最佳完美匹配。

二分图的各种概念这里就不提了,下面直接进入正题。

KM算法是通过给每个顶点一个标号(顶标)来求最大权完美匹配的问题转化为求完美匹配的问题的。

首先介绍两个概念:

可行顶标:对于所有顶点的值l,使得对于任意边e:x→y,都满足\(l_x+l_y \geq w_e\),KM算法始终满足此条件。

相等子图:包含原图中的所有点,但只包含满足\(l_x+l_y=w_e\)的边,那么因此如果相等子图有完美匹配,这个匹配一定是最大完美匹配(因为任意边都受到条件\(l_x+l_y=w_e\)的限制,所以不可能大于所有点的定标之和)。

我们设二分图右侧所有节点顶标为0,那么所有左侧节点的顶标必为从它出发所有的边的最大值。

然后求满足上述条件的完美匹配,如果成功,则算法结束,否则我们必须修改顶标,让更多的边能够参与进来。

我们求当前相等子图的完美匹配失败是由于对于某个未匹配顶点,我们找不到从它出发的增广路,此时我们只能得到一条交替路。如果我们将交替路中左侧顶点的顶标全部减小某个值k,右侧顶点的顶标全部增加某个值k,那么我们就可以发现:两端都在交替路中的边没有变化;两端都不在交替路的边也没有变化;对于一端在交替路左侧,另一段不在交替路的边,\(l_x+l_y\)有所减小,因此它有可能进入了相等子图,因而使相等子图得到了扩大;对于一端在交替路右侧,另一端不在交替路的边,\(l_x+l_y\)有所扩大,因此仍不可能属于相等子图。

我们按照贪心的思想,只能让满足条件的权值最大的边能被选中,因此这个k值应该取\(min{l_x+l_y-w_e}\),其中x属于左侧交替路,y不属于右侧交替路。

每次这样修改后就又有新的边进入相等子图,然后我们就可以继续寻找增广路。

因此算法的轮廓就出来了:

我们按照匈牙利算法的思路进行增广,如果没有找到匹配则修改顶标的值,然后重复以上步骤即可。复杂度\(O(n^3)\)(这个复杂度其实是有些假的,可以被卡,真的n^3复杂度需要把dfs改为bfs,不过我还没见过那么毒瘤的题)。

例题为HDU2255

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define inf 0x7f7f7f7f
int n, sum, head[305], mch[305], s[305], t[305], slack[305], x[305], y[305];
struct node
{
    int v;
    int w;
    int next;
} a[90005];
inline void ins(int u, int v, int w)
{
    ++sum;
    a[sum].v = v;
    a[sum].w = w;
    a[sum].next = head[u];
    head[u] = sum;
    return;
}
inline bool dfs(int k, int cnt)
{
    s[k] = cnt;
    int d = head[k];
    while (d)
    {
        if (t[a[d].v] == cnt)
        {
            d = a[d].next;
            continue;
        }
        int p = x[k] + y[a[d].v] - a[d].w;
        if (!p)
        {
            t[a[d].v] = cnt;
            if (!mch[a[d].v] || dfs(mch[a[d].v], cnt))
            {
                mch[a[d].v] = k;
                return true;
            }
        }
        else
            slack[a[d].v] = min(slack[a[d].v], p);
        d = a[d].next;
    }
    return false;
}
inline void update(int cnt)
{
    int p = inf;
    for (int i = 1; i <= n; ++i)
        if (t[i] != cnt)
            p = min(p, slack[i]);
    for (int i = 1; i <= n; ++i)
    {
        if (s[i] == cnt)
            x[i] -= p;
        if (t[i] == cnt)
            y[i] += p;
    }
    return;
}
inline void km()
{
    memset(x, 0, sizeof(x));
    memset(y, 0, sizeof(y));
    memset(mch, 0, sizeof(mch));
    for (int i = 1; i <= n; ++i)
    {
        int d = head[i];
        while (d)
        {
            x[i] = max(x[i], a[d].w);
            d = a[d].next;
        }
    }
    for (int i = 1; i <= n; ++i)
    {
        memset(s, 0, sizeof(s));
        memset(t, 0, sizeof(t));
        memset(slack, 0x7f, sizeof(slack));
        int cnt = 0;
        while (!dfs(i, ++cnt))
            update(cnt);
    }
    return;
}
int main()
{
    while (scanf("%d", &n) == 1)
    {
        sum = 0;
        memset(head, 0, sizeof(head));
        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= n; ++j)
            {
                int w;
                scanf("%d", &w);
                ins(i, j, w);
            }
        km();
        int ans = 0;
        for (int i = 1; i <= n; ++i)
            ans += x[i] + y[i];
        printf("%d\n", ans);
    }
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注