使用线段树(SegmentTree)求解区间和

SegmentTree算法示例

Posted by Jeremy Song on 2022-07-30
Estimated Reading Time 2 Minutes
Words 621 In Total
Viewed Times

介绍

线段树(segment tree),顾名思义, 是用来存放给定区间(segment, or interval)内对应信息的一种数据结构。与树状数组(binary indexed tree)相似,线段树也用来处理数组相应的区间查询(range query)和元素更新(update)操作。与树状数组不同的是,线段树不止可以适用于区间求和的查询,也可以进行区间最大值,区间最小值(Range Minimum/Maximum Query problem)或者区间异或值的查询。

从数据结构的角度来说,线段树是用一个完全二叉树来存储对应于其每一个区间(segment)的数据。该二叉树的每一个结点中保存着相对应于这一个区间的信息。同时,线段树所使用的这个二叉树是用一个数组保存的,与堆(Heap)的实现方式相同。

算法示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import java.util.Arrays;

public class SegmentTree {
int[] nums;
int n;
int[] st;

public SegmentTree(int[] nums) {
this.nums = nums;
this.n = nums.length;
this.st = new int[2 * n];// segment tree 的size是nums的2倍即可
init();
}

void init() {
// for (int i = n; i < n * 2; i++) {
// st[i] = nums[i - n];
// }
System.arraycopy(nums, 0, st, n, n);
for (int i = n - 1; i > 0; i--) {
st[i] = st[i << 1] + st[(i << 1) | 1]; // 因为 (i << 1) 永远是偶数,所以 (i << 1) | 1 = (i << 1) + 1
// st[i << 1] 和 st[(i << 1) | 1] 分别表示左右孩子
}
}

void update(int i, int val) {
i += n; // 转成st的下标
st[i] = val; // 更新叶子
while (i > 0) {
// st[i ^ 1] 表示 st[i] 的兄弟
st[i >> 1] = st[i] + st[i ^ 1];
i >>= 1;
}
}

/**
* 计算nums数组left到right的和
*
* @param left nums的下标,从0开始
* @param right nums的下标,从0开始
* @return
*/
int between(int left, int right) {
// 转成到st的下标
left += n;
right += n;

int res = 0;
while (left <= right) {
if ((left & 1) == 1) res += st[left++]; // st[left]是右子节点(索引是奇数)
if ((right & 1) == 0) res += st[right--]; // st[right]是左子节点(索引是偶数)

left >>= 1;
right >>= 1;
}

return res;
}

// test code below

public static void main(String[] args) {
int[] input = new int[10];
Arrays.fill(input, 1);

SegmentTree segTree = new SegmentTree(input);

segTree.update(4, 1);
segTree.update(9, 10);

System.out.println(segTree.st.length);
System.out.println("=============");
System.out.println(segTree.between(1, 2));
System.out.println(segTree.between(3, 5));
System.out.println(segTree.between(3, 9));
System.out.println(segTree.between(0, 9));
}
}

欢迎关注我的公众号 须弥零一,跟我一起学习IT知识。


如果您喜欢此博客或发现它对您有用,则欢迎对此发表评论。 也欢迎您共享此博客,以便更多人可以参与。 如果博客中使用的图像侵犯了您的版权,请与作者联系以将其删除。 谢谢 !