P6072 『MdOI R1』 Path 题解


true
Path
MdOI R1
省选/NOI-
#9d3dcf
  • Luogu P6072

题目要求我们从一棵树中选出两条互不相交的简单路径,使得其边权异或和之和最大。

首先我们需要进行一次转化,边权不好维护,但是边权的异或和可以转化为两端点到根的这两条路径上所有边权的异或和。
如果再将每一个点到根的路径上的边权的异或和当做点权给它赋上去的话,我们就可以将边权异或和之和最大改成两端点的点权异或和之和最大了。

然后就是考虑如何不重不漏地找到这样不相交的两条路径。

我们可以考虑边分治,以当前分治中心来隔开两条路径。
不过有一点需要考虑的就是,两条路径的其中一条可能会跨过上一层的分治中心,这可是很麻烦的一件事。

换一种思路,同样是考虑用一条边隔开两条路径,我们选取这条边的一个端点,那么这两条路径就分别在其子树内和子树外。
看起来好像可以维护的样子……

那么我们就考虑对于每一个点,在其子树内和子树外分别找到一条路径,使得其两端点的点权异或和之和最大。
对于子树外的路径,我们可以套用P8511这道由乃OI的做法,时间复杂度 $O(n \log v)$;
对于子树内的路径,我们仍然采用支配的思路来分类讨论。

我们需要明确一个观点,对于两个互相包含的子树,较大的子树的答案是不会小于另一个的答案的。

