Chaos Slover 补档计划 - 莫队算法

2015.11.09 16:09 Mon| 7 visits oi_2016| ChaosSlover补档计划| Text

莫队算法

大概已经成为没有人愿意考的算法了吧(时代的眼泪)。最初莫涛神犇搞了一个“曼哈顿距离最小生成树”,写起来不能更酸爽。后来被大家优化成了分块,变得更加良心了呢。时间复杂度 O(Msqrt(N))。

BZOJ2038,3781 小Z的袜子

下面贴一个经典莫队算法早期代码。随便感受一下吧。//现在已经完全不会了。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <bits/stdc++.h>
using namespace std;

#define N 50005
#define M 400005
#define inf 0x3f3f3f3f

pair<int,int> q[N];
pair<pair<int,int>,int> p[N];
pair<int,pair<int,int> > edge[M];
int f[N],g[N];
int head[N],nxt[N*2],to[N*2],cnt;
int n,m,v[N],l[N],r[N],rnk[N],fa[N];

void change(int opt)
{
    if(opt==2||opt==4)
        for(int i=1;i<=n;i++)
            swap(p[i].first.first,p[i].first.second);
    if(opt==3)
        for(int i=1;i<=n;i++)
            p[i].first.first=-p[i].first.first;
}

inline void modify(int x,int y,int z)
{
    for(int i=x;i;i-=i&-i)
        if(f[i]>y)
            f[i]=y,g[i]=z;
}

inline int query(int x)
{
    int t=inf,re=0;
    for(int i=x;i<=n;i+=i&-i)
        if(f[i]<t)
            t=f[i],re=g[i];
    return re;
}

inline int dist(pair<int,int> a,pair<int,int> b)
{
    return abs(a.first-b.first)+abs(a.second-b.second);
}

inline void add(int x,int y)
{
    to[++cnt]=y;
    nxt[cnt]=head[x];
    head[x]=cnt;
}

inline int find(int x)
{
    if(x==fa[x]) return x;
    return fa[x]=find(fa[x]);
}

void MMST()
{
    int tot=0;
    for(int i=1;i<=4;i++)
    {
        change(i);
        sort(p+1,p+n+1);
        for(int i=1;i<=n;i++)
            q[i]=make_pair(p[i].first.second-p[i].first.first,i);
        sort(q+1,q+n+1);
        for(int i=1;i<=n;i++)
            rnk[q[i].second]=i;
        memset(f,0x3f,sizeof f);
        memset(g,0,sizeof g);
        for(int i=n;i;i--)
        {
            int t=query(rnk[i]);
            if(t) edge[++tot]=make_pair(dist(p[i].first,p[t].first),
                                pair<int,int>(p[i].second,p[t].second));
            modify(rnk[i],p[i].first.first+p[i].first.second,i);
        }
    }
    sort(edge+1,edge+tot+1);
    for(int i=1;i<=n;i++) fa[i]=i;
    for(int i=1,cnt=0;i<=tot&&cnt<n-1;i++)
    {
        int fx=find(edge[i].second.first),fy=find(edge[i].second.second);
        if(fx!=fy)
            cnt++, fa[fx]=fy,
            add(edge[i].second.first,edge[i].second.second),
            add(edge[i].second.second,edge[i].second.first);
    }
}

unsigned int now,sum[N],ans[N],down[N];
int nowl=1,nowr=0;
void dfs(int pre,int x)
{
    while(nowr<r[x]) now+=sum[v[++nowr]]++;
    while(nowl>l[x]) now+=sum[v[--nowl]]++;
    while(nowr>r[x]) now-=--sum[v[nowr--]];
    while(nowl<l[x]) now-=--sum[v[nowl++]];
    down[x]=(unsigned int)(nowr-nowl+1)*(nowr-nowl)>>1;
    ans[x]=now;
    for(int i=head[x],y;i;i=nxt[i])
        if((y=to[i])!=pre){
            dfs(x,y);
            while(nowr<r[x]) now+=sum[v[++nowr]]++;
            while(nowl>l[x]) now+=sum[v[--nowl]]++;
            while(nowr>r[x]) now-=--sum[v[nowr--]];
            while(nowl<l[x]) now-=--sum[v[nowl++]];
        }
}

