[CF125E]MST Company

Description

给定一个 $n$ 点 $m$ 边的图,以及参数 $k$,求在 1 号点度数强制为 $k$ 的情况下的最小生成树。

$$
1 \le n \le 5000, 1 \le m \le 10 ^5, 1 \le w _i \le 10 ^5, 0 \le k < 5000
$$

Solution

hack

强烈谴责题目数据太水…

首先大部分的凸优化(或者说 wqs/带权二分)的做法都错了。这样只能正确的得到最小生成树的权值,但是要以此为基础构造方案会错。比如下面两组数据:

1
2
3
4
5
6
7
8
9
5 8 3
1 2 7
1 4 8
1 5 7
1 3 3
2 3 1
3 5 1
3 4 10
4 5 3
1
2
3
4
5
6
7
8
9
10
7 9 3
1 2 8
1 7 4
1 4 3
1 5 6
3 6 1
3 4 4
4 7 1
4 6 3
5 6 1

上面的最小生成树的权值是 19,下面的是 20。绝大多数程序会输出 20 和 -1…

可以证明 在这种策略下做出的最小生成树即使度数超过 $k$,也必然可以通过换边来得到度数恰好为 $k$ 的生成树,并且权值一样(凸包上多点共线)。然而直接根据排序顺序并限制度数再去构造方案,很可能得到不连通或错误的结果。

破圈

不妨考虑一个暴力一点的思路($n$ 的范围较小):

首先去掉 1 号点跑最小生成森林。

现在加入 1 号点。设此时除 1 号点外的联通块个数为 $x$,若 $k < x$,必然无解。

否则对于每个联通块加入能联通这个联通块的最小的边,这样就得到了一棵强制 $x$ 度生成树。

现在依然考虑经典的破圈算法:

对于一条不在生成树上的与 1 相连的边,若将其加到最小生成树中,其可以替换所产生的环上的一条最大的边,会使权值变小 $maxEdge - w$,其中 $maxEdge$ 是最大边边权,$w$ 是这条边的边权。

在所有 $maxEdge - w$ 找到其最大值并替换,在所替换的最大边不与 1 相连的前提下,就能得到一棵强制 $x + 1$ 度生成树。

将这个过程进行 $k - x$ 次,就能得到题目所要求的强制 $k$ 度生成树。

找到不与 1 相连的最大边由于数据范围原因可以暴力算。设 $maxE _i$ 表示在当前生成树上从 1 出发到 $i$ 节点的路径上忽略与 1 相连的边的最大边,每次 DFS 一遍 $O(n)$ 算就行。

$O(m \log m + nk)$

二分

虽然我本人感觉是对的,并且经过了上万组数据的对拍,但是因为数据湿度,这个做法依然可能会假,请谨慎阅读

现在这个构造法应该只能做这个题(强制 $k$ 度),另一道也需要二分的题(恰好有 $k$ 条白色边的最小生成树)无法用这个方法构造。

之前已经说了,二分最后得到的最小生成树中 1 的度数会大于等于 $k$。

先通过二分把这个生成树建出来,并固定 1 号点为根。

考虑枚举生成树之外的所有边用来替换那些与 1 相连的边,直到 1 号点度数为 $k$。假设当前边为 $(u, v)$,权值为 $w$,以及 $u$ 归属于 1 号点的儿子中 $p$ 的子树,$v$ 归属于 1 号点的儿子中 $q$ 的子树。

如果 $p = q$ 自然无可替换。否则,如果 $p, q$ 中任意一个点与 1 相连的边权是 $w$,那么这条边 $(1, p/q)$ 就可以被 $(u, v)$ 替换。不妨设可替换的边是 $(1, p)$,那么 $p$ 子树内的点需要全部接入 $q$(也即下一次查询 $p$ 子树内的点,得到结果应当为 $q$)。这是简单并查集操作。

