Segment Tree 的应用

  • 最适合用 Segment tree 的情形最好同时满足以下三点:

    • 区间查找 min/max

    • 频繁 update

    • 频繁 query

非常不错的一道题,Segment Tree的常用操作都考到了。

这题更快的做法是用 Binary Indexed Tree,有空我研究下。

public class NumArray {
    private class SegmentTreeNode{
        int start;
        int end;
        int sum;
        SegmentTreeNode left, right;

        public SegmentTreeNode(){}
        public SegmentTreeNode(int start, int end, int sum){
            this.start = start;
            this.end = end;
            this.sum = sum;
            this.left = null;
            this.right = null;
        }
    }

    SegmentTreeNode root;

    public NumArray(int[] nums) {
        root = buildTree(nums, 0, nums.length - 1);
    }

    private SegmentTreeNode buildTree(int[] nums, int start, int end){
        if(nums == null || nums.length == 0) return null;
        if(start == end) return new SegmentTreeNode(start, end, nums[start]);

        int mid = start + (end - start) / 2;

        SegmentTreeNode left = buildTree(nums, start, mid);
        SegmentTreeNode right = buildTree(nums, mid + 1, end);

        SegmentTreeNode root = new SegmentTreeNode(start, end, left.sum + right.sum);
        root.left = left;
        root.right = right;

        return root;
    }

    void update(int i, int val) {
        update(root, i, val);
    }

    private void update(SegmentTreeNode root, int i, int val){
        if(root == null) return;
        if(i < root.start) return;
        if(i > root.end) return;

        if(root.start == i && root.end == i) {
            root.sum = val;
            return;
        }

        update(root.left, i, val);
        update(root.right, i, val);

        root.sum = root.left.sum + root.right.sum;
    }

    public int sumRange(int i, int j) {
        return sumRange(root, i, j);
    }

    private int sumRange(SegmentTreeNode root, int i, int j){
        if(root == null) return 0;
        if(i > root.end) return 0;
        if(j < root.start) return 0;

        i = Math.max(i, root.start);
        j = Math.min(j, root.end);

        if(root.start == i && root.end == j) return root.sum;

        int left = sumRange(root.left, i, j);
        int right = sumRange(root.right, i, j);

        return left + right;
    }
}


// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.update(1, 10);
// numArray.sumRange(1, 2);

这题从类型上讲,看着和上一题非常像。然而其实这题因为需要的操作比较简单,其实就是一个 prefix sum 数组的 dp ...

教育了我们 segment tree 虽屌,也不要一言不合就随便用。。

  • 最适合用 Segment tree 的情形最好同时满足以下三点:

    • 区间查找

    • 频繁 update

    • 频繁 query

  • 在只有区间没有 update 的情况下,其实是一个一维/二维的 DP 问题,并不能体现出 segment tree 的优势。

  • 前缀和数组记得在最前面加上 sum = 0 的 padding.

public class NumArray {
    int[] prefixSum;

    public NumArray(int[] nums) {
        if(nums == null || nums.length == 0) return;

        prefixSum = new int[nums.length + 1];
        prefixSum[0] = 0;
        prefixSum[1] = nums[0];
        for(int i = 1; i < nums.length; i++){
            prefixSum[i + 1] = prefixSum[i] + nums[i];
        }
    }

    public int sumRange(int i, int j) {
        return prefixSum[j + 1] - prefixSum[i];
    }
}


// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.sumRange(1, 2);

这道题当然可以把矩阵降维之后用 segment tree 解,把一个 region 拆分成若干个 interval of rows 然后把结果加起来,但是很慢。

这题既体现了 segment tree的应用,又暴露了 segment tree的问题。

  • width = m, height = n, 现有 1D segment tree 的复杂度

    • build O(mn)

    • update O(log(mn))

    • query O(n * log (mn))

  • 因为这题更适合用 binary index tree 解,另一个教程贴在这里,还有这里,加上这个陈老师推荐的中文帖子

public class NumMatrix {

    private class SegmentTreeNode{
        int start;
        int end;
        int sum;
        SegmentTreeNode left;
        SegmentTreeNode right;
        public SegmentTreeNode(){}
        public SegmentTreeNode(int start, int end, int sum){
            this.start = start;
            this.end = end;
            this.sum = sum;
            this.left = null;
            this.right = null;
        }
    }

    int width;
    int height;
    SegmentTreeNode root;

    private int getIndex(int x, int y){
        return x * width + y;
    }

    public NumMatrix(int[][] matrix) {
        if(matrix == null || matrix.length == 0) return;

        height = matrix.length;
        width = matrix[0].length;

        root = buildTree(matrix, 0, width * height - 1);
    }

    private SegmentTreeNode buildTree(int[][] matrix, int start, int end){
        if(start == end) return new SegmentTreeNode(start, end, matrix[start / width][start % width]);

        int mid = start + (end - start) / 2;
        SegmentTreeNode left = buildTree(matrix, start, mid);
        SegmentTreeNode right = buildTree(matrix, mid + 1, end);
        SegmentTreeNode root = new SegmentTreeNode(start, end, left.sum + right.sum);
        root.left = left;
        root.right = right;

        return root;
    }

    public void update(int row, int col, int val) {
        int index = getIndex(row, col);
        update(root, index, val);
    }

    private void update(SegmentTreeNode root, int index, int val){
        if(root == null) return ;
        if(index < root.start) return;
        if(index > root.end) return;

        if(root.start == index && root.end == index){
            root.sum = val;
            return;
        }

        update(root.left, index, val);
        update(root.right, index, val);

        root.sum = root.left.sum + root.right.sum;
    }

    public int sumRegion(int row1, int col1, int row2, int col2) {
        int sum = 0;
        if(col1 == 0 && col2 == width - 1){
            sum = querySum(root, getIndex(row1, col1), getIndex(row2, col2));
        } else {
            for(; row1 <= row2; row1++){
                sum += querySum(root, getIndex(row1, col1), getIndex(row1, col2));
            }
        }


        return sum;
    }

    private int querySum(SegmentTreeNode root, int start, int end){
        if(root == null) return 0;
        if(start > root.end) return 0;
        if(end < root.start) return 0;

        start = Math.max(start, root.start);
        end = Math.min(end, root.end);

        if(start == root.start && end == root.end) return root.sum;

        int left = querySum(root.left, start, end);
        int right = querySum(root.right, start, end);

        return left + right;
    }
}


// Your NumMatrix object will be instantiated and called as such:
// NumMatrix numMatrix = new NumMatrix(matrix);
// numMatrix.sumRegion(0, 1, 2, 3);
// numMatrix.update(1, 1, 10);
// numMatrix.sumRegion(1, 2, 3, 4);

Last updated

Was this helpful?