unsigned int gcd(int a,int b)
{
    return b?gcd(b,a%b):a;
}

inline void read(int &a)
{
    char ch;
    while((ch=getchar())<'0'||ch>'9');
    a=ch-'0';
    while((ch=getchar())>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+ch-'0';
}


int main()
{
    cin>>m>>n;
    for(int i=1;i<=m;i++)
        read(v[i]);
    for(int i=1;i<=n;i++)
        read(l[i]), read(r[i]),
        p[i]=make_pair(pair<int,int>(l[i],r[i]),i);
    MMST();
    dfs(0,1);
    for(int i=1;i<=n;i++)
    {
        if(ans[i]==0){
            puts("0/1"); continue;
        }
        int t=gcd(ans[i],down[i]);
        printf("%u/%u\n",ans[i]/t,down[i]/t);
    }
    return 0;
}

BZOJ3289 Mato的文件管理

用树状数组维护的莫队算法。莫队算法的思想就是按照一定规律将所有询问排序,使得从上一个答案推到下一个答案的时候不用花费太多时间。将左端点分块后排序,相邻两次询问之间左端点移动距离为 O(sqrt(N)),左端点在某一块内时右端点最多共移动 O(N) 的距离。因此总移动距离为 O(Msqrt(N))。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <bits/stdc++.h>
using namespace std;

#define N 50005

vector<int> v;
int n, m, a[N], f[N], ans[N];

struct query
{
    int p, x, y, b;
    query() {}
    query(int _p, int _x, int _y, int _b)
    : p(_p), x(_x), y(_y), b(_b) {}
    bool operator <(const query &q) const
    {
        return b == q.b ? y < q.y : b < q.b;
    }
} q[N];

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

inline void modify(int x, int d)
{
    for (int i = x; i <= n; i += i & -i)
        f[i] += d;
}

inline int getsum(int x)
{
    int re = 0;
    for (int i = x; i; i -= i & -i)
        re += f[i];
    return re;
}

int main()
{
    cin >> n;
    for (int i = 1; i <= n; ++i)
        a[i] = read(), v.push_back(a[i]);
    sort(v.begin(), v.end());
    for (int i = 1; i <= n; ++i)
        a[i] = lower_bound(v.begin(), v.end(), a[i]) - v.begin() + 1;
    int b = sqrt(n);
    cin >> m;
    for (int x, y, i = 1; i <= m; ++i)
        x = read(), y = read(), q[i] = query(i, x, y, x / b);
    sort(q + 1, q + m + 1);
    int nl = 1, nr = 0, re = 0;
    for (int i = 1; i <= m; ++i)
    {
        while (nl > q[i].x)
            re += getsum(a[--nl] - 1), modify(a[nl], 1);
        while (nr < q[i].y)
            re += getsum(n) - getsum(a[++nr]), modify(a[nr], 1);
        while (nl < q[i].x)
            re -= getsum(a[nl] - 1), modify(a[nl++], -1);
        while (nr > q[i].y)
            re -= getsum(n) - getsum(a[nr]), modify(a[nr--], -1);
        ans[q[i].p] = re;
    }
    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);
}

BZOJ3236 [Ahoi2013]作业

很简单的分块,可是数据范围比较大。看起来真的有点虚。也可以用树套树做,可是有可能会 TLE。

在离散化的时候一定要小心,在这里 WA 了两次。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
#include <bits/stdc++.h>
using namespace std;

#define N 100005
#define M 1000005

vector<int> v;
int n, m, a[N], f1[N], f2[N], cnt[N], ans1[M], ans2[M];

