[usaco2013 Open] Yin and Yang

n个点的无根树, 每个边权为0或1. 一条路径阴阳平衡当且仅当它至少有两条边, 且路径上0,1数目相等. 统计有多少条简单路径上可以选一个点作休息站, 休息站的两侧是阴阳平衡的路径. 两条路径相等当且仅当它们包含的边集相等. (1≤n≤10^5) (bzoj 3127 & 3697) 总算自己想出来一道 (点分治裸) 题. QAQ

树上路径统计的两种常用方法: 树上DP, 树分治. 发现本题后者更合适.

设重心为x. 如果有一条经过重心的合法路径y-x-z, 那么, 休息站要么在y-x上, 要么在x-z上. 我们对每个点y!=x, 统计有多少个点z. y-x上能建休息站, 当且仅当y-x上存在一点t, 使得 dis(x, t) = dis(x, y). 如果y-x上能建休息站, x-z上是否可以建就无所谓了; 否则, x-z上必须有休息站. 注意处理z=x的情况 (发现答案偏小才想到 QAQ).

typedef long long ll;

const int N = 1e5;

struct Edge {
	int to, w;
};

bool g[N], b[N];
int top, seq[N], d[N], sz[N], cnt[2*N+1], tot[2][2*N+1];
ll ans;
vector<Edge> adj[N];

int get_centroid(int u, int p, int n)
{
	int r = -1, mx = 0;
	sz[u] = 1;
	rep (i, 0, adj[u].size()) {
		int v = adj[u][i].to;
		if (g[v] || v == p) continue;
		int t = get_centroid(v, u, n);
		r = t >= 0 ? t : r;
		sz[u] += sz[v];
		mx = max(mx, sz[v]);
	}
	mx = max(mx, n-sz[u]);
	return r >= 0 ? r : (mx <= n/2 ? u : -1);
}

void dfs(int u, int p)
{
	b[u] = ++cnt[d[u]+N] > 1;
	if (d[u] == 0 && cnt[d[u]+N] > 2) ans += 2;
	++tot[b[u]][d[u]+N];
	
	seq[top++] = u;
	
	rep (i, 0, adj[u].size()) {
		int v = adj[u][i].to;
		if (!g[v] && v != p) {
			d[v] = d[u] + adj[u][i].w;
			dfs(v, u);
		}
	}

	--cnt[d[u]+N];
}

void solve(int u, int n)
{
	int x = get_centroid(u, -1, n);
	g[x] = true;

	top = 0;

	cnt[N] = 1;
	rep (i, 0, adj[x].size()) {
		int y = adj[x][i].to;
		if (!g[y]) {
			d[y] = adj[x][i].w;
			sz[y] = sz[y]<sz[x] ? sz[y] : n-sz[x];
			dfs(y, x);
		}
	}
	cnt[N] = 0;
	
	int k = 0;
	rep (i, 0, adj[x].size()) {
		static int tmp[2][2*N+1];
		int y = adj[x][i].to;
		if (g[y]) continue;
		rep (j, k, k + sz[y]) ++tmp[b[seq[j]]][d[seq[j]]+N];
		rep (j, k, k + sz[y]) {
			int z = seq[j], t = b[z], l = -d[z] + N;
			ans += tot[t^1][l] - tmp[t^1][l];
			if (t)
				ans += tot[1][l] - tmp[1][l];
		}
		rep (j, k, k + sz[y]) tmp[b[seq[j]]][d[seq[j]]+N] = 0;
		k += sz[y];
	}

	while (top--) tot[b[seq[top]]][d[seq[top]]+N] = 0;
	
	rep (i, 0, adj[x].size()) {
		int y = adj[x][i].to;
		if (!g[y])
			solve(y, sz[y]);
	}
}

int main()
{
	int n;
	scanf("%d", &n);
	rep (i, 0, n-1) {
		int a, b, t;
		scanf("%d%d%d", &a, &b, &t);
		--a, --b;
		t = t ? 1 : -1;
		adj[a].push_back((Edge){b, t});
		adj[b].push_back((Edge){a, t});
	}
	solve(0, n);
	printf("%lld\n", ans/2);
	return 0;
}