题意: n个点的无根树, m条有向路径, 点上有权值W, 边的长度为1. 一条路径能被一个点观察到当且仅当: 该点在路径上, 且该点到起点的距离等于权值. 求每个点能观察到的路径数量. 限制与约定(n=m): 1. 10%: n=991, 所有路径起点等于终点. 2. 10%: n=992, 点权为0. 3. 5%: n=993. 4. 15%: n=99994, 树退化为一条链, 1与2相连, 2与3相连, ..., (n-1)与n相连. 5. 20%: n=99995, 所有路径起点为1. 6. 20%: n=99996, 所有路径终点为1. 7. 15%: n=99997. 8. 5%: n=299998.
这是一道好题, 从部分分说起 (即前6个子任务, 第8个子任务大概是用来卡常数大的非线性做法的).
路径起点等于终点. 只需数一数有多少个起点到i的距离为Wi. 枚举起点或点i, DFS.
Wi = 0. 只需数一数有多少个起点与i重合.
n=993. 从每条路径的起点开始DFS,
. 个人认为考场上较为实惠的策略是前25%都用这个暴力. 链. 一条路径在链上就是一条线段, 想到了什么? 扫描线. 对方向为1->n和n->1的路径各做一次, 用一个计数器, 事件开始时该位置+1, 结束时开始位置-1. 处理结束事件之前将
i-w[i]
(正向)或i+w[i]
(反向)处的计数器累加至答案即可. 具体地, 可以对每个点开一个vector
以记录事件.起点为1. 以1为根. 问题简化为对所有深度等于W的点, 统计它在多少条路径上. 给路径打标记即可, 和链上做法相似.
终点为1. 以1为根. 问题简化为对所有点统计以它为根的子树中有多少个起点到它的距离为Wi. 一个想法是在DFS的过程中统计, 将某点的子树合并. 但是, 怎样记录信息, 才能使子树之间互不干扰呢? 启发式合并. 分享一篇文章: [Tutorial] Sack (dsu on tree) - Codeforces.
到此为止, (理想情况下) 我们获得了80分. 来想想满分做法.
Subtask 5, 6 给我们启发: 将路径从起点, 终点的lca处拆成两部分. 对于方向向上的部分, 统计满足d[v] = d[u] + w[u]
的起点v
的个数, 与 Subtask 6 相似; 然而, 对于方向向下的部分, 如果要统计终点v
, 满足的关系式是d[v] = d[u] + len[v] - w[u]
, 把v
的约束移到同一边, d[v] - len[v] = d[u] - w[u]
. (Tip: 由于可能产生负数, 实现的时候等式两边应同时加上n
, 并且空间开两倍.) 不统计起点终点而用lca也可以, 就是麻烦了些. 问题转化为在适当的时刻, 对每一个u
, 统计满足d[v] = d[u] + w[u]
或d[v] - len[v] = d[u] - w[u]
的点v
的数目. 我们用两个cnt
计数器数组实现.
把进入路径和离开路径的事件挂在点上, 计数器 +1, -1. 完成子树的DFS后, 合并信息, 处理本点的事件, 更新答案.
和 Subtask 6 一样, 如果x
有两个儿子y
, z
, 并且使用全局计数器, 先遍历y
, 再遍历z
, 那么, 在遍历z
的时候, 由于子树y
中开启的一些事件还没有结束, 计数器并未清零. 简单的解决方案是每棵子树用独立的局部计数器, 自底向上启发式合并. 更好的解决方案是利用DFS的性质. 答案ans[u]
是DFS完以u
为根的子树后, 对应位置 (d[u] + w[u]
与 d[u] - w[u]
) 计数器的增加量. 所以, 在进入某点的时候减去对应位置的计数器, 离开某点的时候答案加上对应位置的计数器, 得到的就是答案. 如果找lca用Tarjan的离线+并查集算法, 时间复杂度为
可不可以在lca处+1, 起点终点处-1呢? 不行......因为没法确定路径往哪边延伸.
看~ 这是一道简单的NOIP题, 考察了DFS和求lca -_-b
以下是该做法的实现. (尝试一下宏和简单头文件, 我猜想这种风格在工程上大概是要挨打 QAQ)
由于需要保证每对点只处理一次, Tarjan算法有两个需要注意的地方: 1. vis[u] = true
得放在DFS子树之后, 否则对于那种不拐弯的情况会处理两次. 2. 路径的起点与终点重合, 只提出一个询问.
#include <bits/stdc++.h>
#define Rep(i, a, b) for (int i = (a), i##_ = (b); i < i##_; ++i)
#define For(i, a, b) for (int i = (a), i##_ = (b); i <= i##_; ++i)
#define Down(i, a, b) for (int i = (a), i##_ = (b); i >= i##_; --i)
#define pb push_back
#define fst first
#define sec second
using namespace std;
typedef pair<int, int> ii;
const int MAX_N = 3e5, MAX_M = 3e5;
int n, f[MAX_N + 1], rk[MAX_N + 1], anc[MAX_N + 1], w[MAX_N + 1], d[MAX_N + 1], cnt[2][2 * MAX_N], ans[MAX_N + 1];
bool vis[MAX_N + 1];
vector<int> adj[MAX_N + 1], st[2][MAX_N + 1], ed[2][MAX_N + 1];
vector<ii> Q[MAX_N + 1];
int F(int x)
{
return f[x] ? f[x] = F(f[x]) : x;
}
inline void merge(int x, int y)
{
int fx = F(x), fy = F(y);
if (rk[fy] > rk[fx])
swap(fx, fy);
f[fy] = fx;
rk[fx] += rk[fx] == rk[fy];
}
// d, st, ed
void get_lca(int u, int p)
{
anc[u] = u;
Rep (i, 0, adj[u].size()) {
int v = adj[u][i];
if (v != p) {
d[v] = d[u] + 1;
get_lca(v, u);
merge(u, v);
anc[F(u)] = u;
}
}
vis[u] = true;
Rep (i, 0, Q[u].size()) {
ii p = Q[u][i];
int v = p.fst == u ? p.sec : p.fst;
if (vis[v]) {
int a = anc[F(v)], len = d[u] + d[v] - 2*d[a], s = d[p.fst], t = d[p.sec] - len + n;
st[0][p.fst].pb(s);
st[1][p.sec].pb(t);
ed[0][a].pb(s);
ed[1][a].pb(t);
}
}
}
void dfs(int u, int p)
{
ans[u] -= cnt[0][d[u] + w[u]] + cnt[1][d[u] - w[u] + n];
Rep (i, 0, adj[u].size()) {
int v = adj[u][i];
if (v != p)
dfs(v, u);
}
For (i, 0, 1)
Rep (j, 0, st[i][u].size()) ++cnt[i][st[i][u][j]];
Rep (j, 0, ed[1][u].size()) --cnt[1][ed[1][u][j]];
ans[u] += cnt[0][d[u] + w[u]] + cnt[1][d[u] - w[u] + n];
Rep (j, 0, ed[0][u].size()) --cnt[0][ed[0][u][j]];
}
int main()
{
int m;
scanf("%d%d", &n, &m);
Rep (i, 0, n-1) {
int u, v;
scanf("%d%d", &u, &v);
adj[u].pb(v);
adj[v].pb(u);
}
For (i, 1, n)
scanf("%d", &w[i]);
Rep (i, 0, m) {
ii p;
scanf("%d%d", &p.fst, &p.sec);
Q[p.fst].pb(p);
if (p.fst != p.sec)
Q[p.sec].pb(p);
}
get_lca(1, 0);
dfs(1, 0);
For (i, 1, n)
printf("%d\n", ans[i]);
return 0;
}
以下是我在考场上的60分做法, 做个纪念.
#include <cstdio>
#include <vector>
using namespace std;
const int MAX_N = 3e5;
vector<int> E[MAX_N];
int n, m, ans[MAX_N], w[MAX_N], cnt[MAX_N];
namespace work_1 {
bool dfs(int s, int t, int fa, int k)
{
bool ok = false;
if (s == t) {
ok = true;
} else {
for (int i = 0, v; i < E[s].size(); ++i)
if ((v = E[s][i]) != fa)
ok = ok || dfs(v, t, s, k+1);
}
if (ok && w[s] == k)
++ans[s];
return ok;
}
void solve()
{
for (int i = 1; i <= n-1; ++i) {
int u, v;
scanf("%d %d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
for (int i = 1; i <= n; ++i)
scanf("%d", &w[i]);
for (int i = 1; i <= m; ++i) {
int s, t;
scanf("%d %d", &s, &t);
dfs(s, t, 0, 0);
}
for (int i = 1; i <= n; ++i)
printf("%d%c", ans[i], " \n"[i == n]);
}
}
namespace work_2 {
const int MAX_N = 1e5;
struct Event {
int s, t;
};
vector<Event> L[MAX_N], R[MAX_N];
void solve()
{
for (int i = 1; i <= n-1; ++i) {
int u, v;
scanf("%d %d", &u, &v);
}
for (int i = 1; i <= n; ++i)
scanf("%d", &w[i]);
for (int i = 1; i <= m; ++i) {
int s, t;
scanf("%d %d", &s, &t);
if (s <= t) {
L[s].push_back((Event){s, t});
L[t].push_back((Event){s, t});
} else {
R[s].push_back((Event){s, t});
R[t].push_back((Event){s, t});
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 0; j < L[i].size(); ++j)
if (L[i][j].s == i)
++cnt[i];
if (i > w[i])
ans[i] = cnt[i-w[i]];
for (int j = 0; j < L[i].size(); ++j)
if (L[i][j].t == i)
--cnt[L[i][j].s];
}
for (int i = n; i >= 1; --i) {
for (int j = 0; j < R[i].size(); ++j)
if (R[i][j].s == i)
++cnt[i];
if (i+w[i] <= n)
ans[i] += cnt[i+w[i]];
for (int j = 0; j < R[i].size(); ++j)
if (R[i][j].t == i)
--cnt[R[i][j].s];
}
for (int i = 1; i <= n; ++i)
printf("%d%c", ans[i], " \n"[i == n]);
}
}
namespace work_3 {
const int MAX_N = 1e5;
int dep[MAX_N];
void dfs(int u, int fa, int d)
{
dep[u] = d;
for (int i = 0, v; i < E[u].size(); ++i)
if ((v = E[u][i]) != fa) {
dfs(v, u, d+1);
cnt[u] += cnt[v];
}
}
void solve()
{
for (int i = 1; i <= n-1; ++i) {
int u, v;
scanf("%d %d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
for (int i = 1; i <= n; ++i)
scanf("%d", &w[i]);
for (int i = 1; i <= m; ++i) {
int s, t;
scanf("%d %d", &s, &t);
++cnt[t];
}
dfs(1, 0, 0);
for (int i = 1; i <= n; ++i)
printf("%d%c", dep[i] == w[i] ? cnt[i] : 0, " \n"[i == n]);
}
}
int main()
{
freopen("running.in", "r", stdin);
freopen("running.out", "w", stdout);
scanf("%d %d", &n, &m);
if (n % 10 <= 3)
work_1::solve();
else if (n % 10 == 4)
work_2::solve();
else
work_3::solve();
return 0;
}