[Luogu5599]失昼城的守星使

Description

给定一棵 $n$ 个节点的树,边有边权,每个节点初始有一个颜色(黑 / 白)。有 $m$ 个操作:

1 x: 翻转 $x$ 节点的颜色。

2 x y: 询问所有黑色节点到 $x \to y$ 这条链上的距离之和。

(定义一个节点到链的距离为到链上节点距离的最小值)

$$
1 \le n, m \le 2 \times 10 ^5
$$

Solution

有两种做法。先介绍一种细节少一点的,也较为优美的做法,第二种做法应该是标算。

First

(应该)与标算不太一样,主要是实现上除了板子几乎没有细节。

首先随意定根。

一个点 $u$ 到链的距离等价于 $u$ 到这条链上 LCA 的距离减去 $u$ 到 LCA 的重合路径部分。这里的重合指的是 $u$ 到 LCA 的路径与这条链的重合部分。

考虑这个重合路径部分怎么算。类似 [LNOI2014]LCA 的做法,重合部分显然等价于 $u$ 到根的路径与这条链的重合部分。将每个黑点到根的边全部打上一个 +1 的标记,每次直接查询这条链上的路径和即可。

(注意这里的 +1 是累加一次边权,不是权值 +1)

然后再考虑 $u$ 到这条链上 LCA 的距离。

现在设 LCA 为 $v$,拆一下可以得到 $dis (u, v) = dis _u + dis _v - 2 dis _{lca(u, v)}$(注意这里的 $dis(u, v)$ 是两点距离,$dis _u$ 是 $u$ 到根的距离)

用上面的修改方法一样能做,查询 $v$ 到根的路径上的路径和,即可得到 $\sum dis _{lca(u, v)}$。而 $\sum dis _u + dis _v$ 很容易维护(直接记当前有多少个黑点和黑点的 $dis$ 和),于是解决问题。

用 LCT 处理可以做到 $O(n \log n)$,不过懒得写了,$O(n \log ^2 n)$ 树剖就行。

虽然是一个比较奇怪的区间加(整体累加一次边权),不过大同小异,跟普通线段树没有什么区别。

Second

树剖之后,每个节点记录除重儿子子树外子树内黑点到当前节点的距离。

设链两端点为 $u, v$,其 LCA 为 $x$。首先考虑询问时,$x$ 的子树内的那些黑点的贡献怎么处理。

容易想到,因为每个节点忽略了重儿子子树,所以在跳链的时候,一段重链上的点权可以直接累加;而跨过轻链时,对于重链头的父亲节点 $y$,需要计算重儿子子树内的贡献,减去当前这个重链头的子树内的贡献。

注意这里减去加上的是整棵子树的贡献,不能忽略重儿子。

子树外的用 First 里说的统计就行。

这个方法细节特别多所以我就不写了 而且我只是口胡出锅了不能找我负责

如果想写这种做法的话去网上搜搜题解吧,这里只是一个大概的思路。

(大概看了下官方题解,没怎么看懂)

Code

gen

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from random import randint
from cyaron import *
import os
cnt = 0
while True :
n, m = int(2e5), int(2e5)
outs = "%d %d %d\n" % (n, m, 0)
tree = Graph.tree(n, weight_limit = 1e5)
for i in tree.iterate_edges() :
outs += "%d %d %d\n" % (i.start, i.end, i.weight)
for i in range(n) :
outs += "%d " % (randint(0, 1))
outs += '\n'
for i in range(m) :
if randint(0, 1) is 0 :
outs += "1 %d\n" % randint(1, n)
else :
outs += "2 %d %d\n" % (randint(1, n), randint(1, n))
print(outs, file = open("tmp.in", "w"))
os.system("./std && ./tmp")
if os.system("diff tmp.out tmp.ans") :
print("WA")
exit()
else :
cnt += 1
print("AC %d times!" % cnt)

std

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
/**********************************************************
* Author : EndSaH
* Email : hjxhb1@gmail.com
* Created Time : 2019-09-17 10:26
* FileName : tmp.cpp
* Website : https://endsah.cf
* *******************************************************/

#include <cstdio>
#include <cctype>
#include <bitset>
#include <vector>

using pii = std::pair<int, int>;
using LL = long long;

#define fir first
#define sec second
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define Debug(s) debug("The message in line %d, Function %s: %s\n", __LINE__, __FUNCTION__, s)
#define getchar() (ipos == iend and (iend = (ipos = _ibuf) + fread(_ibuf, 1, __bufsize, stdin), ipos == iend) ? EOF : *ipos++)
#define putchar(ch) (opos == oend ? fwrite(_obuf, 1, __bufsize, stdout), opos = _obuf : 0, *opos++ = (ch))
#define __bufsize (1 << 21 | 1)

char _ibuf[__bufsize], _obuf[__bufsize], _stk[50];
char *ipos = _ibuf, *iend = _ibuf, *opos = _obuf, *oend = _obuf + __bufsize, *stkpos = _stk;

struct END
{ ~END() { fwrite(_obuf, 1, opos - _obuf, stdout); } }
__;

inline int read()
{
register int x = 0;
register char ch;
while (!isdigit(ch = getchar()));
while (x = x * 10 + (ch & 15), isdigit(ch = getchar()));
return x;
}