struct data
{
    int p, l, r, a, b, blo;
    data() {}
    data(int _p, int _l, int _r, int _a, int _b, int _blo)
    : p(_p), l(_l), r(_r), a(_a), b(_b), blo(_blo) {}
    inline bool operator <(const data &d) const
    {
        return blo == d.blo ? r < d.r : blo < d.blo;
    }
} q[M];

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

inline void modify1(int x, int d)
{
    for (int i = x; i <= n + 1; i += i & -i)
        f1[i] += d;
}

inline void modify2(int x, int d)
{
    for (int i = x; i <= n + 1; i += i & -i)
        f2[i] += d;
}

inline int getsum1(int x)
{
    int re = 0;
    for (int i = x; i; i -= i & -i)
        re += f1[i];
    return re;
}

inline int getsum2(int x)
{
    int re = 0;
    for (int i = x; i; i -= i & -i)
        re += f2[i];
    return re;
}

inline void add(int x)
{
    modify1(x, 1);
    if (!cnt[x]) modify2(x, 1);
    ++cnt[x];
}

inline void del(int x)
{
    modify1(x, -1);
    --cnt[x];
    if (!cnt[x]) modify2(x, -1);
}

int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; ++i)
        a[i] = read(), v.push_back(a[i]);
    sort(v.begin(), v.end());
    for (int i = 1; i <= n; ++i)
        a[i] = lower_bound(v.begin(), v.end(), a[i]) - v.begin() + 1;
    int b = sqrt(n);
    for (int tl, tr, ta, tb, i = 1; i <= m; ++i)
    {
        tl = read(); tr = read(); ta = read(); tb = read();
        ta = lower_bound(v.begin(), v.end(), ta) - v.begin() + 1;
        tb = upper_bound(v.begin(), v.end(), tb) - v.begin();
        q[i] = data(i, tl, tr, ta, tb, tl / b);
    }
    sort(q + 1, q + m + 1);
    int l = 1, r = 0;
    for (int i = 1; i <= m; ++i)
    {
        while (l > q[i].l) add(a[--l]);
        while (r < q[i].r) add(a[++r]);
        while (l < q[i].l) del(a[l++]);
        while (r > q[i].r) del(a[r--]);
        ans1[q[i].p] = getsum1(q[i].b) - getsum1(q[i].a - 1);
        ans2[q[i].p] = getsum2(q[i].b) - getsum2(q[i].a - 1);
    }
    for (int i = 1; i <= m; ++i)
        printf("%d %d\n", ans1[i], ans2[i]);
    return 0;
}

BZOJ3809 Gty的二逼妹子序列

和上一道题基本相同啊,然而改一改交上去之后就 T 掉了。

上一道题用树状数组维护取值,总时间复杂度其实是 O(Msqrt(N)log(N)+Mlog(N)) 的。其实完全可以利用分块的思想,O(1) 修改,O(sqrt(N)) 询问,将时间复杂度降为 O(Msqrt(N)+Msqrt(N))。然而为什么我的上一道题并没有超时呢?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>
using namespace std;

#define N 100005
#define M 1000005

vector<int> v;
int n, m, blo, a[N], f[N], cnt[N], sum[N], bel[N], ans[M];

struct data
{
    int p, l, r, a, b;
    data() {}
    data(int _p, int _l, int _r, int _a, int _b)
    : p(_p), l(_l), r(_r), a(_a), b(_b) {}
    inline bool operator <(const data &d) const
    {
        return bel[l] == bel[d.l] ? r < d.r : bel[l] < bel[d.l];
    }
} q[M];

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

inline int getsum(int l, int r)
{
    int bl = bel[l], br = bel[r], re = 0;
    for (int i = bl + 1; i < br; ++i)
        re += sum[i];
    if (bl == br)
    {
        for (int i = l; i <= r; ++i)
            re += cnt[i] ? 1 : 0;
    }
    else
    {
        for (int i = l; bel[i] == bl; ++i)
            re += cnt[i] ? 1 : 0;
        for (int i = r; bel[i] == br; --i)
            re += cnt[i] ? 1 : 0;
    }
    return re;
}   

