Segment Tree 的应用
最适合用 Segment tree 的情形最好同时满足以下三点:
区间查找 min/max
频繁 update
频繁 query
非常不错的一道题,Segment Tree的常用操作都考到了。
这题更快的做法是用 Binary Indexed Tree,有空我研究下。
public class NumArray {
private class SegmentTreeNode{
int start;
int end;
int sum;
SegmentTreeNode left, right;
public SegmentTreeNode(){}
public SegmentTreeNode(int start, int end, int sum){
this.start = start;
this.end = end;
this.sum = sum;
this.left = null;
this.right = null;
}
}
SegmentTreeNode root;
public NumArray(int[] nums) {
root = buildTree(nums, 0, nums.length - 1);
}
private SegmentTreeNode buildTree(int[] nums, int start, int end){
if(nums == null || nums.length == 0) return null;
if(start == end) return new SegmentTreeNode(start, end, nums[start]);
int mid = start + (end - start) / 2;
SegmentTreeNode left = buildTree(nums, start, mid);
SegmentTreeNode right = buildTree(nums, mid + 1, end);
SegmentTreeNode root = new SegmentTreeNode(start, end, left.sum + right.sum);
root.left = left;
root.right = right;
return root;
}
void update(int i, int val) {
update(root, i, val);
}
private void update(SegmentTreeNode root, int i, int val){
if(root == null) return;
if(i < root.start) return;
if(i > root.end) return;
if(root.start == i && root.end == i) {
root.sum = val;
return;
}
update(root.left, i, val);
update(root.right, i, val);
root.sum = root.left.sum + root.right.sum;
}
public int sumRange(int i, int j) {
return sumRange(root, i, j);
}
private int sumRange(SegmentTreeNode root, int i, int j){
if(root == null) return 0;
if(i > root.end) return 0;
if(j < root.start) return 0;
i = Math.max(i, root.start);
j = Math.min(j, root.end);
if(root.start == i && root.end == j) return root.sum;
int left = sumRange(root.left, i, j);
int right = sumRange(root.right, i, j);
return left + right;
}
}
// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.update(1, 10);
// numArray.sumRange(1, 2);
这题从类型上讲,看着和上一题非常像。然而其实这题因为需要的操作比较简单,其实就是一个 prefix sum 数组的 dp ...
教育了我们 segment tree 虽屌,也不要一言不合就随便用。。
最适合用 Segment tree 的情形最好同时满足以下三点:
区间查找
频繁 update
频繁 query
在只有区间没有 update 的情况下,其实是一个一维/二维的 DP 问题,并不能体现出 segment tree 的优势。
前缀和数组记得在最前面加上 sum = 0 的 padding.
public class NumArray {
int[] prefixSum;
public NumArray(int[] nums) {
if(nums == null || nums.length == 0) return;
prefixSum = new int[nums.length + 1];
prefixSum[0] = 0;
prefixSum[1] = nums[0];
for(int i = 1; i < nums.length; i++){
prefixSum[i + 1] = prefixSum[i] + nums[i];
}
}
public int sumRange(int i, int j) {
return prefixSum[j + 1] - prefixSum[i];
}
}
// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.sumRange(1, 2);
这道题当然可以把矩阵降维之后用 segment tree 解,把一个 region 拆分成若干个 interval of rows 然后把结果加起来,但是很慢。
这题既体现了 segment tree的应用,又暴露了 segment tree的问题。
width = m, height = n, 现有 1D segment tree 的复杂度
build O(mn)
update O(log(mn))
query O(n * log (mn))
因为这题更适合用 binary index tree 解,另一个教程贴在这里,还有这里,加上这个陈老师推荐的中文帖子
public class NumMatrix {
private class SegmentTreeNode{
int start;
int end;
int sum;
SegmentTreeNode left;
SegmentTreeNode right;
public SegmentTreeNode(){}
public SegmentTreeNode(int start, int end, int sum){
this.start = start;
this.end = end;
this.sum = sum;
this.left = null;
this.right = null;
}
}
int width;
int height;
SegmentTreeNode root;
private int getIndex(int x, int y){
return x * width + y;
}
public NumMatrix(int[][] matrix) {
if(matrix == null || matrix.length == 0) return;
height = matrix.length;
width = matrix[0].length;
root = buildTree(matrix, 0, width * height - 1);
}
private SegmentTreeNode buildTree(int[][] matrix, int start, int end){
if(start == end) return new SegmentTreeNode(start, end, matrix[start / width][start % width]);
int mid = start + (end - start) / 2;
SegmentTreeNode left = buildTree(matrix, start, mid);
SegmentTreeNode right = buildTree(matrix, mid + 1, end);
SegmentTreeNode root = new SegmentTreeNode(start, end, left.sum + right.sum);
root.left = left;
root.right = right;
return root;
}
public void update(int row, int col, int val) {
int index = getIndex(row, col);
update(root, index, val);
}
private void update(SegmentTreeNode root, int index, int val){
if(root == null) return ;
if(index < root.start) return;
if(index > root.end) return;
if(root.start == index && root.end == index){
root.sum = val;
return;
}
update(root.left, index, val);
update(root.right, index, val);
root.sum = root.left.sum + root.right.sum;
}
public int sumRegion(int row1, int col1, int row2, int col2) {
int sum = 0;
if(col1 == 0 && col2 == width - 1){
sum = querySum(root, getIndex(row1, col1), getIndex(row2, col2));
} else {
for(; row1 <= row2; row1++){
sum += querySum(root, getIndex(row1, col1), getIndex(row1, col2));
}
}
return sum;
}
private int querySum(SegmentTreeNode root, int start, int end){
if(root == null) return 0;
if(start > root.end) return 0;
if(end < root.start) return 0;
start = Math.max(start, root.start);
end = Math.min(end, root.end);
if(start == root.start && end == root.end) return root.sum;
int left = querySum(root.left, start, end);
int right = querySum(root.right, start, end);
return left + right;
}
}
// Your NumMatrix object will be instantiated and called as such:
// NumMatrix numMatrix = new NumMatrix(matrix);
// numMatrix.sumRegion(0, 1, 2, 3);
// numMatrix.update(1, 1, 10);
// numMatrix.sumRegion(1, 2, 3, 4);
Last updated
Was this helpful?