BZOJ3238 [AHOI2013]差异

2015.01.04 20:53 Sun| 0 visits oi_2015| 2015_刷题日常| Text

Solution

后缀数组+二分RMQ/单调栈。

很显然, $\sum\limits_{1\le i<j \le n} len(T_i)+len(T_j)+2\times lcp(T_i, T_j)$ 可以拆成独立的两部分分别求解。

首先, $\sum\limits_{1\le i<j \le n} len(T_i)+len(T_j)={len\times(len+1)\over 2}\times(len-1)$ 。考虑计算 $\sum\limits_{1\le i<j \le n} lcp(T_i, T_j)$ 。求出原串的后缀数组之后观察发现对于一个 height[i] ,能更新的范围是左面不存在height值小于等于它的区间和右面不存在height值小于它的区间长度的乘积,单调栈 $\mathcal O(n)$ 维护即可。注意讨论边界问题!!!

Code

#include <bits/stdc++.h>
using namespace std;
#define N 500005

char str[N];
int x[N], y[N], c[N];
int n, sa[N], rank[N], height[N];

inline bool same(int a, int b, int l)
{
    return y[a] == y[b] &&
    ((a+l>=n && b+l>=n) || (a+l<n && b+l<n && y[a+l] == y[b+l]));
}

void get_sa(int m=256)
{
    for(int i=0; i<n; i++) c[x[i] = str[i]]++;
    for(int i=1; i<m; i++) c[i] += c[i-1];
    for(int i=n-1; i>=0; i--) sa[--c[x[i]]] = i;
    for(int k=1; k<n; k<<=1)
    {
        int p = 0;
        for(int i=n-k; i<n; i++) y[p++] = i;
        for(int i=0; i<n; i++) if(sa[i]>=k) y[p++] = sa[i]-k;
        memset(c, 0, m*sizeof(int));
        for(int i=0; i<n; i++) c[x[y[i]]]++;
        for(int i=1; i<m; i++) c[i] += c[i-1];
        for(int i=n-1; i>=0; i--) sa[--c[x[y[i]]]] = y[i];
        for(int i=0; i<n; i++) y[i] = x[i];
        m = 0; x[sa[0]] = 0;
        for(int i=1; i<n; i++)
            x[sa[i]] = same(sa[i], sa[i-1], k) ? m : ++m;
        if(++m == n) break;
    }
}

void get_height()
{
    for(int i=0; i<n; i++) rank[sa[i]]=i;
    for(int i=0, j, k=0; i<n; height[rank[i++]]=k)
        if(rank[i]) for(k=max(k-1, 0), j=sa[rank[i]-1]; str[i+k]==str[j+k]; k++);
}

stack<int> s;
long long getans()
{
    long long ans = 0;
    static int l[N], r[N];
    for(int i=1; i<n; i++)
    {
        while(!s.empty() && height[i] < height[s.top()]) s.pop();
        if(s.empty()) l[i] = 0;
        else l[i] = s.top();
        s.push(i);
    }
    while(!s.empty()) s.pop();
    for(int i=n-1; i; i--)
    {
        while(!s.empty() && height[i] <= height[s.top()]) s.pop();
        if(s.empty()) r[i] = n;
        else r[i] = s.top();
        s.push(i);
    }
    for(int i=0; i<n; i++)
        ans -= (long long)(i-l[i])*(r[i]-i)*height[i];
    return ans*2 + (long long)n*(n+1)/2*(n-1);
}

int main()
{
    scanf("%s", str);
    n=strlen(str);
    get_sa();
    get_height();
    cout << getans() << endl;
    return 0;
}