虚树的概念
虚树,即将树上的一些关键点和
其两两之间的 lca 之间连边形成的树。
虚树解决的问题
利用虚树可以对于指定的多组点集 S 的询问进行
每组 O(∣S∣(logn+log∣S∣)+f(∣S∣)) 的回答,
其中 f(x) 表示对一棵 x 节点的树
单组询问 这个问题的复杂度。
可以看出,这个复杂度基本(除 logn 外)与 n 无关。
这样,对于多组询问的回答就省去了
每次询问 都遍历一整棵树的 O(n) 复杂度了。
由此可见,虚树适合解决多组询问树上特定点集 S
且 ∑∣S∣ 有限的问题。
如何构造虚树
常见的构造方法和栈有关,
但 我不会且 有更无脑的方法。
具体的,先预处理每个点的 dfn 序,将关键点按 dfn 序排序后插入相邻两者的 lca 。
此后再次按 dfn 序排序并去重,并
从相邻两者的 lca 向后者连边 (详见代码)。
code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| void solve() { scanf("%d", &k); vector<int> o(k); for (auto &i : o) scanf("%d", &i); sort(o.begin(), o.end(), dfncmp); for (int i = 0; i < k; i++) o.push_back(lca(o[i], o[i + 1])); sort(o.begin(), o.end(), dfncmp); o.erase(unique(o.begin(), o.end()), o.end()); for (int i = 1; i < o.size(); i++) tmp[lca(o[i - 1], o[i])].push_back(o[i]);
for (auto i : o) tmp[i].clear(); }
|
如何使用虚树
视具体题目而定。
CF613D Kingdom and its Cities
题意:给定一棵树, q 组询问,每组询问给定 k 个点,
你可以删掉不同于那 k 个点的 m 个点,使得这 k 个点两两不连通,
要求最小化 m,如果不可能输出 −1 。询问之间独立。
由于数据范围 n≤105,∑k≤105,q≤105 ,故可以考虑使用虚树解决。
首先,注意到若有树上相邻两点同为关键点则必然无解。
建出虚树后,递归遍历虚树中的节点,对每个点分以下情况讨论:
- 子树中无关键点,不对答案产生贡献。
- 当前点为关键点,则对于每个有关键点的子树均需删掉至少一个点,贡献为有关键点的子树个数。
- 当前点不为关键点,若子树中仅有一个关键点,对答案不产生贡献,标记以此节点为根的子树中有关键点;
否则删除此点,对答案产生 1 贡献,并标记以此节点为根的子树中无关键点(已断开)。
code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
| #include <bits/stdc++.h> using namespace std; int n, _, fa[100010][20], dfn[100010], tot, k, dep[100010]; bool flag[100010]; vector<int> g[100010], tmp[100010]; bool dfncmp(int x, int y) { return dfn[x] < dfn[y]; } void dfs(int p) { dfn[p] = ++tot, dep[p] = dep[fa[p][0]] + 1; for (int i = 1; i <= 17; i++) fa[p][i] = fa[fa[p][i - 1]][i - 1]; for (auto i : g[p]) if (i != fa[p][0]) fa[i][0] = p, dfs(i); } int lca(int x, int y) { if (dep[x] > dep[y]) swap(x, y); for (int i = 17; ~i; i--) if (dep[fa[y][i]] >= dep[x]) y = fa[y][i]; for (int i = 17; ~i; i--) if (fa[y][i] != fa[x][i]) x = fa[x][i], y = fa[y][i]; return x == y ? x : fa[x][0]; } int work(int p) { int cnt = 0, ans = 0; for (auto i : tmp[p]) ans += work(i), cnt += flag[i]; if (cnt) { if (flag[p]) ans += cnt; else if (cnt == 1) flag[p] = true; else ans++; } return ans; } void solve() { scanf("%d", &k); vector<int> o(k); for (auto &i : o) scanf("%d", &i), flag[i] = true; for (auto i : o) { if (flag[fa[i][0]]) { for (auto j : o) flag[j] = false; printf("-1\n"); return; } } sort(o.begin(), o.end(), dfncmp); for (int i = 0; i < k; i++) o.push_back(lca(o[i], o[i + 1])); sort(o.begin(), o.end(), dfncmp); o.erase(unique(o.begin(), o.end()), o.end()); for (int i = 1; i < o.size(); i++) tmp[lca(o[i - 1], o[i])].push_back(o[i]); printf("%d\n", work(o[0])); for (auto i : o) tmp[i].clear(), flag[i] = false; } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); g[x].push_back(y), g[y].push_back(x); } dfs(1); scanf("%d", &_); while (_--) solve(); return 0; }
|
Luogu P2495 [SDOI2011] 消耗战
题意:给定一棵带权树, m 组询问,每组询问给定 k 个点,
要求砍断一些边使得这 k 个点均与 1 号点不连通,
求砍断的边权和的最小值。
同样的,由于数据范围 n≤2.5×105,∑k≤5×105,m≤5×105 ,考虑使用虚树解决。
建出虚树后,对每个点进行 dp ,显然有如下方程:
f(x)={∑p∈son(x)min(f(p),len(p))len(x)flag[x]=trueotherwise
其中 flag[x] 表示点 x 是否为关键点,son(x) 表示点 x 在
虚树 上的子节点, len(x) 表示点 x 到其
虚树上父节点 的
原树中 边权最小值,这是可以在建虚树时一并处理的,若你使用的是倍增求 lca ,你会发现方法极其类似。
综上,我们便可以通过虚树+树形 dp 解决这个问题。实际实现时由于每节点仅使用一次,不必用 dp 数组存储结果 还省得清空 ;且由于我将 len 与树边一同存储,还有要考虑 1 号点是否在虚树中等问题,代码与描述稍有出入。
code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
| #include <bits/stdc++.h> using namespace std; const int N = 250010; int n, _, dfn[N], tot, fa[N][20], len[N][20], dep[N], k; bool flag[N]; vector<pair<int, int>> g[N], tmp[N]; bool cmp(int x, int y) { return dfn[x] < dfn[y]; } void init(int p) { dfn[p] = ++tot, dep[p] = dep[fa[p][0]] + 1; for (int i = 1; i <= 18; i++) fa[p][i] = fa[fa[p][i - 1]][i - 1], len[p][i] = min(len[p][i - 1], len[fa[p][i - 1]][i - 1]); for (auto i : g[p]) { if (i.first != fa[p][0]) fa[i.first][0] = p, len[i.first][0] = i.second, init(i.first); } } int lca(int x, int y) { if (dep[x] > dep[y]) swap(x, y); for (int i = 18; ~i; i--) if (dep[fa[y][i]] >= dep[x]) y = fa[y][i]; for (int i = 18; ~i; i--) if (fa[y][i] != fa[x][i]) y = fa[y][i], x = fa[x][i]; return x == y ? x : fa[x][0]; } int le(int x, int y) { int ans = len[x][0]; for (int i = 18; ~i; i--) if (dep[fa[x][i]] >= dep[y]) ans = min(ans, len[x][i]), x = fa[x][i]; return ans; } long long work(int x) { if (flag[x]) return len[x][0]; long long ans = 0; for (auto i : tmp[x]) ans += min(work(i.first), (long long)i.second); return ans; } void solve() { scanf("%d", &k); vector<int> o(k); for (auto &i : o) scanf("%d", &i), flag[i] = true; sort(o.begin(), o.end(), cmp); for (int i = 1; i < k; i++) o.push_back(lca(o[i - 1], o[i])); sort(o.begin(), o.end(), cmp); o.erase(unique(o.begin(), o.end()), o.end()); for (int i = 1; i < o.size(); i++) tmp[lca(o[i - 1], o[i])].emplace_back(o[i], le(o[i], lca(o[i - 1], o[i]))); long long ans = work(o[0]); printf("%lld\n", o[0] == 1 ? ans : min(ans, (long long)le(o[0], 1))); for (auto i : o) flag[i] = false, tmp[i].clear(); } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y, z; scanf("%d%d%d", &x, &y, &z); g[x].emplace_back(y, z), g[y].emplace_back(x, z); } init(1); scanf("%d", &_); while (_--) solve(); return 0; }
|
that’s all.