题目
第十二届蓝桥杯C/C++ B组决赛——机房
题解代码
#include<bits/stdc++.h>
using namespace std;
#define fir(i, a, b) for(int i = a; i <= b; i++)
#define rif(i, b, a) for(int i = b; i >= a; i--)
const int N = 1e5 + 5, Deep = 17;
int n, m;
int tot, head[N], ver[N], Next[N], w[N];
int depth[N], fa[N][Deep], dist[N][Deep];
void add(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}
void bfs(){
memset(depth, 0x3f, sizeof depth);
depth[0] = 0;
depth[1] = 1;
queue<int> q;
q.push(1);
while(q.size()) {
int x = q.front();
q.pop();
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(depth[y] <= depth[x] + 1) continue;
depth[y] = depth[x] + 1;
q.push(y);
fa[y][0] = x;
dist[y][0] = w[y];
fir(k, 1, Deep - 1) {
fa[y][k] = fa[fa[y][k - 1]][k - 1];
dist[y][k] = dist[y][k - 1] + dist[fa[y][k - 1]][k - 1];
}
}
}
}
int lca(int x, int y){
int res = 0;
if(depth[x] < depth[y]) swap(x, y);
rif(k, Deep - 1, 0) {
if(depth[fa[x][k]] >= depth[y]) {
res += dist[x][k];
x = fa[x][k];
}
}
if(x == y) return res + w[x];
rif(k, Deep - 1, 0) {
if(fa[x][k] > 0 && fa[x][k] != fa[y][k]) {
res += dist[x][k] + dist[y][k];
x = fa[x][k];
y = fa[y][k];
}
}
return res + w[fa[x][0]] + dist[x][0] + dist[y][0];
}
int main(){
cin >> n >> m;
int x, y;
fir(i, 1, n - 1) {
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
w[x]++;
w[y]++;
}
bfs();
while(m--) {
scanf("%d%d", &x, &y);
if(x == y) {
printf("%d\n", w[x]);
continue;
}
printf("%d\n", lca(x, y));
}
return 0;
}