一种比较简短的写法:
拉出直径,再在直径的每一个点上跑一下最长链,为 $ mx_i$
这里设三点的路径交点为 \(rt\)。
假设 \(rt \rightarrow u,v,w\) 的距离为 \(dis1,dis2,dis3\) 。
容易知道 \(dis1 = (x+y-z)/2,dis2 = (-x+y+z)/2,dis3 = (x-y+z)/2\)。
在直径上的 \([dis1+1,len-dis2]\) 中选择 \(mx\) 最大值判断是否大于等于 \(dis3\)。
若成立,则在 \([dis2+1,len-dis1]\) 中选择 \(mx\) 最大值判断是否大于等于 \(dis3\)。
明显这样做是在 \(dis1\ge dis2\ge dis3\) 时成立,若 \(dis1,dis2,dis3\) 不满足递减关系,先排序,最后调整关系即可。
时间复杂度 \(O(n \log n +q)\),常熟较小,比次优解快一半。
放一个未卡常的个人觉得比较清新的代码:
#include<bits/stdc++.h>
using namespace std;
const int N=2E5+5;
int n,q,dep[N],rt,fa[N],st[N][21],mx[N],lg[N],vis[N],a[4],b[4];
vector<int>e[N],g,h[N];
int get(int x,int y){return mx[x]>mx[y]?x:y;
}
void dfs(int u,int F){fa[u]=F;dep[u]=dep[F]+1;if(dep[u]>dep[rt])rt=u;for(int v:e[u])if(v!=F&&!vis[v])dfs(v,u);
}
int query(int l,int r){int tmp=lg[r-l+1];return get(st[l][tmp],st[r+1-(1<<tmp)][tmp]);
}
bool cmp(int x,int y){return a[x]>a[y];
}
int main(){freopen("game.in", "r", stdin);freopen("game.out", "w", stdout);scanf("%d",&n);for(int i=1,u,v;i<n;i++){scanf("%d%d",&u,&v);e[u].push_back(v),e[v].push_back(u);}dfs(1,0);dfs(rt,0);while(rt)vis[rt]=1,g.push_back(rt),rt=fa[rt];for(int i=0;i<g.size();i++){int x=g[i];rt=0;dep[0]=-1;st[i+1][0]=i;dfs(x,0);mx[i]=dep[rt];while(rt)h[i].push_back(rt),rt=fa[rt];reverse(h[i].begin(),h[i].end());}for(int i=1;i<=20;i++)for(int j=1;j+(1<<i-1)<=g.size();j++)st[j][i] = get(st[j][i-1],st[j+(1<<i-1)][i-1]);for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;scanf("%d",&q);while(q--){int u,v,w;for(int i=1;i<=3;i++)scanf("%d",a+i),b[i]=i;sort(b+1,b+4,cmp);int tmp1= (a[b[1]] + a[b[2]] - a[b[3]])/2;int tmp2= a[b[1]] - tmp1;int tmp3= a[b[2]] - tmp1;int tmp= query(tmp1+1,g.size()-tmp2);if(mx[tmp] >= tmp3){u = g[tmp-tmp1];v = g[tmp+tmp2];w = h[tmp][tmp3];} else {tmp= query(tmp2+1,g.size()-tmp1);v = g[tmp-tmp2];u = g[tmp+tmp1];w = h[tmp][tmp3];}if (b[1] == 1 && b[2] == 2)printf("%d %d %d\n", u, v, w);if (b[1] == 1 && b[2] == 3)printf("%d %d %d\n", v, u, w);if (b[1] == 2 && b[2] == 1)printf("%d %d %d\n", u, w, v);if (b[1] == 2 && b[2] == 3)printf("%d %d %d\n", v, w, u);if (b[1] == 3 && b[2] == 1)printf("%d %d %d\n", w, u, v);if (b[1] == 3 && b[2] == 2)printf("%d %d %d\n", w, v, u);}return 0;}