终于做完OS课设了,奖励自己学点算法(

最近看了点恋爱小甜饼搞得心里暖暖的🥰本来想当封面但是感觉根号分治不够重量级,攒一攒再说了,再放张狈宝吧


概述

根号分治不是一种算法,而是一类通过数据规模进行分类讨论,在不同的情况下使用不同的算法,将整体的复杂度控制在一个比较小的范围的思想。

常见于这样的情形:直接暴力计算会导致时间复杂度太高,如果预处理空间又不够。我们选取一个阈值 $L$(例如 $L=\sqrt{n}$ ),对于小于 $\sqrt{n}$ 的数据进行预处理,大于 $\sqrt{n}$ 的数据暴力计算,两种情形下复杂度均不超过 $O(n\sqrt{n})$ ,就实现了复杂度的平衡。一般来讲阈值选取根号大小是最合适的,故而得名根号分治。

思想本身非常浅显易懂,比整体二分还自然。直接看例题。

例题

P3396 哈希冲突

给定序列,单点修改,查询下标模 $x$ 等于 $y$ 的所有下标对应元素之和。

根号分治板子题。

$x \gt \sqrt{n}$ 时直接暴力计算,满足条件的下标只有 $O(\sqrt{n})$ 个。

$x \le \sqrt{n}$ 时预处理,定义 $f[x][y]$ 是所有下标模 $x$ 等于 $y$ 的元素之和,查询直接输出,修改时也只需要修改 $O(\sqrt{n})$ 次。

总复杂度 $O(m\sqrt{n})$ ,可以通过本题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 0
#define int i64
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>;

constexpr int N = 2e5 + 10;
constexpr int B = 500;
constexpr int MOD = 998244353;

int f[B][B];
int a[N];
int n, m, blen;

void solve() {
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];
blen = sqrt(n);
for (int i = 1; i <= blen; ++i) {
for (int j = 1; j <= n; ++j) {
f[i][j % i] += a[j];
}
}
while (m--) {
char c;
int x, y;
cin >> c >> x >> y;
if (c == 'A') {
if (x <= blen) cout << f[x][y] << "\n";
else {
int res = 0;
for ( ; y <= n; y += x) res += a[y];
cout << res << "\n";
}
} else {
int delta = y - a[x];
a[x] = y;
for (int i = 1; i <= blen; ++i) f[i][x % i] += delta;
}
}
}

signed main() {

FIO;
TEST {
solve();
}

return 0;
}

CF797E Array Queries

给定正整数序列,每次询问给定 $p$、$k$ ,不断执行 $p = p + a[p]+k$ 直到 $p \gt n$ ,输出操作次数。

根据 $k$ 的大小分类讨论。

$k \gt \sqrt{n}$ 时,最多跳 $O(\sqrt{n})$ 次,直接暴力计算;

$k \le \sqrt{n}$ 时,预先计算答案,定义 $dp[p][k]$ 是按照题意操作的次数,做简单DP即可。

总复杂度 $O((n+m)\sqrt{n})$ 。

注意DP时循环两维的顺序!如果换过来会造成很多缓存不命中,慢一倍以上!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 0
#define int i64
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>;

constexpr int N = 1e5 + 10;
constexpr int B = 400;
constexpr int MOD = 998244353;

int dp[N][B];
int a[N];

void solve() {
int n, q, blen;
cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i];
blen = sqrt(n);
// 注意顺序
for (int i = n; i >= 1; --i) {
for (int k = 0; k <= blen; ++k) {
if (i + a[i] + k > n) dp[i][k] = 1;
else dp[i][k] = dp[i + a[i] + k][k] + 1;
}
}
cin >> q;
while (q--) {
int p, k;
cin >> p >> k;
if (k <= blen) cout << dp[p][k] << "\n";
else {
int res = 0;
do {
p += a[p] + k;
res++;
} while (p <= n);
cout << res << "\n";
}
}
}

signed main() {

FIO;
solve();

return 0;
}

CF1921F Sum of Progression

给定序列,每次询问给定 $s$、$d$、$k$ ,从第 $s$ 个元素开始,每次步长为 $d$ ,共取 $k$ 个元素,将取到的第 $i$ 个元素乘 $i$ ,求和。

根据 $d$ 的大小分类讨论。

$d \gt \sqrt{n}$ 时,最多跳 $O(\sqrt{n})$ 次,直接暴力计算;

