Kuhn-Munkres

Kuhn-Munkres算法(KM算法)的作用是求解二分图最大权最佳完美匹配。



二分图的各种概念这里就不提了,下面直接进入正题。

KM算法是通过给每个顶点一个标号(顶标)来求最大权完美匹配的问题转化为求完美匹配的问题的。

首先介绍两个概念:

可行顶标:对于所有顶点的值l,使得对于任意边e:x→y,都满足$latex l_x+l_y \geq w_e$,KM算法始终满足此条件。

相等子图:包含原图中的所有点,但只包含满足$latex l_x+l_y=w_e$的边,那么因此如果相等子图有完美匹配,这个匹配一定是最大完美匹配(因为任意边都受到条件$latex l_x+l_y=w_e$的限制,所以不可能大于所有点的定标之和)。

我们设二分图右侧所有节点顶标为0,那么所有左侧节点的顶标必为从它出发所有的边的最大值。

然后求满足上述条件的完美匹配,如果成功,则算法结束,否则我们必须修改顶标,让更多的边能够参与进来。

我们求当前相等子图的完美匹配失败是由于对于某个未匹配顶点,我们找不到从它出发的增广路,此时我们只能得到一条交替路。如果我们将交替路中左侧顶点的顶标全部减小某个值k,右侧顶点的顶标全部增加某个值k,那么我们就可以发现:两端都在交替路中的边没有变化;两端都不在交替路的边也没有变化;对于一端在交替路左侧,另一段不在交替路的边,$latex l_x+l_y$有所减小,因此它有可能进入了相等子图,因而使相等子图得到了扩大;对于一端在交替路右侧,另一端不在交替路的边,$latex l_x+l_y$有所扩大,因此仍不可能属于相等子图。

我们按照贪心的思想,只能让满足条件的权值最大的边能被选中,因此这个k值应该取$latex min{l_x+l_y-w_e}$,其中x属于左侧交替路,y不属于右侧交替路。

每次这样修改后就又有新的边进入相等子图,然后我们就可以继续寻找增广路。

因此算法的轮廓就出来了:

我们按照匈牙利算法的思路进行增广,如果没有找到匹配则修改顶标的值,然后重复以上步骤即可。复杂度$latex O(n^3)$(这个复杂度其实是有些假的,可以被卡,真的n^3复杂度需要把dfs改为bfs,不过我还没见过那么毒瘤的题)。

例题为HDU2255
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define inf 0x7f7f7f7f
int n, sum, head[305], mch[305], s[305], t[305], slack[305], x[305], y[305];
struct node
{
int v;
int w;
int next;
} a[90005];
inline void ins(int u, int v, int w)
{
++sum;
a[sum].v = v;
a[sum].w = w;
a[sum].next = head[u];
head[u] = sum;
return;
}
inline bool dfs(int k, int cnt)
{
s[k] = cnt;
int d = head[k];
while (d)
{
if (t[a[d].v] == cnt)
{
d = a[d].next;
continue;
}
int p = x[k] + y[a[d].v] - a[d].w;
if (!p)
{
t[a[d].v] = cnt;
if (!mch[a[d].v] || dfs(mch[a[d].v], cnt))
{
mch[a[d].v] = k;
return true;
}
}
else
slack[a[d].v] = min(slack[a[d].v], p);
d = a[d].next;
}
return false;
}
inline void update(int cnt)
{
int p = inf;
for (int i = 1; i <= n; ++i)
if (t[i] != cnt)
p = min(p, slack[i]);
for (int i = 1; i <= n; ++i)
{
if (s[i] == cnt)
x[i] -= p;
if (t[i] == cnt)
y[i] += p;
}
return;
}
inline void km()
{
memset(x, 0, sizeof(x));
memset(y, 0, sizeof(y));
memset(mch, 0, sizeof(mch));
for (int i = 1; i <= n; ++i)
{
int d = head[i];
while (d)
{
x[i] = max(x[i], a[d].w);
d = a[d].next;
}
}
for (int i = 1; i <= n; ++i)
{
memset(s, 0, sizeof(s));
memset(t, 0, sizeof(t));
memset(slack, 0x7f, sizeof(slack));
int cnt = 0;
while (!dfs(i, ++cnt))
update(cnt);
}
return;
}
int main()
{
while (scanf("%d", &n) == 1)
{
sum = 0;
memset(head, 0, sizeof(head));
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j)
{
int w;
scanf("%d", &w);
ins(i, j, w);
}
km();
int ans = 0;
for (int i = 1; i <= n; ++i)
ans += x[i] + y[i];
printf("%d\n", ans);
}
return 0;
}
UPDATE:洛谷P6577有卡dfs的模板了,补一个bfs版
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define int long long
#define inf 0x7fffffffffffffff
int n, m, sum;
int head[505], mx[505], my[505];
int s[505], t[505], slack[505];
int x[505], y[505], pre[505];
int l[250005];
struct node
{
int v;
int w;
int nxt;
} a[500005];
void ins(int u, int v, int w)
{
++sum;
a[sum].v = v;
a[sum].w = w;
a[sum].nxt = head[u];
head[u] = sum;
return;
}
void mch(int k)
{
while (k)
{
int tmp = mx[pre[k]];
mx[pre[k]] = k;
my[k] = pre[k];
k = tmp;
}
return;
}
void bfs(int k)
{
int hh = 0, tt = 1;
l[1] = k;
while (1)
{
while (hh < tt)
{
k = l[++hh];
s[k] = 1;
int d = head[k];
while (d)
{
if (!t[a[d].v])
{
int p = x[k] + y[a[d].v] - a[d].w;
if (p < slack[a[d].v])
pre[a[d].v] = k;
if (!p)
{
t[a[d].v] = 1;
if (!my[a[d].v])
{
mch(a[d].v);
return;
}
else
l[++tt] = my[a[d].v];
}
else
slack[a[d].v] = min(slack[a[d].v], p);
}
d = a[d].nxt;
}
}
int p = inf;
for (int i = 1; i <= n; ++i)
if (!t[i])
p = min(p, slack[i]);
for (int i = 1; i <= n; ++i)
{
if (s[i])
x[i] -= p;
if (t[i])
y[i] += p;
else
slack[i] -= p;
}
for (int i = 1; i <= n; ++i)
if (!t[i] && !slack[i])
{
t[i] = 1;
if (!my[i])
{
mch(i);
return;
}
else
l[++tt] = my[i];
}
}
return;
}
void km()
{
for (int i = 1; i <= n; ++i)
{
int d = head[i];
while (d)
{
x[i] = max(x[i], a[d].w);
d = a[d].nxt;
}
}
for (int i = 1; i <= n; ++i)
{
memset(s, 0, sizeof(s));
memset(t, 0, sizeof(t));
memset(slack, 0x7f, sizeof(slack));
bfs(i);
}
return;
}
signed main()
{
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= m; ++i)
{
int u, v, w;
scanf("%lld%lld%lld", &u, &v, &w);
ins(u, v, w);
}
km();
int ans = 0;
for (int i = 1; i <= n; ++i)
ans += x[i] + y[i];
printf("%lld\n", ans);
for (int i = 1; i <= n; ++i)
printf("%lld ", my[i]);
return 0;
}

评论

此博客中的热门博文