题目链接:
题意: 给出一棵 n 个节点的带边权的树, 有 m 个形如 x y 的询问, 要求输出所有 x, y节点之间的最短距离.
思路: dis[i] 存储 i 节点到根节点的最短距离, lca 为 x, y 的最近公共祖先, 那么 x, y 之间的最短距离为: dis[x] + dis[y] - 2 * dis[lca] .
解法1: tarjan离线算法
关于该算法
1 Tarjan(u)//marge和find为并查集合并函数和查找函数 2 { 3 for each(u,v) //访问所有u子节点v 4 { 5 Tarjan(v); //继续往下遍历 6 marge(u,v); //合并v到u上 7 标记v被访问过; 8 } 9 for each(u,e) //访问所有和u有询问关系的e10 {11 如果e被访问过;12 u,e的最近公共祖先为find(e);13 }14 }
详解见:
代码:
1 #include2 #include 3 #include 4 using namespace std; 5 6 const int MAXN = 4e4 + 10; 7 struct node{ 8 int u, v, w, lca, next; 9 }edge1[MAXN << 1], edge2[810];//edge1记录树, edge2记录询问10 11 int vis[MAXN], pre[MAXN], dis[MAXN];//vis[i]标记i是否已经搜索过, pre[i]记录i的根节点, dis[i]记录i到根节点的距离12 int head1[MAXN], head2[MAXN], ance[MAXN], ip1, ip2;13 14 void init(void){15 memset(vis, 0, sizeof(vis));16 memset(dis, 0, sizeof(dis));17 memset(head1, -1, sizeof(head1));18 memset(head2, -1, sizeof(head2));19 ip1 = ip2 = 0;20 }21 22 void addedge1(int u, int v, int w){ //前向星23 edge1[ip1].v = v;24 edge1[ip1].w = w;25 edge1[ip1].next = head1[u];26 head1[u] = ip1++;27 }28 29 void addedge2(int u, int v){30 edge2[ip2].u = u;31 edge2[ip2].v = v;32 edge2[ip2].lca = -1;33 edge2[ip2].next = head2[u];34 head2[u] = ip2++;35 }36 37 int find(int x){38 return pre[x] == x ? x : pre[x] = find(pre[x]);39 }40 41 void jion(int x, int y){42 x = find(x);43 y = find(y);44 if(x != y) pre[y] = x;45 }46 47 void tarjan(int u){48 vis[u] = 1;49 ance[u] = pre[u] = u;50 for(int i = head1[u]; i != -1; i = edge1[i].next){51 int v = edge1[i].v;52 int w = edge1[i].w;53 if(!vis[v]){54 dis[v] = dis[u] + w;55 tarjan(v);56 jion(u, v);57 }58 }59 for(int i = head2[u]; i != -1; i = edge2[i].next){60 int v = edge2[i].v;61 if(vis[v]) edge2[i].lca = edge2[i ^ 1].lca = ance[find(v)];62 }63 }64 65 int main(void){66 int t, n, m, x, y, z;67 scanf("%d", &t);68 while(t--){69 init();70 scanf("%d%d", &n, &m);71 for(int i = 1; i < n; i++){72 scanf("%d%d%d", &x, &y, &z);73 addedge1(x, y, z);74 addedge1(y, x, z);75 }76 for(int i = 0; i < m; i++){77 scanf("%d%d", &x, &y);78 addedge2(x, y);79 addedge2(y, x);80 }81 dis[1] = 0;82 tarjan(1);83 for(int i = 0; i < m; i++){84 int cc = i << 1;85 int u = edge2[cc].u;86 int v = edge2[cc].v;87 int lca = edge2[cc].lca;88 printf("%d\n", dis[u] + dis[v] - 2 * dis[lca]);89 }90 }91 return 0;92 }
解法2: lca转RMQ
关于该算法
ver[] 存储树的 dfs 路径
first[u] 为顶点 u 在 ver 数组中第一次出现时的下标
deep[indx] 为顶点 ver[indx] 的深度
对于求 x, y 的 lca, 先令 l = first[x], r = first[y], 即 l, r 分别为 x, y 第一次在 ver 数组中出现时对应的下标
在 deep[] 数组中找到区间 [l, r] 中的最小值, 其下标对应的 ver 值即为 x, y 的 lca. (区间最值可以用 RMQ 处理)
详解见:
代码:
1 #include2 #include 3 #include 4 #include 5 using namespace std; 6 7 const int MAXN = 4e4 + 10; 8 struct node{ 9 int v, w, next;10 }edge[MAXN << 1];11 12 int dp[MAXN << 1][30]; //dp[i][j]存储deep数组中下标i开始长度为2^j的子串中最小值的下标13 int first[MAXN], ver[MAXN << 1], deep[MAXN << 1];14 int vis[MAXN], head[MAXN], dis[MAXN], ip, indx;15 16 inline void init(void){17 memset(vis, 0, sizeof(vis));18 memset(head, -1, sizeof(head));19 ip = 0;20 indx = 0;21 }22 23 void addedge(int u, int v, int w){24 edge[ip].v = v;25 edge[ip].w = w;26 edge[ip].next = head[u];27 head[u] = ip++;28 }29 30 void dfs(int u, int h){31 vis[u] = 1; //标记已搜索过的点32 ver[++indx] = u; //记录dfs路径33 first[u] = indx; //记录顶点u第一次出现时对应的ver数组的下标34 deep[indx] = h; //记录ver数组中对应下标的点的深度35 for(int i = head[u]; i != -1; i = edge[i].next){36 int v = edge[i].v;37 if(!vis[v]){38 dis[v] = dis[u] + edge[i].w;39 dfs(v, h + 1);40 ver[++indx] = u;41 deep[indx] = h;42 }43 }44 }45 46 void ST(int n){47 for(int i = 1; i <= n; i++){48 dp[i][0] = i;49 }50 for(int j = 1; (1 << j) <= n; j++){51 for(int i = 1; i + (1 << j) - 1 <= n; i++){52 int x = dp[i][j - 1], y = dp[i + (1 << (j - 1))][j - 1];53 dp[i][j] = deep[x] < deep[y] ? x : y;54 }55 }56 }57 58 int RMQ(int l, int r){59 int len = log2(r - l + 1);60 int x = dp[l][len], y = dp[r - (1 << len) + 1][len];61 return deep[x] < deep[y] ? x : y;62 }63 64 int LCA(int x, int y){65 int l = first[x], r = first[y];66 if(l > r) swap(l, r);67 int pos = RMQ(l, r);68 return ver[pos];69 }70 71 int main(void){72 int t, n, m, x, y, z;73 scanf("%d", &t);74 while(t--){75 init();76 scanf("%d%d", &n, &m);77 for(int i = 1; i < n; i++){78 scanf("%d%d%d", &x, &y, &z);79 addedge(x, y, z);80 addedge(y, x, z);81 }82 dis[1] = 0;83 dfs(1, 1);84 ST(2 * n - 1);85 for(int i = 0; i < m; i++){86 scanf("%d%d", &x, &y);87 int lca = LCA(x, y);88 printf("%d\n", dis[x] + dis[y] - 2 * dis[lca]);89 }90 }91 return 0;92 }