另外边权必须恰好是 $w$,才可以替换。实际上因为已经是最小生成树了,不可能替换之后边权更小,证明中也提到了一定是边权相同的边导致了 1 的度数超出了 $k$,所以必须严格等于。

复杂度的话,二分部分为 $O(m \log m \log SIZE)$,其中 $SIZE$ 为值域大小。如果将与 1 相连的边与其他边每次归并排序或者桶排可以优化变为 $O(m \log SIZE)$。换边构树部分为 $O(m \log n)$。

Code

gen

1
2
3
4
5
6
7
8
9
10
from random import randint
from cyaron import *
import os
n = 200
m = randint(n - 1, min(1e5, n * (n - 1) / 2))
k = randint(1, n - 1)
outs = "%d %d %d\n" % (n, m, k)
graph = Graph.UDAG(n, m, weight_limit = 1e2, self_loop=False, repeated_edges=False)
outs += graph.to_str();
print(outs, file = open("tmp.in", "w"))

spj

需搭配 std 和 gen 食用

另外由于 python 玄学原因,每次只能运行 340 次…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <cstdlib>
#include <iostream>

using namespace std;
using ptr = FILE*;

const int maxN = 5e3 + 5;
const int maxM = 1e5 + 5;

int n, m, _ans, num, k;
int fa[maxN], size[maxN];

struct Edge
{ int u, v, w; }
edge[maxM];

int Find(int x)
{ return fa[x] == x ? x : fa[x] = Find(fa[x]); }

inline bool Merge(int u, int v)
{
int fu = Find(u), fv = Find(v);
if (fu == fv)
return false;
if (size[fu] > size[fv])
swap(fu, fv);
fa[fu] = fv, size[fv] += size[fu];
return true;
}

inline void Err(const char* str)
{ fprintf(stderr, "%s\n", str); exit(0); }

int main()
{
for (int cnt = 1; ; ++cnt)
{
system("python3 gen.py && ./tmp && ./std");
ptr in = fopen("tmp.in", "r"), out = fopen("tmp.out", "r"), ans = fopen("tmp.ans", "r");
int tmp;
fscanf(in, "%d %d %d", &n, &m, &k);
for (int i = 1; i <= n; ++i)
fa[i] = i, size[i] = 1;
for (int i = 1; i <= m; ++i)
fscanf(in, "%d %d %d", &edge[i].u, &edge[i].v, &edge[i].w);
if (fscanf(out, "%d", &_ans) != 1)
Err("No ans.");
if (fscanf(ans, "%d", &num) != 1)
Err("Invalid output.");
if (num == -1 and _ans != -1)
Err("No solution.");
if (num != -1 and _ans == -1)
Err("Not.");
if (num == -1 and _ans == -1)
{
printf("AC %d times!\n", cnt);
continue;
}
if (num != n - 1)
Err("edges' number wrong!");
int curans = 0, c = 0;
for (int i = 1; i <= num; ++i)
{
if (fscanf(ans, "%d", &tmp) != 1)
Err("Invalid output.");
if (!Merge(edge[tmp].u, edge[tmp].v))
Err("Not tree.");
curans += edge[tmp].w;
if (edge[tmp].u == 1 or edge[tmp].v == 1)
++c;
}
if (curans != _ans)
Err("Wrong answer!");
if (c != k)
Err("1's degree isn't k!");
printf("AC %d times!\n", cnt);
}
return 0;
}

std

破圈

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include<bits/stdc++.h>

#define mp make_pair
#define pb push_back
#define x first
#define y second

using namespace std;

typedef long long LL;
typedef pair<int,int> pii;

template <typename T> inline T read()
{
T sum=0, fl=1; char ch=getchar();
for(; !isdigit(ch); ch=getchar()) if(ch=='-') fl=-1;
for(; isdigit(ch); ch=getchar()) sum=(sum<<3)+(sum<<1)+ch-'0';
return sum*fl;
}