template <typename _INT>
inline void write(_INT x)
{
while (*++stkpos = x % 10 ^ 48, x /= 10, x);
while (stkpos != _stk)
putchar(*stkpos--);
}

template <typename _Tp>
inline bool Chkmax(_Tp& x, const _Tp& y)
{ return x < y ? x = y, true : false; }

template <typename _Tp>
inline bool Chkmin(_Tp& x, const _Tp& y)
{ return x > y ? x = y, true : false; }

const int maxN = 2e5 + 5;

int n, m, cnt;
LL sum;
std::bitset<maxN> vis;
std::vector<pii> G[maxN];

// HLD and SEG
namespace HLD
{
int dfst, ql, qr, addval;
int size[maxN], top[maxN], son[maxN], dfn[maxN], fa[maxN], dep[maxN];
LL dis[maxN];
int val[maxN], ref[maxN];
int tag[maxN << 2];
LL sum[maxN << 2], real[maxN << 2];

void DFS1(int u)
{
size[u] = 1;
for (const auto& i : G[u])
{
if (size[i.fir])
continue;
fa[i.fir] = u, dep[i.fir] = dep[u] + 1;
val[i.fir] = i.sec, dis[i.fir] = dis[u] + i.sec;
DFS1(i.fir);
size[u] += size[i.fir];
if (size[i.fir] > size[son[u]])
son[u] = i.fir;
}
}

void DFS2(int u)
{
dfn[u] = ++dfst;
ref[dfst] = u;
if (son[u])
top[son[u]] = top[u], DFS2(son[u]);
for (const auto& i : G[u])
if (!top[i.fir])
top[i.fir] = i.fir, DFS2(i.fir);
}

int LCA(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
std::swap(u, v);
u = fa[top[u]];
}
return dep[u] > dep[v] ? v : u;
}

inline void Add(int addval, int x)
{
real[x] += sum[x] * addval;
tag[x] += addval;
}

inline void Pushdown(int x)
{
if (tag[x])
{
Add(tag[x], x << 1), Add(tag[x], x << 1 | 1);
tag[x] = 0;
}
}

inline void Pushup(int x)
{ real[x] = real[x << 1] + real[x << 1 | 1]; }

void Build(int l = 1, int r = n, int cur = 1)
{
if (l == r)
{
sum[cur] = val[ref[l]];
return;
}
int mid = (l + r) >> 1;
Build(l, mid, cur << 1), Build(mid + 1, r, cur << 1 | 1);
sum[cur] = sum[cur << 1] + sum[cur << 1 | 1];
}

void Modify(int l = 1, int r = n, int cur = 1)
{
if (ql <= l and r <= qr)
{
Add(addval, cur);
return;
}
int mid = (l + r) >> 1;
Pushdown(cur);
if (ql <= mid)
Modify(l, mid, cur << 1);
if (mid < qr)
Modify(mid + 1, r, cur << 1 | 1);
Pushup(cur);
}

LL Query(int l = 1, int r = n, int cur = 1)
{
if (ql <= l and r <= qr)
return real[cur];
if (ql > r or qr < l)
return 0;
int mid = (l + r) >> 1;
Pushdown(cur);
return Query(l, mid, cur << 1) + Query(mid + 1, r, cur << 1 | 1);
}

void Init()
{
dep[1] = 1, DFS1(1);
top[1] = 1, DFS2(1);
Build();
}

void Add(int l, int r, int val)
{ ql = l, qr = r, addval = val, Modify(); }

LL Ask(int l, int r)
{ ql = l, qr = r; return Query(); }

void Add_root(int u, int addval)
{
while (top[u] != 1)
Add(dfn[top[u]], dfn[u], addval), u = fa[top[u]];
Add(2, dfn[u], addval);
}

LL Query_chain(int u, int v)
{
LL res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
std::swap(u, v);
res += Ask(dfn[top[u]], dfn[u]), u = fa[top[u]];
}
if (dep[u] > dep[v])
std::swap(u, v);
return res + Ask(dfn[u] + 1, dfn[v]);
}
}

void Oper(int x)
{
if (vis[x])
sum -= HLD::dis[x], --cnt, vis.reset(x), HLD::Add_root(x, -1);
else
sum += HLD::dis[x], ++cnt, vis.set(x), HLD::Add_root(x, 1);
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("tmp.in", "r", stdin);
freopen("tmp.out", "w", stdout);
#endif
n = read(), m = read(), read();
for (int i = 2; i <= n; ++i)
{
int u = read(), v = read(), w = read();
G[u].emplace_back(v, w), G[v].emplace_back(u, w);
}
HLD::Init();
// for (int i = 1; i <= n; ++i)
// printf("%d %d %d\n", HLD::size[i], HLD::top[i], HLD::dfn[i]);
for (int i = 1; i <= n; ++i)
if (read())
Oper(i);
for (int i = 1; i <= m; ++i)
{
int opt = read(), x = read();
// debug("%d %d %d\n", i, opt, x);
if (opt == 1)
Oper(x);
else
{
int y = read(), lca = HLD::LCA(x, y);
LL tmp = sum + cnt * HLD::dis[lca] - (HLD::Query_chain(1, lca) << 1);
LL minus = HLD::Query_chain(x, y);
write(tmp - minus), putchar('\n');
}
}
return 0;
}