$d \le \sqrt{n}$ 时,预先计算答案。定义 $f[d][i]$ 为步长为 $d$ ,从第 $i$ 个元素开始取,得到的元素后缀和,$g[d][i]$ 为步长为 $d$ ,从第 $i$ 个元素开始取,将取到的元素乘上对应系数的后缀和。我们要的答案就是 $g[d][s]-g[d][s+dk]-k\times f[d][s+dk]$ 。

总复杂度 $O((n+m)\sqrt{n})$ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 1
#define int i64
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>;

constexpr int N = 1e5 + 10;
constexpr int B = 400;
constexpr int MOD = 998244353;

int f[B][N], g[B][N];
int a[N];

void solve() {
int n, q;
cin >> n >> q;
for (int i = 1; i <= n; ++i) cin >> a[i];
int blen = sqrt(n);
for (int d = 1; d <= blen; ++d) {
for (int i = n; i >= 1; --i) {
f[d][i] = a[i] + (i + d > n ? 0 : f[d][i + d]);
}
}
for (int d = 1; d <= blen; ++d) {
for (int i = n; i >= 1; --i) {
g[d][i] = f[d][i] + (i + d > n ? 0 : g[d][i + d]);
}
}
while (q--) {
int s, d, k;
cin >> s >> d >> k;
if (d <= blen) {
cout << g[d][s] - (s + d * k > n ? 0 : g[d][s + d * k] + k * f[d][s + d * k]) << " ";
continue;
}
int res = 0;
for (int i = s, j = 1; i <= n && j <= k; i += d, j++) {
res += a[i] * j;
}
cout << res << " ";
}
cout << "\n";
}

signed main() {

FIO;
TESTS {
solve();
}

return 0;
}

P5309 [Ynoi2011] 初始化

给定序列,每次对模 $x$ 等于 $y$ 的下标元素加 $z$ ,查询区间和。

在序列分块的基础上加入根号分治。

先讨论修改,我们维护整块和 $sum[i]$ 。把修改操作看成在长度为 $x$ 的周期上对应位置进行单点加,于是可以维护每个周期内的前缀后缀修改和,即 $pre[x][i]$ 和 $suf[x][i]$。

$x \gt \sqrt{n}$ 时,最多改 $O(\sqrt{n})$ 次,直接暴力修改 $arr$ 和 $sum$;

$x \le \sqrt{n}$ 时,修改周期前后缀信息。在这个范围下也只需要修改 $O(\sqrt{n})$ 次。

再看查询。不考虑周期信息,先按照经典分块拿到散块和整块之和。再遍历 $\sqrt{n}$ 内的每个周期,把贡献拆成前后缀,累加即可得到答案,单周期内累加是 $O(1)$ 的,因此查询复杂度还是 $O(\sqrt{n})$ 。做完了。

遍历每个周期需要做大量取模,并且这里模数是变量,编译器无法优化,使用long long会导致TLE!

只能使用4字节的整型!(unsigned比int快,因为编译器可以使用更激进的SIMD优化)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 0
#define int u32
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>;

inline char nc() {
static char buf[1 << 20], *p1, *p2;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20, stdin), p1 == p2) ? EOF : *p1++;
}
#ifndef ONLINE_JUDGE
#define nc getchar
#endif
void read() {}
template<typename T, typename... T2>
inline void read(T &x, T2 &... oth) {
x = 0; char c = nc(), up = c;
while(!isdigit(c)) up = c, c = nc();
while(isdigit(c)) x = x * 10 + c - '0', c = nc();
up == '-' ? x = -x : 0;
read(oth...);
}

constexpr int N = 2e5 + 10;
constexpr int B = 500;
constexpr int MOD = 1e9 + 7;

int bi[N], bl[B], br[B];
i64 a[N], sum[B], pre[B][B], suf[B][B];
int n, m, blen;

void modify(int x, int y, int z) {
if (x > blen) {
for (int i = y; i <= n; i += x) {
a[i] += z;
sum[bi[i]] += z;
}
} else {
for (int i = y; i <= x; ++i) pre[x][i] += z;
for (int i = y; i >= 1; --i) suf[x][i] += z;
}
}

i64 query(int l, int r) {
i64 res = 0;
if (bi[l] == bi[r]) {
for (int i = l; i <= r; ++i) res += a[i];
} else {
for (int i = l; i <= br[bi[l]]; ++i) res += a[i];
for (int i = bi[l] + 1; i < bi[r]; ++i) res += sum[i];
for (int i = bl[bi[r]]; i <= r; ++i) res += a[i];
}
res %= MOD;
for (int x = 1; x <= blen; ++x) {
int lth = (l - 1) / x + 1, rth = (r - 1) / x + 1;
int num = rth - lth - 1;
if (lth == rth) {
res = res + pre[x][(r - 1) % x + 1] - pre[x][(l - 1) % x];
} else {
res = res + suf[x][(l - 1) % x + 1] + num * pre[x][x] + pre[x][(r - 1) % x + 1];
}
}
res %= MOD;
return res;
}

