BZOJ3611 [HEOI2014]大工程

2015.01.19 16:00 Mon| 0 visits oi_2015| 2015_刷题日常| Text

Problem

Description

国家有一个大工程,要给一个非常大的交通网络里建一些新的通道。

我们这个国家位置非常特殊,可以看成是一个单位边权的树,城市位于顶点上。

在 2 个国家 a,b 之间建一条新通道需要的代价为树上 a,b 的最短路径。

现在国家有很多个计划,每个计划都是这样,我们选中了 k 个点,然后在它们两两之间新建 C(k,2)条新通道。

现在对于每个计划,我们想知道:

  1. 这些新通道的代价和 2.这些新通道中代价最小的是多少 3.这些新通道中代价最大的是多少

Input

第一行 n 表示点数。

接下来 n-1 行,每行两个数 a,b 表示 a 和 b 之间有一条边。点从 1 开始标号。

接下来一行 q 表示计划数。

对每个计划有 2 行,第一行 k 表示这个计划选中了几个点。第二行用空格隔开的 k 个互不相同的数表示选了哪 k 个点。

Output

输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。

Sample Input

10

2 1

3 2

4 1

5 2

6 4

7 5

8 6

9 7

10 9

5

2

5 4

2

10 4

2

5 2

2

6 1

2

6 1

Sample Output

3 3 3

6 6 6

1 1 1

2 2 2

2 2 2

Hint

对于第 1,2 个点: n<=10000

对于第 3,4,5 个点: n<=100000,交通网络构成一条链

对于第 6,7 个点: n<=100000 对于第 8,9,10 个点: n<=1000000

Solution

构建虚树,进行树上 DP (合并子树,更新答案,详见代码)。

Code

#include <bits/stdc++.h>
using namespace std;

#define N 1100005
#define M 2100005
#define inf 0x3fffffffffffffff

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

vector<int> v[N];
int n, q, k, test, a[N], head[N], deep[N], fa[N][21], mark[N], id[N];

struct data {
    int next, to;
} edge[M];

inline void add(int x, int y)
{
    static int cnt = 0;
    edge[++cnt].to = y;
    edge[cnt].next = head[x];
    head[x] = cnt;
}

void dfs(int x, int d)
{
    static int tot = 0;
    id[x] = ++tot; deep[x] = d;
    for (int y, i = head[x]; i; i = edge[i].next)
        if ((y = edge[i].to) != fa[x][0])
            fa[y][0] = x, dfs(y, d + 1);
}

inline bool cmp(int x, int y) { return id[x] < id[y]; }

inline int dist(int x, int y) { return deep[y] - deep[x]; }

inline int lca(int x, int y)
{
    if (deep[x] < deep[y]) swap(x, y);
    for (int i = 20; i >= 0; --i)
        if (deep[fa[x][i]] >= deep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 20; i >= 0; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

inline void connect(int x, int y)
{
    static int vis[N];
    if (vis[x] != test)
        vis[x] = test, v[x] = vector<int>();
    if (vis[y] != test)
        vis[y] = test, v[y] = vector<int>();
    v[x].push_back(y);
}

int s[N], top;
void build()
{
    s[top = 1] = 1;
    for (int i = 1; i <= k; ++i)
    {
        int t = lca(s[top], a[i]);
        while (deep[s[top-1]] > deep[t])
            connect(s[top-1], s[top]), --top;
        if (t != s[top]) connect(t, s[top--]);
        if (t != s[top]) s[++top] = t;
        if (a[i] != s[top]) s[++top] = a[i];
    }
    while (top - 1)
        connect(s[top-1], s[top]), --top;
}

long long sumt, mint, maxt;
void getans(int x, int fa)
{
    static long long sum[N], sumd[N], mind[N], maxd[N];
    sum[x] = (mark[x] == test); sumd[x] = 0;
    mind[x] = (mark[x] == test) ? 0 : inf;
    maxd[x] = (mark[x] == test) ? 0 : -inf;
    for (vector<int>::iterator i = v[x].begin(); i != v[x].end(); ++i)
        if (*i != fa)
        {
            getans(*i, x);
            int d = dist(x, *i);
            sumt += (sumd[x] + sum[x] * d) * sum[*i] + sumd[*i] * sum[x];
            sum[x] += sum[*i];
            sumd[x] += sumd[*i] + sum[*i] * d;
            mint = min(mint, mind[x] + mind[*i] + d);
            maxt = max(maxt, maxd[x] + maxd[*i] + d);
            mind[x] = min(mind[x], mind[*i] + d);
            maxd[x] = max(maxd[x], maxd[*i] + d);
        }
}

void print()
{
    sumt = 0; mint = inf; maxt = -inf;
    getans(1, 0);
    printf("%lld %lld %lld\n", sumt, mint, maxt);
}

int main()
{
    cin >> n;
    for (int x, y, i = 1; i < n; ++i)
        x = read(), y = read(), add(x, y), add(y, x);
    dfs(1, 1);
    for (int j = 1; j <= 20; ++j)
        for (int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j-1]][j-1];
    cin >> q;
    for (test = 1; test <= q; ++test)
    {
        k = read();
        for (int i = 1; i <= k; ++i)
            a[i] = read(), mark[a[i]] = test;
        sort(a + 1, a + k + 1, cmp);
        build(); print();
    }
}