首先我们还是需要找到整棵树中异或和最大的两个点,记作 $x$ 和 $y$。(代码中用的是 l1l2

对于 $x$ 和 $y$ 均在子树中的点,其子树内答案就是 $(x,y)$,此时我们直接比较已经提前求好了的子树外答案即可。
这些点集中在从 $\operatorname{lca}(x,y)$ 到根的路径上,我们从LCA处一路往上跳即可。

1
2
3
//x与y均在子树内的点
for(int i = g; i >= 1; i = fa[i])
res = max(res, ans[i] + (a[l1] ^ a[l2]));

对于 $x$ 和 $y$ 均在子树外的点,其子树外答案就是 $(x,y)$,我们只需要求出其子树内的答案即可。
这些点集中在 $x$ 到 $y$ 的路径上的点的子树中和从 $\operatorname{lca}(x,y)$ 到根的路径上的点的子树中,我们可以直接求出答案进行比较,也可以利用我们得到的性质,只维护路径上的点的直接儿子的答案即可。
我们还可以进一步降低工作量,忽略掉从 $\operatorname{lca}(x,y)$ 到根的路径上的点的子树中的点,因为其在求 $\operatorname{lca}(x,y)$ 的子树外的答案的时候就已经统一被求过一遍了。
对于这种点,我们在遍历从 $x$ 到 $y$ 的路径上的点的时候,对于其每一个直接儿子求一遍答案并更新一次即可。
这样处理的话,插入的次数是 $O(n)$ 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
//x与y均在子树外的点
queue<int>q;
for(int i = l1; i != g; i = fa[i])
{
for(int j = h[i]; ~j; j = ne[j])
if(!tag[e[j]])q.push(e[j]);
}
for(int i = l2; i != g; i = fa[i])
{
for(int j = h[i]; ~j; j = ne[j])
if(!tag[e[j]])q.push(e[j]);
}
for(int i = h[g]; ~i; i = ne[i])
if(!tag[e[i]])q.push(e[i]);
for(int i = 1; i <= n; i++)vis[i] = false;
while(!q.empty())
{
tr.clear(), maxn = 0;
dfs2(q.front(), fa[q.front()]);
res = max(res, maxn + (a[l1] ^ a[l2]));
q.pop();
}
res = max(res, maxn + (a[l1] ^ a[l2]));

最后是两个点中只有其中一个在子树内的点,我们只能依次求出其子树内答案了。
看起来很麻烦,实际上这些点只存在于从 $x$ 到 $y$ 的这一条路径上。
对于一条路径,我们可以以 $O(n)$ 次插入的复杂度来从下到上求出子树内的答案,那么我们就可以把 $x$ 到 $y$ 的路径拆成 $x$ 到 $\operatorname{lca}(x,y)$ 和 $\operatorname{lca}(x,y)$ 到 $y$ 这两条路径,分别维护。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//x到y路径上的点
tr.clear(), maxn = 0;
for(int i = 1; i <= n; i++)vis[i] = false;
for(int i = l1; i != fa[g]; i = fa[i])
{
dfs2(i, fa[i]);
vis[i] = true;
res = max(res, maxn + ans[i]);
}
tr.clear(), maxn = 0;
for(int i = 1; i <= n; i++)vis[i] = false;
for(int i = l2; i != fa[g]; i = fa[i])
{
dfs2(i, fa[i]);
vis[i] = true;
res = max(res, maxn + ans[i]);
}

三种类型的点的插入次数都是 $O(n)$ 级别的,总时间复杂度是 $O(n \log v)$。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 30010, M = N << 1;
struct Trie
{
struct Node
{
int s[2];
int end;
};
Node tr[N * 30];
int idx;
void clear()
{
for(int i = 0; i <= idx; i++)
tr[i].s[0] = tr[i].s[1] = tr[i].end = 0;
idx = 0;
}
void insert(int x, int id)
{
int p = 0;
for(int i = 29; i >= 0; i--)
{
int c = (x >> i) & 1;
if(!tr[p].s[c])tr[p].s[c] = ++idx;
p = tr[p].s[c];
}
tr[p].end = id;
}
int query(int x)
{
int p = 0;
for(int i = 29; i >= 0; i--)
{
int c = (x >> i) & 1;
if(tr[p].s[c ^ 1])p = tr[p].s[c ^ 1];
else if(tr[p].s[c])p = tr[p].s[c];
else return 0;
}
return tr[p].end;
}
};

int n;
int h[N], e[M], ne[M], w[M], idx;
void add(int a, int b, int c)
{
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
int a[N], ans[N];
Trie tr;
int l1, l2; int maxn;
int dep[N], fa[N], deg[N];
bool tag[N];
void dfs1(int p, int fa, int v)
{
dep[p] = dep[fa] + 1, ::fa[p] = fa, a[p] = v;
for(int i = h[p]; ~i; i = ne[i])
{
if(e[i] == fa)continue;
dfs1(e[i], p, v ^ w[i]);
}
}
int lca(int p, int q)
{
if(dep[p] < dep[q])swap(p, q);
tag[p] = tag[q] = true;
while(dep[p] > dep[q])p = fa[p], tag[p] = true;
if(p == q)return p;
while(p != q)p = fa[p], q = fa[q], tag[p] = tag[q] = true;
return p;
}
int sta[N], tt;
bool vis[N];
void dfs2(int p, int fa)
{
tr.insert(a[p], p);
int q = tr.query(a[p]);
if((a[p] ^ a[q]) > maxn)maxn = a[p] ^ a[q];
for(int i = h[p]; ~i; i = ne[i])
{
if(e[i] == fa || vis[e[i]])continue;
dfs2(e[i], p);
}
}
void solve(int p, int q)
{
if(p == q)return;
if(dep[p] < dep[q])swap(p, q);
tt = 0;
while(p != q)sta[++tt] = p, vis[p] = true, p = fa[p];
tr.clear(), maxn = 0;
dfs2(1, 0);
for(int i = tt; i; i--)
{
ans[sta[i]] = maxn;
dfs2(sta[i], fa[sta[i]]);
vis[sta[i]] = false;
}
}
int res = 0;
int main()
{
memset(h, -1, sizeof(h));
scanf("%d", &n);
bool flag = true;
for(int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w), add(v, u, w);
deg[u]++, deg[v]++;
if(deg[u] > 2 || deg[v] > 2)flag = false;
}
dfs1(1, 0, 0);
tr.clear(), maxn = 0;
for(int i = 1; i <= n; i++)
{
tr.insert(a[i], i);
int j = tr.query(a[i]);
if((a[i] ^ a[j]) > maxn)l1 = j, l2 = i, maxn = a[i] ^ a[j];
}
//求子树外的答案
int g = lca(l1, l2);
for(int i = g; i >= 1; i = fa[i])tag[i] = true;
for(int i = 2; i <= n; i++)
if(!tag[i])ans[i] = maxn;
solve(1, l1), solve(g, l2);
//一条链且x到y是整条链的特殊情况
if(flag && deg[l1] == 1 && deg[l2] == 1)
{
tr.clear(), maxn = 0;
if(deg[l2] < deg[l1])swap(l1, l2);
for(int i = l2; i != l1; i = fa[i])
{
tr.insert(a[i], i);
int j = tr.query(a[i]);
if((a[i] ^ a[j]) > maxn)maxn = a[i] ^ a[j];
res = max(res, ans[i] + maxn);
}
printf("%d\n", res);
return 0;
}
//x与y均在子树内的点
for(int i = g; i >= 1; i = fa[i])
res = max(res, ans[i] + (a[l1] ^ a[l2]));
//x与y均在子树外的点
queue<int>q;
for(int i = l1; i != g; i = fa[i])
{
for(int j = h[i]; ~j; j = ne[j])
if(!tag[e[j]])q.push(e[j]);
}
for(int i = l2; i != g; i = fa[i])
{
for(int j = h[i]; ~j; j = ne[j])
if(!tag[e[j]])q.push(e[j]);
}
for(int i = h[g]; ~i; i = ne[i])
if(!tag[e[i]])q.push(e[i]);
for(int i = 1; i <= n; i++)vis[i] = false;
while(!q.empty())
{
tr.clear(), maxn = 0;
dfs2(q.front(), fa[q.front()]);
res = max(res, maxn + (a[l1] ^ a[l2]));
q.pop();
}
res = max(res, maxn + (a[l1] ^ a[l2]));
//x到y路径上的点
tr.clear(), maxn = 0;
for(int i = 1; i <= n; i++)vis[i] = false;
for(int i = l1; i != fa[g]; i = fa[i])
{
dfs2(i, fa[i]);
vis[i] = true;
res = max(res, maxn + ans[i]);
}
tr.clear(), maxn = 0;
for(int i = 1; i <= n; i++)vis[i] = false;
for(int i = l2; i != fa[g]; i = fa[i])
{
dfs2(i, fa[i]);
vis[i] = true;
res = max(res, maxn + ans[i]);
}

printf("%d\n", res);
return 0;
}