void solve() {
read(n, m);
blen = sqrt(n);
int bnum = (n + blen - 1) / blen;
for (int i = 1; i <= n; ++i) {
read(a[i]);
}
for (int i = 1; i <= n; ++i) bi[i] = (i - 1) / blen + 1;
for (int i = 1; i <= bnum; ++i) {
bl[i] = (i - 1) * blen + 1;
br[i] = min(i * blen, n);
}
for (int i = 1; i <= n; ++i) {
sum[bi[i]] += a[i];
}
while (m--) {
int op, x, y, z;
read(op, x, y);
if (op == 1) {
read(z);
modify(x, y, z);
} else {
cout << query(x, y) << "\n";
}
}
}

signed main() {

FIO;
TEST {
solve();
}

return 0;
}

P3645 [APIO2015] 雅加达的摩天楼

$n$ 座楼,$m$ 只狗,每只狗在一座楼上。每只狗有一个 $r$ 属性,在第 $i$ 座楼时可以移动至 $i-r$ 或 $i+r$ 。第 $0$ 只狗要向第 $1$ 只狗传递信息,信息可以经由别的狗代传,同座楼的狗可共享信息,问最少需要移动几次。

BFS即可。这题根号分治体现在其复杂度的证明上。我们需要对狗的状态进行去重,只记录每座大楼里可到达的狗的所有 $r$ 属性,相同 $r$ 属性当做同一只狗进行处理。

对于 $r \gt \sqrt{n}$ 的狗,最多可以到达 $O(\sqrt{n})$ 座楼,共 $m$ 只狗,因此总状态数量为 $O(m\sqrt{n})$ 。

对于 $r \le \sqrt{n}$ 的狗,由于我们已经做了去重,因此每座楼最多只有 $O(\sqrt{n})$ 种状态,共 $n$ 座楼,总状态数量为 $O(n\sqrt{n})$ 。

以上,我们证明了总状态数量的上界是 $O((n+m)\sqrt{n})$ ,BFS完全没有问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 0
#define int i64
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>

constexpr int N = 3e4 + 10;
constexpr int MOD = 998244353;

vector<int> e[N];
bitset<N> vis[N];
queue<tuple<int, int, int>> q;
int n, m;

void add(int i, int j) {
e[i].push_back(j);
}

void trigger(int x, int t) {
for (int j : e[x]) {
if (vis[x][j]) continue;
vis[x][j] = 1;
q.emplace(x, j, t);
}
e[x].clear();
}

void extend(int x, int j, int t) {
trigger(x, t);
if (vis[x][j]) return;
vis[x][j] = 1;
q.emplace(x, j, t);
}

int bfs(int s, int t) {
if (s == t) return 0;
trigger(s, 0);
while (q.size()) {
auto [x, j, time] = q.front();
q.pop();
if (x - j == t || x + j == t) return time + 1;
if (x - j >= 1) extend(x - j, j, time + 1);
if (x + j <= n) extend(x + j, j, time + 1);
}
return -1;
}

void solve() {
cin >> n >> m;
int s, t, sjump, tjump;
cin >> s >> sjump >> t >> tjump;
add(s, sjump);
add(t, tjump);
for (int i = 2; i < m; ++i) {
int x, j;
cin >> x >> j;
add(x, j);
}
cout << bfs(s, t) << "\n";
}

signed main() {

FIO;
solve();

return 0;
}

CF786C Till I Collapse

给定序列,对于每个 $k \in [1,n]$ ,计算将序列拆分成每段最多 $k$ 种数,最少需要拆成的段数。

定义 $query(k)$ ,含义为计算每段最多 $k$ 种数时最少需要拆成的段数。简单DP可以在 $O(n)$ 的时间进行计算。

我们选取一个阈值 $L$,对于 $k \in [1,L]$ ,直接调用 $query(k)$ 。

对于 $k \in [L + 1, n]$,可以发现此时答案最多为 $\frac{n}{L}$ ,并且单调递减。类似于整除分块的表现,每个答案值对应一段连续的区间。于是我们可以二分得到每段区间的右端点。

总复杂度 $O(Ln+\frac{n^2}{L}\log{n})$ ,可以发现 $L$ 取 $\sqrt{n\log{n}}$ 时最优。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 0
#define int i64
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>;

