[bzoj 4871] [Shoi2017]摧毁“树状图”

删掉树上两条边不相交的路径 (可退化为点), 问最多裂成多少个连通块. (多组数据, ∑ n ≤ 5 × 10^5) 这当然是一道树型DP, 但是我好晕啊......又去学习了一下WA神犇的题解, 发现自己方向不错, 但是少了一步转化. WA神犇的思路十分清晰, 分类讨论恰到好处. 本来挺腌臢的一道题......顿时变得平易近人, 甚至有种美感?

的度数. 从树 中删掉一个点 , 得到 个连通块; 删掉两个点 , 得到 个连通块; 删掉一棵树 , 得到 个连通块. (我缺少这步转化, 没有抽象出"度数之和", 从而难以简化问题)

.

现在, 删掉两条边不相交的链 . 根据它们的位置关系, 分三种情况讨论: - 有一个公共点. 是从该点出发的不超过4条链. 得到 个连通块. - 点导出子图有一条公共边. 即, 存在 , 使得 . 得到 个连通块. - 其他. 即, 存在一个点 , 将 分割成两部分 , 使得 . 得到 个连通块.

现在问题变得很简单了~

任取一个根.

: 从u向下的所有链, 的最大值, : 所有以u为lca的路径, 的最大值. 这两个值可以一遍DFS求出.

为从 Missing \left or extra \right\left\\{f(v)|v\in S\right\\} 中取不超过 个数, 和的最大值 (定义取0个数和为0). 中所有和 相邻的点的集合.

  • 有一个公共点. 设该点为 , 用 更新答案.
  • 点导出子图有一条公共边. 设该边为 , 则 是其中一条路径的lca. 用 Missing \left or extra \rightsum(adj(u)-\left\\{v\right\\}, 2) + g(v) + 2 更新答案.
  • 其他. 设 的lca为 , 的lca为 , 分为两种情况:
    • 有祖先-后代关系 (不妨设的祖先), 设 的路径上, 除 以外的第一个点是 ;
    • 没有祖先-后代关系, 设 . 设 的子树 (包含 ) 中 的最大值, 的子树 (不包含 ) 中 的最大值. 分别用 Missing \left or extra \rightsum(adj(x)-\left\\{z\right\\}, 2) + h'(z) + 3Missing \left or extra \rightsum(adj(w)-\left\\{fa(w)\right\\}, 2) + 3 更新答案.

可以用一个struct维护 的前4大. 要从集合中排除一个元素, 和我们维护的最大值, 次大值做做比较即可.

至此, 问题解决, 撒花~

#define Z(x) max(0, x)

const int N = 5e5, inf = 1e9;

template<typename T>
inline void upmax(T& x, T v)
{
	x = max(x, v);
}

struct Info {
	int v[4];
	Info(int t=-inf)
	{
		v[0] = v[1] = v[2] = v[3] = t;
	}
	void operator+=(int t)
	{
		rep (i, 0, 4) if (t >= v[i]) swap(t, v[i]);
	}
	int one(int t=-inf)
	{
		return max(0, t == v[0] ? v[1] : v[0]);
	}
	int two(int t=-inf)
	{
		if (t == v[0]) return Z(v[1]) + Z(v[2]);
		if (t == v[1]) return Z(v[0]) + Z(v[2]);
		return Z(v[0]) + Z(v[1]);
	}
	int four()
	{
		return Z(v[0]) + Z(v[1]) + Z(v[2]) + Z(v[3]);
	}
} f[N];

int ans, F[N], g[N], h[N], d[N];
vector<int> adj[N];

int dfs_1(int u, int p)
{
	f[u] = -inf;
	rep (i, 0, adj[u].size()) {
		int v = adj[u][i];
		if (v != p) {
			f[u] += dfs_1(v, u);
		}
	}
	g[u] = f[u].two() + d[u];
	return F[u] = f[u].one() + d[u];
}

void dfs_2(int u, int p, int t)
{
	f[u] += t;
	h[u] = t = -inf;
	rep (i, 0, adj[u].size()) {
		int v = adj[u][i];
		if (v != p) {
			dfs_2(v, u, f[u].one(F[v]) + d[u]);
			upmax(ans, f[u].two(F[v]) + d[u] + max(g[v] + 2, h[v] + 3));
			upmax(h[v], g[v]);
			upmax(h[u], h[v]);
			upmax(ans, h[v] + t + 3);
			upmax(t, h[v]);
		}
	}
	upmax(ans, f[u].four() + d[u] + 2);
}

int main()
{
	int T, x, n;
	scanf("%d%d", &T, &x);
	x *= 2;
	while (T--) {
		scanf("%d", &n);
		rep (i, 0, x) scanf("%*d");
		rep (i, 0, n) adj[i].clear();
		rep (i, 0, n-1) {
			int u, v;
			scanf("%d%d", &u, &v);
			--u, --v;
			adj[u].push_back(v);
			adj[v].push_back(u);
		}
		rep (i, 0, n) d[i] = (int)adj[i].size() - 2;
		ans = 0;
		dfs_1(0, -1);
		dfs_2(0, -1, -inf);
		printf("%d\n", ans);
	}
	return 0;
}