TopCoder TCO2013 Wildcard SemiMultiple

2016.04.08 15:39 Fri| 16 visits oi_2016| 2016_刷题日常| Text

算法分析

首先,如果 m 是一个偶数,可以直接通过缩减 n 和预处理答案使得 m 变成一个奇数。这一过程十分简单,以至于连我都会。

之后,我们可以初步确定答案:通过 O(nm) 的普及组动态规划求出长度为 n 且除以 m 的余数为 0~m-1 的数的个数。可以发现只有与 2^i 或 -2^i 同余的数可能对答案有贡献,即翻转相应的一位使得余数变为 0。可是,其中仍然存在有不合法的数存在着。

这里,我们可以发现,一个数不合法仅当它在它的所有同余的二进制位上都不能通过翻转使得余数变成 0。于是考虑将这些位从答案中分离出来,即求出在不存在这些位的情况下, 除以 m 的余数为 0~m-1 的数的个数,这可以通过解方程解决:

设 t[x] 为在所有位任意取值的前提下,除以 m 的余数为 x 的数的个数,a[x] 为当不存在第 i 位时,除以 m 的余数为 x 的数的个数,b 为 2^i % m。列出方程:对于任意 0<=x<m,t[x] = a[x-t]+a[x]。由于模数为奇数,所以上面的方程中,a 的有关联的项之间形成了环状结构,可以很方便地在 O(m) 的时间内解出来。

对于一个余数,重复上述操作就可以去除所有令人不爽的位的影响啦,之后我们可以直接得到多算的数的个数,用答案减一下就好。

这种做法果然比各种乱七八糟的一大坨 DP 堆在一起优雅多了!

代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;

int mod = 1000000007;

class SemiMultiple
{
public:
    void solve(int t[], int m, int b)
    {
        int vis[2005] = {0};
        for (int i = 0; i < m; ++i)
            if (!vis[i])
            {
                int sum = 0;
                for (int j = i; !vis[j]; j = (j + b) % m)
                    sum = (sum + t[j]) % mod, vis[j] = true;
                sum = 1ll * sum * ((mod + 1) / 2) % mod;
                int fl = 0, j = i;
                do {
                    if (fl) sum = (sum - t[j] + mod) % mod;
                    fl ^= 1, j = (j + b) % m;
                } while (j != i);
                do
                {
                    t[j] = (t[j] - sum + mod) % mod;
                    sum = t[j], j = (j + b) % m;
                } while (j != i);
            }
    }
    int f[2005][2005] = {{0}};
    int count(int n, int m)
    {
        int re = 0, cnt = 0;
        while (m % 2 == 0 && n)
            ++cnt, --n, m >>= 1;
        int d[2005] = {0};
        for (int t = 1, i = 0; i < n; ++i, t = t * 2 % m)
            ++d[t];
        f[0][0] = 1;
        for (int t = 1, i = 1; i <= n; ++i, t = t * 2 % m)
            for (int j = 0; j < m; ++j)
                f[i][j] = (f[i - 1][j] + f[i - 1][(j - t + m) % m]) % mod;
        re = 1ll * cnt * f[n][0] % mod;
        for (int t[2005], t2[2005], i = 1; i < m; ++i)
            if (d[i] || d[m - i])
            {
                re = (re + f[n][i]) % mod;
                memcpy(t, f[n], sizeof t);
                for (int j = 1; j <= d[i]; ++j)
                    solve(t, m, i);
                for (int j = 1; j <= d[m - i]; ++j)
                {
                    solve(t, m, m - i);
                    memcpy(t2, t, sizeof t);
                    for (int k = 0; k < m; ++k)
                        t[k] = t2[(k + i) % m];
                }
                re = (re - t[i] + mod) % mod;
            }
        return re;
    }
};