constexpr int N = 1e5 + 10;
constexpr int MOD = 998244353;

int a[N], ans[N], vis[N];
int n;

int query(int x) {
int cnt = 0;
int kind = 0;
int j = 1;
for (int i = 1; i <= n; ++i) {
if (!vis[a[i]]) {
vis[a[i]] = 1;
kind++;
if (kind > x) {
cnt++;
kind = 1;
for (int k = j; k < i; ++k) vis[a[k]] = 0;
j = i;
}
}
}
if (kind) {
cnt++;
}
for (int i = j; i <= n; ++i) vis[a[i]] = 0;
return cnt;
}

int jump(int i, int val) {
int l = i, r = n + 1;
while (l < r) {
int mid = l + r >> 1;
if (query(mid) < val) r = mid;
else l = mid + 1;
}
return l;
}

void solve() {
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
int blen = sqrt(n * (__lg(n) + 1));
for (int i = 1; i <= blen; ++i) ans[i] = query(i);
for (int i = blen + 1; i <= n; i = jump(i, ans[i])) {
ans[i] = query(i);
}
for (int i = blen + 1; i <= n; ++i) {
if (ans[i]) continue;
ans[i] = ans[i - 1];
}
for (int i = 1; i <= n; ++i) cout << ans[i] << " \n"[i == n];
}

signed main() {

FIO;
solve();

return 0;
}

CF1039D You Are Given a Tree

给定一棵树,对于每个 $k \in [1,n]$ ,计算将树拆分成点数为 $k$ 的链,所有方案中满足条件的最多的链的数量。

这题跟上一题几乎是一个模子里刻出来的。换一下 $query(k)$ 的定义,其他框架直接套上面的做法。

首先DFS树,得到DFS序列,再逆序遍历该序列,就可以做到树上先处理子节点,再处理父节点的效果。基于贪心,一定是先分配子树再分配父节点能够得到的满足条件的链数最多。

对于每个节点 $u$ 维护该点作为起始点在其子树内没分配的最长链长度 $len[u]$,其子节点的 $len$ 最大值 $max1[u]$,次大值 $max2[u]$ 。

对于点 $u$ ,如果 $max1[u]+max2[u]+1\ge k$ 则可以分配一个新链。每个点先更新自己的 $len$ 信息,再更新父节点的 $max1$ 和 $max2$ 信息,即可在 $O(n)$ 的时间复杂度内对给定 $k$ 进行计算。

剩下的无需多言。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
#define FIO cin.tie(0); ios::sync_with_stdio(false)
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define TEST
#define TESTS int t = 1; cin >> t; while (t--)

#if 0
#define int i64
#define inf 0x3f3f3f3f3f3f3f3fLL
#else
#define inf 0x3f3f3f3f
#endif

using namespace std;
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using pii = std::pair<int, int>;

constexpr int N = 1e5 + 10;
constexpr int MOD = 998244353;

int fa[N], id[N], idx;
vector<int> e[N];
int len[N], max1[N], max2[N];
int ans[N];
int n;

void dfs(int u, int f) {
fa[u] = f;
id[++idx] = u;
for (int v : e[u]) {
if (v == f) continue;
dfs(v, u);
}
}

int query(int x) {
int cnt = 0;
for (int i = n; i >= 1; --i) {
int u = id[i], f = fa[u];
if (max1[u] + max2[u] + 1 >= x) {
cnt++;
len[u] = 0;
} else {
len[u] = max1[u] + 1;
}
if (len[u] > max1[f]) {
max2[f] = max1[f];
max1[f] = len[u];
} else if (len[u] > max2[f]) {
max2[f] = len[u];
}
}
for (int i = 1; i <= n; ++i) len[i] = max1[i] = max2[i] = 0;
return cnt;
}

int jump(int i, int val) {
int l = i, r = n + 1;
while (l < r) {
int mid = l + r >> 1;
if (query(mid) < val) r = mid;
else l = mid + 1;
}
return l;
}

void solve() {
cin >> n;
for (int i = 0, u, v; i < n - 1; ++i) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i <= n; ++i) ans[i] = -1;
dfs(1, 0);
int blen = sqrt(n);
for (int i = 1; i <= blen; ++i) {
ans[i] = query(i);
}
for (int i = blen + 1; i <= n; i = jump(i, ans[i])) {
ans[i] = query(i);
}
for (int i = 1; i <= n; ++i) {
if (ans[i] != -1) continue;
ans[i] = ans[i - 1];
}
for (int i = 1; i <= n; ++i) cout << ans[i] << "\n";
}

signed main() {

FIO;
solve();

return 0;
}