BZOJ3246 [IOI2013]Dreaming

2015.03.04 16:10 Wed| 1 visits oi_2015| 2015_刷题日常| Text

Solution

Code

#include <bits/stdc++.h>
using namespace std;
#define N 100005

int n, m, l, head[N], root[N], cnt;
int vis[N], maxd[N], max2d[N], son[N], deep[N];
int dis[N], maxt[N], mint[N], maxb, max2b, max3b, ans;

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

struct graph
{
    int next, to, val;
    graph() {}
    graph(int _next, int _to, int _val)
    : next(_next), to(_to), val(_val) {}
} edge[N << 1];

inline void add(int x, int y, int z)
{
    static int cnt = 0;
    edge[++cnt] = graph(head[x], y, z);
    head[x] = cnt;
}

void dfs(int x, int fa)
{
    int y; vis[x] = 1;
    for (int i = head[x]; i; i = edge[i].next)
        if ((y = edge[i].to) != fa)
        {
            deep[y] = deep[x] + edge[i].val; dfs(y, x);
            if (maxd[y] + edge[i].val >= maxd[x])
                max2d[x] = maxd[x], maxd[x] = maxd[y] + edge[i].val,
                son[x] = y;
            else if (maxd[y] + edge[i].val >= max2d[x])
                max2d[x] = maxd[y] + edge[i].val;
        }
}

void getd(int p, int x, int fa)
{
    int y;
    for (int i = head[x]; i; i = edge[i].next)
        if ((y = edge[i].to) != fa)
        {
            dis[y] = max(dis[x],
                (y == son[x] ? max2d[x] : maxd[x])) + edge[i].val;
            maxt[p] = max(maxt[p], max(maxd[y], dis[y]));
            mint[p] = min(mint[p], max(maxd[y], dis[y]));
            getd(p, y, x);
        }

}

int main()
{
    cin >> n >> m >> l;
    for (int x, y, z, i = 1; i <= m; ++i)
    {
        x = read() + 1; y = read() + 1; z = read();
        add(x, y, z); add(y, x, z);
    }
    for (int i = 1; i <= n; ++i)
    {
        if (vis[i]) continue;
        root[++cnt] = i;
        dfs(i, 0);
        maxt[cnt] = mint[cnt] = maxd[i];
        getd(cnt, i, 0);
        ans = max(ans, maxt[cnt]);
        if (mint[cnt] >= maxb)
            max3b = max2b, max2b = maxb, maxb = mint[cnt];
        else if (mint[cnt] >= max2b)
            max3b = max2b, max2b = mint[cnt];
        else if (mint[cnt] >= max3b)
            max3b = mint[cnt];
    }
    if (cnt > 1) ans = max(ans, maxb + max2b + l);
    if (cnt > 2) ans = max(ans, max2b + max3b + l * 2);
    cout << ans << endl;
}