一棵边权均为1的无根树, q个询问: 给出m个关键点, 每个点隶属于与它最近的关键点, 距离相同取编号最小的, 求每个关键点管多少点. (n,q,Σm ≤ 3*10^5) 虚树学习笔记 | Sengxian's Blog
虚树是一棵包含所有关键点和关键点两两之间LCA的树, 祖先-后代关系和原树保持一样.
建虚树的过程简述如下: 1. 加入根节点. 把关键点按原树中前序遍历的DFS序排列. 去重. 2. 模拟 DFS 的过程, S 即递归栈, 元素为当前点和它的祖先. 根节点入栈. 3. 设栈顶元素为 u (由于加入了根节点, S 始终非空), 新加入点 v. 设 w = lca(u,v). 如果 w = u, 说明 v 是 u 的后代, 直接压入栈中. 4. 否则, w 包含 u 的这一棵子树访问完毕. 从栈顶第二个点 S[top-2] 向 S[top-1] 连边, 直到栈中只剩下一个点, 或 dep[S[top-2]] < dep[w]. 那么, S[top-1] 是深度 ≥ dep[w] 的最后一个点 (对于栈中只剩一个点的情况, 由于根节点没有祖先, 它显然是最后一个). 如果 S[top-1] ≠ w, 则从 w 向 S[top-1] 连边, 并且用 w 替换 S[top-1] 成为栈顶. 压入点 v. 5. 最后, 弹栈, 连边, 直到 S 为空.
根节点加不加其实无所谓啦~ 但是DP的时候也会方便一点.
那这道题怎么做呢? 虚树上除了询问点, 还有它们之间的LCA. 每条边上还可能挂着一堆子树......如果能知道每个询问点在虚树上管哪一部分就好办了.
学习了一下题解. 虚树上每个点被谁管, 相比每个询问点管谁, 更加容易处理. 两遍DP, 维护最近次近 (要求不在同一棵子树中; 把上方的所有点也看成一棵子树) 即可.
对于虚树上的一条边 (u,v) (不含端点), 设 belong[u] = x, belong[v] = y. 如果 x = y, 那么 (u,v) 也属于 x. 否则, 先判断 u-x, v-y 相差太悬殊的情形. 那么接下来, (u,v) 上存在一个分界点 z, 使得 (z,u) 属于 x, (v,z] 属于 y. 列个不等式可以解出它的位置. 如果解出一个整数, 那么该点与 x,y 距离相等, 属于哪边取决于 x,y 的编号大小. 用倍增找到这个点, 累加答案.
对于虚树上的一个点 u, 把它自己和挂在点 u 上且不在虚树中的点的数目累加给 belong[u]. 不用考虑上方的部分, 因为我们保证根在虚树中.
typedef pair<int, int> ii;
const int N = 3e5 + 1, D = 20, inf = (1<<30)-1;
bool q[N];
int n, dfs_clock, max_d, num, dfn[N], anc[N][D], fa[N], up[N], dep[N], sz[N], V[N], ans[N];
vector<int> adj[N], T[N];
void dfs(int u, int p)
{
dfn[u] = ++dfs_clock;
anc[u][0] = p;
dep[u] = dep[p] + 1;
sz[u] = 1;
rep (i, 0, max_d)
anc[u][i+1] = anc[anc[u][i]][i];
rep (i, 0, adj[u].size())
{
int v = adj[u][i];
if (v != p)
{
dfs(v, u);
sz[u] += sz[v];
}
}
}
int jump(int x, int h)
{
for (int i = 0; h; ++i, h >>= 1)
if (h & 1)
x = anc[x][i];
return x;
}
int lca(int x, int y)
{
if (dep[y] < dep[x]) swap(x, y);
y = jump(y, dep[y]-dep[x]);
if (x == y) return y;
for (int i = max_d; i >= 0; --i)
if (anc[x][i] != anc[y][i])
x = anc[x][i], y = anc[y][i];
return anc[x][0];
}
bool cmp(int i, int j)
{
return dfn[i] < dfn[j];
}
struct VirtualTree
{
int t, cnt, S[N], tmp[N];
inline void link(int x, int y)
{
fa[y] = x;
T[x].push_back(y);
tmp[cnt++] = y;
}
void build()
{
sort(V, V+num, cmp);
t = 0, cnt = 0;
S[t++] = V[0];
rep (i, 1, num)
{
int p = lca(S[t-1], V[i]);
if (S[t-1] != p)
{
while (t > 1 && dep[S[t-2]] >= dep[p])
{
link(S[t-2], S[t-1]);
--t;
}
if (S[t-1] != p)
{
link(p, S[t-1]);
S[t-1] = p;
}
}
S[t++] = V[i];
}
while (t > 1)
{
link(S[t-2], S[t-1]);
--t;
}
tmp[cnt++] = S[0];
fa[S[0]] = 0;
copy(tmp, tmp + cnt, V);
num = cnt;
}
} VT;
struct Info
{
ii x[2];
Info()
{
x[0] = x[1] = ii(inf, 0);
}
void operator|=(const ii& t)
{
if (t <= x[0])
x[1] = x[0], x[0] = t;
else if (t < x[1])
x[1] = t;
}
} f[N];
void dp()
{
rep (i, 0, num)
{
int u = V[i];
f[u] = Info();
if (q[u]) f[u] |= ii(0, u);
rep (j, 0, T[u].size())
{
int v = T[u][j];
f[u] |= ii(f[v].x[0].first + dep[v] - dep[u], f[v].x[0].second);
}
}
per (i, num-1, 0)
{
int u = V[i], p = fa[u];
ii t = f[p].x[f[p].x[0].second == f[u].x[0].second];
f[u] |= ii(t.first + dep[u] - dep[p], t.second);
}
}
inline int calc(int x, int y)
{
return sz[x] - sz[y];
}
void solve()
{
rep (i, 0, num-1)
{
int v = V[i], u = fa[v], l = dep[v]-dep[u], w = jump(v, l-1),
x = f[u].x[0].second, y = f[v].x[0].second;
up[v] = w;
if (l == 1) continue;
if (x == y)
{
ans[x] += calc(w, v);
}
else
{
int a = f[u].x[0].first, b = f[v].x[0].first;
if (a-b <= -l)
{
ans[x] += calc(w, v);
}
else if (a-b >= l)
{
ans[y] += calc(w, v);
}
else
{
int t = (a-b+l)/2 - (!((a-b+l) & 1) && x < y), z = jump(v, t);
ans[x] += calc(w, z);
ans[y] += calc(z, v);
}
}
}
rep (i, 0, num)
{
int u = V[i], s = sz[u];
rep (i, 0, T[u].size())
s -= sz[up[T[u][i]]];
ans[f[u].x[0].second] += s;
}
}
void clear()
{
rep (i, 0, num)
{
int v = V[i];
T[v].clear();
q[v] = false;
ans[v] = 0;
}
}
int main()
{
scanf("%d", &n);
rep (i, 0, n-1)
{
int x, y;
scanf("%d%d", &x, &y);
adj[x].push_back(y);
adj[y].push_back(x);
}
while ((1<<max_d) < n) ++max_d;
dfs(1, 0);
int m;
scanf("%d", &m);
while (m--)
{
static int k, h[N];
scanf("%d", &k);
num = k;
rep (i, 0, k)
{
scanf("%d", h+i);
V[i] = h[i];
q[h[i]] = true;
}
if (!q[1]) V[num++] = 1;
VT.build();
dp();
solve();
rep (i, 0, k)
printf("%d ", ans[h[i]]);
puts("");
clear();
}
return 0;
}