inline void add(int x)
{
    if (!cnt[x]) ++sum[bel[x]];
    ++cnt[x];
}

inline void del(int x)
{
    --cnt[x];
    if (!cnt[x]) --sum[bel[x]];
}

int main()
{
    cin >> n >> m;
    blo = sqrt(n);
    for (int i = 1; i <= n; ++i)
        a[i] = read(), bel[i] = (i - 1) / blo + 1;
    for (int tl, tr, ta, tb, i = 1; i <= m; ++i)
    {
        tl = read(); tr = read(); ta = read(); tb = read();
        q[i] = data(i, tl, tr, ta, tb);
    }
    sort(q + 1, q + m + 1);
    int l = 1, r = 0;
    for (int i = 1; i <= m; ++i)
    {
        while (l > q[i].l) add(a[--l]);
        while (r < q[i].r) add(a[++r]);
        while (l < q[i].l) del(a[l++]);
        while (r > q[i].r) del(a[r--]);
        ans[q[i].p] = getsum(q[i].a, q[i].b);
    }
    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);
    return 0;
}

BZOJ3757 苹果树

树上莫队算法。首先需要考虑如何对树分块才能够使得每一个块之内任意两个结点间距离不超过一个定值。实际上这用一次 DFS 即可解决。

树上分块裸题:BZOJ1086 王室联邦。

将所有询问的左端点按照所在块的编号排序,右端点按照 DFS 序排序。这样能够保证总复杂度不会改变,仍然是 O(Msqrt(N))。

根据网上到处都有的 VFK 大神的证明,知道了链上的一个端点从点 u 移动到点 u',需要修改的点就是链 (u,u') 上除了 lca 外的所有点。只需要将这些点的状态全部取反就好了。至于色盲的判断,很简单啊。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <bits/stdc++.h>
using namespace std;

#define N 50005
#define M 100005

stack<int> s;
int n, m, blo, col[N], head[N], id[N], bel[N];
int deep[N], fa[N][17], vis[N], cnt[N], tot, ans[M];

struct data
{
    int p, u, v, a, b;
    data() {}
    data(int _p, int _u, int _v, int _a, int _b)
    : p(_p), u(_u), v(_v), a(_a), b(_b) {}

    inline bool operator <(const data &d) const
    {
        return bel[u] == bel[d.u] ? id[v] < id[d.v] : bel[u] < bel[d.u];
    }
} q[M];

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

struct graph
{
    int next, to;
    graph() {}
    graph(int _next, int _to)
    : next(_next), to(_to) {}
} edge[M];

inline void add(int x, int y)
{
    static int cnt = 0;
    edge[++cnt] = graph(head[x], y);
    head[x] = cnt;
}

int dfs(int x)
{
    static int tot = 0, c = 0;
    id[x] = ++c; deep[x] = deep[fa[x][0]] + 1;
    s.push(x);
    int size = 0;
    for (int i = head[x]; i; i = edge[i].next)
        if (edge[i].to != fa[x][0])
        {
            fa[edge[i].to][0] = x;
            size += dfs(edge[i].to);
            if (size >= blo)
            {
                size = 0; ++tot;
                while (s.top() != x)
                    bel[s.top()] = tot, s.pop();
            }
        }
    return size + 1;
}