const int maxn=5000+5;
const int maxm=1e5+5;

int n, m, k, c;
LL mx[maxn]; bool del[maxm<<1], no[maxm];
int fa[maxn], id[maxm<<1]; vector<int>vec, E;
int he[maxn], ne[maxm<<1], to[maxm<<1], w[maxm<<1];
struct EDGE { int u, v, w, id; } e[maxm];

bool operator<(const EDGE&a, const EDGE&b) { return a.w<b.w; }

void add_edge(int u, int v, int val)
{
ne[++c]=he[u]; he[u]=c; to[c]=v; w[c]=val;
ne[++c]=he[v]; he[v]=c; to[c]=u; w[c]=val;
}

int find(int x) { return fa[x]==x ? x : fa[x]=find(fa[x]); }
void forest()
{
for(int i=1; i<=n; ++i) fa[i]=i;
sort(e+1, e+m+1);
for(int i=1; i<=m; ++i)
{
int u=find(e[i].u), v=find(e[i].v);
if(u==v || u==1 || v==1) continue;
fa[u]=v; E.pb(e[i].id);
add_edge(e[i].u, e[i].v, e[i].w); id[c]=id[c-1]=e[i].id;
}
}

void get_max(int p, int f)
{
for(int i=he[p]; i; i=ne[i])
{
if(to[i]==f || del[i] || to[i]==1) continue;
if(w[i]>=w[mx[p]]) mx[to[i]]=i;
else mx[to[i]]=mx[p];
get_max(to[i], p);
}
}

void Solve()
{
n=read<int>(), m=read<int>(), k=read<int>(); c=1;
for(int i=1; i<=m; ++i) e[i]=(EDGE){read<int>(), read<int>(), read<int>(), i};
forest();

for(int i=1; i<=m; ++i)
{
if(e[i].u!=1 && e[i].v!=1) continue;
int u=find(e[i].u), v=find(e[i].v);
if(u == v) vec.pb(i);
else fa[u]=v, add_edge(e[i].u, e[i].v, e[i].w), E.pb(e[i].id), --k;
}

for(int i=2; i<=n; ++i) if(find(i)!=find(i-1)) return (void) printf("-1");
if(k<0) return (void) printf("-1");

for(int i=1; i<=k; ++i)
{
if(!vec.size()) return (void) printf("-1"); /**/
memset(mx, 0, sizeof(mx));
for(int j=he[1]; j; j=ne[j]) get_max(to[j], 1);
int now=-1, ans=2147483647;
for(int j=vec.size()-1; j>=0; --j)
{
int p=vec[j], var=max(e[p].u, e[p].v);
if(e[p].w-w[mx[var]]<ans) ans=e[p].w-w[mx[var]], now=j;
}
int p=vec[now], var=e[p].u==1?e[p].v:e[p].u;
add_edge(e[p].v, e[p].u, e[p].w); E.pb(e[p].id);
no[id[mx[var]]]=1; del[mx[var]]=del[mx[var]^1]=1;
vec.erase(vec.begin()+now);
}

printf("%d\n", n-1);
for(int i=E.size()-1; i>=0; --i) if(!no[E[i]]) printf("%d ", E[i]);
}

int main()
{

// freopen("125e.in","r",stdin);
// freopen("125e.out","w",stdout);

Solve();

return 0;
}

二分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#include <iostream>
#include <set>
#include <vector>
#include <bitset>
#include <algorithm>

using namespace std;

const int maxN = 5e3 + 5;
const int maxM = 1e5 + 5;

int n, m, k, maxw, ans;
int fa[maxN], size[maxN], val[maxN], id[maxN];
set<int> out;
vector<struct Edge> G[maxN];
bitset<maxM> vis;

struct Edge
{
int u, v, w, id;

Edge() { }

Edge(int _u, int _v, int _w, int _id) : u(_u), v(_v), w(_w), id(_id) { }
} edge[maxM];

