BZOJ2286 [SDOI2011]消耗战

2015.01.11 15:41 Sun| 1 visits oi_2015| 2015_刷题日常| Text

Solution

可以发现每一次询问中,若暴力树形 DP ,则时间复杂度为可怕的 $\mathcal O(n)$ ,令人不忍直视 >_< 。

考虑建立虚树,即对于每个询问,建立只包含这一次询问的节点及其两两间 LCA 的至多 $2k_i$ 个节点的树,并在这一基础上进行树形 DP,总时间复杂度便可降为 $\mathcal O(\sum k_i)$ 。

可以在原图上利用 LCA 的单调性通过维护一个单调栈构造这一虚树。对于一次询问,将节点按照 DFS 序排序,依次插入单调栈。枚举到某节点时,求出它与原栈顶间的 LCA 并弹栈,直到栈顶元素不大于这一 LCA 为止。这样,可以保证单调栈中保存的是从当前节点到根节点之间所有的有关节点。将每一次询问过程中弹出的节点依次连边,最终将单调栈中将剩余节点依次连边,即可得到所求虚树。在其上简单进行 DP 即可得到答案。

注意需要使用时间戳记录每一个询问所涉及到的节点,并用 long long 记录答案!!!

为了让代码更加美观并富有多样性?!可以分别使用 vector 和链式前向星来存边 233333

Code

#include <bits/stdc++.h>
using namespace std;

#define N 250005
#define M 500005
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3fLL
#define print(x) cout << #x << " = " << x << endl

vector<int> v[N];
int mark[N], cost[N];
int n, m, k, a[N], test;
int head[N], id[N], deep[N], fa[N][19];
long long val[N][19];

struct data {
    int next, to, val;
} edge[M];

inline void add(int x, int y, int z)
{
    static int cnt = 0;
    edge[++cnt].to = y;
    edge[cnt].val = z;
    edge[cnt].next = head[x];
    head[x] = cnt;
}

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch-'0'; ch = getchar(); }
    return x * f;
}

void dfs(int x, int d)
{
    static int tot = 0;
    id[x] = ++tot;
    deep[x] = d;
    for (int y, i = head[x]; i; i = edge[i].next)
        if((y = edge[i].to) != fa[x][0])
        {
            fa[y][0] = x;
            val[y][0] = edge[i].val;
            dfs(y, d + 1);
        }
}

int lca(int x, int y)
{
    if (deep[x] < deep[y]) swap(x, y);
    for (int i = 18; i >= 0; i--)
        if (deep[fa[x][i]] >= deep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 18; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

long long dist(int x, int y)
{
    long long re = inf;
    for (int i = 18; i >= 0; i--)
        if (deep[fa[x][i]] >= deep[y])
            re = min(re, val[x][i]), x = fa[x][i];
    return re;
}

inline bool cmp(int x, int y)
{
    return id[x] < id[y];
}

void add(int x, int y)
{
    static int vis[N];
    if (vis[x] != test)
        vis[x] = test, v[x] = vector<int>();
    if (vis[y] != test)
        vis[y] = test, v[y] = vector<int>();
    v[x].push_back(y);
}

int s[N], top;

void build()
{
    top = 1; s[top] = 1;
    v[1] = vector<int>();
    for (int i = 1; i <= k; ++i)
    {
        int t = lca(s[top], a[i]);
        while (deep[s[top-1]] > deep[t])
            add(s[top-1], s[top]), top--;
        if (t != s[top]) add(t, s[top--]);
        if (t != s[top]) s[++top] = t;
        if (a[i] != s[top]) s[++top] = a[i];
    }
    while (top - 1)
        add(s[top-1], s[top]), top--;
}

long long dfs(int x)
{
    if (mark[x] == test) return INF;
    long long re = 0;
    for (vector<int>::iterator i = v[x].begin(); i != v[x].end(); ++i)
        re += min(dfs(*i), dist(*i, x));
    return re;
}

long long getans()
{
    sort(a + 1, a + k + 1, cmp);
    for (int i = 1; i <= k; ++i)
        mark[a[i]] = test;
    build();
    return dfs(1);
}

int main()
{
    cin >> n;
    for (int x, y, z, i = 1; i < n; ++i)
        x = read(), y = read(), z = read(),
        add(x, y, z), add(y, x, z);
    memset(val, 0x3f, sizeof val);
    dfs(1, 1);
    for (int j = 1; j <= 18; ++j)
        for (int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j-1]][j-1],
            val[i][j] = min(val[fa[i][j-1]][j-1], val[i][j-1]);
    cin >> m;
    for (test = 1; test <= m; ++test)
    {
        k = read();
        for (int j = 1; j <= k; ++j)
            a[j] = read();
        printf("%lld\n", getans());
    }
    return 0;
}