int lca(int x, int y)
{
    if (deep[x] < deep[y]) swap(x, y);
    for (int i = 16; i >= 0; --i)
        if (deep[fa[x][i]] >= deep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 16; i >= 0; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

inline void change(int x)
{
    if (vis[x]) { --cnt[col[x]]; if (!cnt[col[x]]) --tot; }
    else { if (!cnt[col[x]]) ++tot; ++cnt[col[x]]; }
    vis[x] ^= 1;
}

inline void reverse(int x, int f)
{
    for (int i = x; i != f; i = fa[i][0])
        change(i);
}

inline void getans(int x)
{
    int t = lca(q[x].u, q[x].v);
    change(t);
    ans[q[x].p] = tot;
    if (q[x].a != q[x].b && cnt[q[x].a] && cnt[q[x].b])
        --ans[q[x].p];
    change(t);
}

int main()
{
    cin >> n >> m;
    blo = pow(n, 2.0/3.0);
    for (int i = 1; i <= n; ++i)
        col[i] = read();
    for (int x, y, i = 1; i <= n; ++i)
        x = read(), y = read(), add(x, y), add(y, x);
    dfs(0);
    for (int j = 1; j <= 16; ++j)
        for (int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
    for (int u, v, a, b, i = 1; i <= m; ++i)
    {
        u = read(), v = read(), a = read(), b = read();
        if (id[u] > id[v]) swap(u, v);
        q[i] = data(i, u, v, a, b);
    }
    sort(q + 1, q + m + 1);
    int t = lca(q[1].u, q[1].v);
    reverse(q[1].u, t); reverse(q[1].v, t); getans(1);
    for (int i = 2; i <= m; ++i)
    {
        t = lca(q[i - 1].u, q[i].u);
        reverse(q[i - 1].u, t); reverse(q[i].u, t);
        t = lca(q[i - 1].v, q[i].v);
        reverse(q[i - 1].v, t); reverse(q[i].v, t);
        getans(i);
    }
    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);
    return 0;
}

BZOJ3052&UOJ58 [wc2013]糖果公园

谜之时间复杂度——强行不是 O(N^2)。

首先发现利用莫队算法,对于两次询问,分别需要处理时间上的更改和路径上的更改。然而我们并不会写三维曼哈顿 MST。

考虑一种时间复杂度很神奇的方式——

把树按照 N^(2/3) 的大小分块。之后将询问排序时,第一维关键字是左端点所在块,第二维关键字是右端点所在块,第三维关键字是时间。

下面来计算时间复杂度吧!

首先计算路径上的转移次数,很显然是 O(M*N^(2/3)) 次。

接下来计算时间上的转移次数。把树按照 N^(2/3) 的大小分块,也就意味着一共分出了 N^(1/3) 块,左、右端点在这些块上的分布情况一共有 (N^(1/3))^2=N^(2/3) 种。在其中的每一种内,时间都会被修改 O(M) 次。因此一共会对时间进行 O(M*N^(2/3)) 次修改。

发现 N、M 同阶,所以总时间复杂度为 O(N^(5/3))。评测机表示很开心。

写完之后根本没有需要调试。真是令人开心。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <bits/stdc++.h>
using namespace std;

#define N 100005

stack<int> s;
int blo, n, m, q, v[N], w[N], c[N], pos[N], to[N];
int id[N], bel[N], deep[N], fa[N][17], vis[N], cnt[N];
long long sum, ans[N];

struct data
{
    int p, x, y;
    data() {}
    data(int _p, int _x, int _y) : p(_p), x(_x), y(_y)
    {
        if (id[x] > id[y]) swap(x, y);
    }
    inline bool operator <(const data &d) const
    {
        if (bel[x] == bel[d.x] && bel[y] == bel[d.y]) return p < d.p;
        if (bel[x] == bel[d.x]) return bel[y] < bel[d.y];
        return bel[x] < bel[d.x];
    }
} d[N];

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

int head[N];

struct graph
{
    int next, to;
    graph() {}
    graph(int _next, int _to)
    : next(_next), to(_to) {}
} edge[N * 2];

inline void add(int x, int y)
{
    static int cnt = 0;
    edge[++cnt] = graph(head[x], y);
    head[x] = cnt;
}

int dfs(int x)
{
    static int tot = 0, c = 0;
    id[x] = ++c; deep[x] = deep[fa[x][0]] + 1;
    s.push(x);
    int size = 0;
    for (int i = head[x]; i; i = edge[i].next)
        if (edge[i].to != fa[x][0])
        {
            fa[edge[i].to][0] = x;
            size += dfs(edge[i].to);
            if (size >= blo)
            {
                size = 0; ++tot;
                while (s.top() != x)
                    bel[s.top()] = tot, s.pop();
            }
        }
    return size + 1;
}

int lca(int x, int y)
{
    if (deep[x] < deep[y]) swap(x, y);
    for (int i = 16; i >= 0; --i)
        if (deep[fa[x][i]] >= deep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 16; i >= 0; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

inline void change(int x, int d)
{
    if (d == 1) { ++cnt[x]; sum += 1ll * v[x] * w[cnt[x]]; }
    else { sum -= 1ll * v[x] * w[cnt[x]]; --cnt[x]; }
}

inline void reverse(int x, int f)
{
    for (int i = x; i != f; i = fa[i][0])
        change(c[i], vis[i] ? -1 : 1), vis[i] ^= 1;
}

inline void getans(int x)
{
    int t = lca(d[x].x, d[x].y);
    change(c[t], 1);
    ans[d[x].p] = sum;
    change(c[t], -1);
}

inline void run(int s, int t)
{
    for (int i = s; i <= t; ++i)
        if (pos[i])
        {
            if (vis[pos[i]])
                change(c[pos[i]], -1), change(to[i], 1);
            swap(c[pos[i]], to[i]);
        }
    for (int i = s; i >= t; --i)
        if (pos[i])
        {
            if (vis[pos[i]])
                change(c[pos[i]], -1), change(to[i], 1);
            swap(c[pos[i]], to[i]);
        }
}

int main()
{
    cin >> n >> m >> q;
    blo = pow(n, 2.0 / 3.0);
    for (int i = 1; i <= m; ++i)
        v[i] = read();
    for (int i = 1; i <= n; ++i)
        w[i] = read();
    for (int x, y, i = 1; i < n; ++i)
        x = read(), y = read(), add(x, y), add(y, x);
    for (int i = 1; i <= n; ++i)
        c[i] = read();
    dfs(1);
    for (int j = 1; j <= 16; ++j)
        for (int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
    for (int t, x, y, i = 1; i <= q; ++i)
    {
        t = read(), x = read(), y = read();
        if (t == 0) pos[i] = x, to[i] = y;
        else d[i] = data(i, x, y);
    }
    sort(d + 1, d + q + 1);
    int now = 0;
    int t = lca(d[1].x, d[1].y);
    reverse(d[1].x, t); reverse(d[1].y, t);
    run(now, d[1].p); now = d[1].p;
    getans(1);
    for (int i = 2; i <= q; ++i)
    {
        t = lca(d[i - 1].x, d[i].x);
        reverse(d[i - 1].x, t); reverse(d[i].x, t);
        t = lca(d[i - 1].y, d[i].y);
        reverse(d[i - 1].y, t); reverse(d[i].y, t);
        run(now, d[i].p); now = d[i].p;
        getans(i);
    }
    for (int i = 1; i <= q; ++i)
        if (ans[i])
            printf("%lld\n", ans[i]);
    return 0;
}

BZOJ4129 Haruna’s Breakfast

在糖果公园的基础上又套了一个傻分块(BZOJ3585),代码长度成功上 4000B。

我们要求的 mex 函数值一定不大于 N,因此 a 的取值大于 n 的话直接就不需要考虑了。将它们赋值为 n+1 刚刚好。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <bits/stdc++.h>
using namespace std;

#define N 50005

stack<int> s;
int blo, n, m, q, c[N], pos[N], to[N];
int id[N], bel[N], deep[N], fa[N][17], vis[N], cnt[N];
int blos, bels[N], sum[N], ans[N];

struct data
{
    int p, x, y;
    data() {}
    data(int _p, int _x, int _y) : p(_p), x(_x), y(_y)
    {
        if (id[x] > id[y]) swap(x, y);
    }
    inline bool operator <(const data &d) const
    {
        if (bel[x] == bel[d.x] && bel[y] == bel[d.y]) return p < d.p;
        if (bel[x] == bel[d.x]) return bel[y] < bel[d.y];
        return bel[x] < bel[d.x];
    }
} d[N];

inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

int head[N];

struct graph
{
    int next, to;
    graph() {}
    graph(int _next, int _to)
    : next(_next), to(_to) {}
} edge[N * 2];

inline void add(int x, int y)
{
    static int cnt = 0;
    edge[++cnt] = graph(head[x], y);
    head[x] = cnt;
}

int dfs(int x)
{
    static int tot = 0, c = 0;
    id[x] = ++c; deep[x] = deep[fa[x][0]] + 1;
    s.push(x);
    int size = 0;
    for (int i = head[x]; i; i = edge[i].next)
        if (edge[i].to != fa[x][0])
        {
            fa[edge[i].to][0] = x;
            size += dfs(edge[i].to);
            if (size >= blo)
            {
                size = 0; ++tot;
                while (s.top() != x)
                    bel[s.top()] = tot, s.pop();
            }
        }
    return size + 1;
}

int lca(int x, int y)
{
    if (deep[x] < deep[y]) swap(x, y);
    for (int i = 16; i >= 0; --i)
        if (deep[fa[x][i]] >= deep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 16; i >= 0; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

inline int mex()
{
    int re = 0;
    while (sum[bels[re]] == blos) re += blos;
    while (cnt[re]) ++re;
    return re;
}

inline void change(int x, int d)
{
    if (d == 1) { if (cnt[x] == 0) ++sum[bels[x]]; ++cnt[x]; }
    else { --cnt[x]; if (cnt[x] == 0) --sum[bels[x]]; }
}

inline void reverse(int x, int f)
{
    for (int i = x; i != f; i = fa[i][0])
        change(c[i], vis[i] ? -1 : 1), vis[i] ^= 1;
}

inline void getans(int x)
{
    int t = lca(d[x].x, d[x].y);
    change(c[t], 1);
    ans[d[x].p] = mex();
    change(c[t], -1);
}

inline void run(int s, int t)
{
    for (int i = s; i <= t; ++i)
        if (pos[i])
        {
            if (vis[pos[i]])
                change(c[pos[i]], -1), change(to[i], 1);
            swap(c[pos[i]], to[i]);
        }
    for (int i = s; i >= t; --i)
        if (pos[i])
        {
            if (vis[pos[i]])
                change(c[pos[i]], -1), change(to[i], 1);
            swap(c[pos[i]], to[i]);
        }
}

int main()
{
    cin >> n >> q;
    blo = pow(n, 2.0 / 3.0);
    for (int i = 1; i <= n; ++i)
        c[i] = min(read(), n + 1);
    for (int x, y, i = 1; i < n; ++i)
        x = read(), y = read(), add(x, y), add(y, x);
    dfs(1);
    for (int j = 1; j <= 16; ++j)
        for (int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
    for (int t, x, y, i = 1; i <= q; ++i)
    {
        t = read(), x = read(), y = read();
        if (t == 0) pos[i] = x, to[i] = min(y, n + 1);
        else d[i] = data(i, x, y);
    }
    sort(d + 1, d + q + 1);
    blos = sqrt(n);
    for (int i = 0; i <= n; ++i)
        bels[i] = i / blos + 1;
    int now = 0;
    int t = lca(d[1].x, d[1].y);
    reverse(d[1].x, t); reverse(d[1].y, t);
    run(now, d[1].p); now = d[1].p;
    getans(1);
    for (int i = 2; i <= q; ++i)
    {
        t = lca(d[i - 1].x, d[i].x);
        reverse(d[i - 1].x, t); reverse(d[i].x, t);
        t = lca(d[i - 1].y, d[i].y);
        reverse(d[i - 1].y, t); reverse(d[i].y, t);
        run(now, d[i].p); now = d[i].p;
        getans(i);
    }
    for (int i = 1; i <= q; ++i)
        if (!pos[i])
            printf("%d\n", ans[i]);
    return 0;
}