题目

代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+10, M = 3e5+10;
const int inf = 0x3f3f3f3f;
int h[N], e[2*N], ne[2*N], idx, w[2*N];
int p[N];
int fa[N][18], d[N], d1[N][18], d2[N][18];
int n, m;
struct edge{int a;int b;int c;bool f;bool operator < (const edge& v) const{return c < v.c;}
}edge[M];
void add(int a, int b, int c) // 添加一条边a->b,边权为c
{e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
int find(int x)
{if(p[x] != x) p[x] = find(p[x]);return p[x];
}
ll kruskal()
{ll retv = 0;for(int i = 1; i <= n; i++)p[i] = i;sort(edge, edge+m);for(int i = 0; i < m; i++){int a = edge[i].a, b = edge[i].b, c = edge[i].c;a = find(a), b = find(b);if(a == b) continue;p[a] = b;edge[i].f = 1;retv += c;}return retv;
}
void build()
{memset(h, -1, sizeof h);for(int i = 0; i < m; i++)if(edge[i].f){int a = edge[i].a, b = edge[i].b, c = edge[i].c;add(a, b, c);add(b, a, c);}
}
void bfs()
{memset(d, 0x3f, sizeof d);queue<int> q;q.push(1);d[0] = 0, d[1] = 1;while(q.size()){int u = q.front(); q.pop();for(int i = h[u]; ~i; i = ne[i]){int j = e[i];if(d[j] > d[u] + 1){q.push(j);d[j] = d[u] + 1;fa[j][0] = u;d1[j][0] = w[i], d2[j][0] = -inf;for(int k = 1; k <= 17; k++){int anc = fa[j][k-1];fa[j][k] = fa[anc][k-1];d1[j][k] = -inf, d2[j][k] = -inf;int distance[4] = {d1[j][k-1], d2[j][k-1], d1[anc][k-1], d2[anc][k-1]};for(int t = 0; t < 4; t++){if(distance[t] > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = distance[t];else if(distance[t] != d1[j][k] && distance[t] > d2[j][k]) d2[j][k] = distance[t];}}}}}
}
int lca(int a, int b, int c)
{static int dist[N*2];int cnt = 0;if(d[a] < d[b]) swap(a, b);for(int i = 17; i >= 0; i--){if(d[fa[a][i]] >= d[b]){dist[cnt++] = d1[a][i];dist[cnt++] = d2[a][i];a = fa[a][i];}}if(a != b){for(int i = 17; i >= 0; i--){if(fa[a][i] != fa[b][i]){dist[cnt++] = d1[a][i];dist[cnt++] = d2[a][i];dist[cnt++] = d1[b][i];dist[cnt++] = d2[b][i];a = fa[a][i], b = fa[b][i];}}dist[cnt++] = d1[a][0];dist[cnt++] = d1[b][0]; //这里必须都是d1,因为这条路径只有1条边,所以:d1 = w[i], d2 = -inf}int dist1 = -inf, dist2 = -inf;for(int i = 0; i < cnt; i++){if(dist[i] > dist1) dist2 = dist1, dist1 = dist[i];else if(dist[i] != dist1 && dist[i] > dist2) dist2 = dist[i];}if(c > dist1) return c - dist1;if(c > dist2) return c - dist2;return inf;
}
int main()
{ios::sync_with_stdio(0);cin.tie(0);cin >> n >> m;for(int i = 0; i < m; i++){int a, b, c;cin >> a >> b >> c;edge[i] = {a, b, c};}ll sum = kruskal();build();bfs();ll res = 1e18;for(int i = 0; i < m; i++)if(!edge[i].f){int a = edge[i].a, b = edge[i].b, c = edge[i].c;ll t = sum + lca(a, b, c);res = min(res, t);}cout << res;
}