bool operator<(const Edge& x, const Edge& y)
{ return x.w == y.w ? x.u < y.u : x.w < y.w; }

int Find(int x)
{ return fa[x] == x ? x : fa[x] = Find(fa[x]); }

inline bool Merge(int u, int v)
{
int fu = Find(u), fv = Find(v);
if (fu == fv)
return false;
if (size[fu] > size[fv])
swap(fu, fv);
fa[fu] = fv, size[fv] += size[fu];
return true;
}

void Init()
{
for (int i = 1; i <= n; ++i)
fa[i] = i, size[i] = 1;
}

bool Check()
{
Init();
int deg = 0;
for (int i = 1; i <= m; ++i)
if (edge[i].u != 1)
Merge(edge[i].u, edge[i].v);
else
++deg;
if (deg < k)
return false;
int cnt = 0;
for (int i = 2; i <= n; ++i)
if (Find(i) == i)
++cnt;
if (cnt > k)
return false;
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
Merge(edge[i].u, edge[i].v);
cnt = 0;
for (int i = 1; i <= n; ++i)
if (Find(i) == i)
++cnt;
return cnt == 1;
}

void DFS(int u, int father)
{
fa[u] = fa[father];
for (const auto& v : G[u]) if (v.u != father)
DFS(v.u, u);
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("tmp.in", "r", stdin);
freopen("tmp.ans", "w", stdout);
#endif
ios::sync_with_stdio(false);
cin >> n >> m >> k;
if (n == 1)
{ cout << 0 << endl; return 0; }
for (int i = 1; i <= m; ++i)
{
Edge& o = edge[i];
cin >> o.u >> o.v >> o.w;
o.id = i;
if (o.u > o.v)
swap(o.u, o.v);
maxw = max(maxw, o.w);
}
if (!Check())
{
cout << "-1" << endl;
return 0;
}
int l = -maxw, r = maxw, p, ecnt, cnt;
while (l <= r)
{
p = (l + r) >> 1;
ecnt = cnt = 0;
Init();
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
edge[i].w += p;
sort(edge + 1, edge + m + 1);
for (int i = 1; i <= m; ++i)
if (Merge(edge[i].u, edge[i].v))
{
++ecnt;
if (edge[i].u == 1)
++cnt;
if (ecnt == n - 1)
break;
}
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
edge[i].w -= p;
if (cnt >= k)
ans = p, l = p + 1;
else
r = p - 1;
}
Init(), cnt = 0;
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
edge[i].w += ans;
sort(edge + 1, edge + m + 1);
for (int i = 1; i <= m; ++i)
if (Merge(edge[i].u, edge[i].v))
{
vis.set(i);
if (edge[i].u == 1)
++cnt;
out.insert(edge[i].id);
G[edge[i].u].emplace_back(edge[i].v, 0, edge[i].w, edge[i].id);
G[edge[i].v].emplace_back(edge[i].u, 0, edge[i].w, edge[i].id);
}
for (const auto& v : G[1])
{
val[v.u] = v.w, id[v.u] = v.id;
fa[1] = v.u, DFS(v.u, 1);
}
for (int i = 1; i <= m and cnt > k; ++i) if (edge[i].u != 1 and !vis[i])
{
int fu = Find(edge[i].u), fv = Find(edge[i].v);
if (fu == fv)
continue;
if (val[fu] == edge[i].w)
{
fa[fu] = fv, --cnt;
out.insert(edge[i].id);
out.erase(id[fu]);
continue;
}
if (val[fv] == edge[i].w)
{
fa[fv] = fu, --cnt;
out.insert(edge[i].id);
out.erase(id[fv]);
continue;
}
}
if (cnt == k)
{
cout << n - 1 << endl;
for (int i : out)
cout << i << ' ';
cout << endl;
}
else
cout << "-1" << endl;
return 0;
}