Segment Tree 基础操作

  • Segment Tree 是一个 Full Binary Tree,每个节点子节点数量为 0 或 2.

  • 给定含 n 个元素的数组区间,对应的 Segmeng Tree 节点数量最多为 2n - 1.

  • Build O(n)

  • Update O(log n)

  • Query O(log n)

  • 偶数长度区间会拆出两个偶数OR奇数长度区间;奇数长度区间会拆出一个偶数+一个奇数长度区间, 最终以长度为 1 的区间为叶节点。

比较基本,就是很普通的建二叉树。

public class Solution {
    /**
     *@param start, end: Denote an segment / interval
     *@return: The root of Segment Tree
     */
    public SegmentTreeNode build(int start, int end) {
        // write your code here
        if(start > end) return null;
        SegmentTreeNode root = new SegmentTreeNode(start, end);
        if(start == end) return root;
        int mid = start + (end - start) / 2;
        root.left = build(start, mid);
        root.right = build(mid + 1, end);

        return root;
    }
}

这个例子更具有实际意义一点,因为这个建出来的 Segment Tree 已经可以用了。

  • 对于给定无序数组,Segment Tree 可以利用递归在 O(n) 时间建立。

  • Segment Tree 总共节点个数为 2n - 1,建立每个节点操作时间为 O(1).

  • 另一个角度看的话,总共有 n 个叶节点,那么对应的 perfect binary tree 总节点个数就是 2n - 1.

  • 结构和 quick sort / merge sort 非常类似,每一层包含的所有区间覆盖整个数组。

public class Solution {
    /**
     *@param A: a list of integer
     *@return: The root of Segment Tree
     */
    public SegmentTreeNode build(int[] A) {
        // write your code here
        return buildHelper(A, 0, A.length - 1);
    }

    private SegmentTreeNode buildHelper(int[] A, int start, int end){
        if(start > end) return null;
        if(start == end) return new SegmentTreeNode(start, end, A[start]);

        int mid = start + (end - start) / 2;

        SegmentTreeNode left = buildHelper(A, start, mid);
        SegmentTreeNode right = buildHelper(A, mid + 1, end);

        SegmentTreeNode root = new SegmentTreeNode(start, end, Math.max(left.max, right.max));
        root.left = left;
        root.right = right;

        return root;
    }
}

第一次写的时候翻了个错误,只考虑了改的值变大,更新新的 max 的情况,而没考虑更新的值可能是原来某个区间节点的 max,变小之要更新整个到 Root 路径的 max.

  • 每次 update 都是一次 top-down 的递归,实际更新是由最小的区间单位 bottom-up 一直到 root 的路径更新。

public class Solution {
    /**
     *@param root, index, value: The root of segment tree and 
     *@ change the node's value with [index, index] to the new given value
     *@return: void
     */
    public void modify(SegmentTreeNode root, int index, int value) {
        // write your code here
        if(root == null) return;
        if(index < root.start || index > root.end) return;

        // Segment Tree 不会出现单独分叉的节点,所以到叶节点可以直接返回。
        if(index == root.start && index == root.end){
            root.max = value;
            return;
        }

        modify(root.left, index, value);
        modify(root.right, index, value);

        root.max = Math.max(root.left.max, root.right.max);
    }
}

对于区间覆盖问题可以多画点图,考虑下所有可能的情况。

对于每一层的查询,最多只会分裂成两个节点;

  • 如果目标区间完全不在 root 的区间里,直接返回;

  • 否则设有效查询区间为

    • max(root.start, query.start)

    • min(root.end, query.end)

  • 处理下正确叶节点的位置,递归处理。

public class Solution {
    /**
     *@param root, start, end: The root of segment tree and 
     *                         an segment / interval
     *@return: The maximum number in the interval [start, end]
     */
    public int query(SegmentTreeNode root, int start, int end) {
        // write your code here
        if(end < root.start) return Integer.MIN_VALUE;
        if(start > root.end) return Integer.MIN_VALUE;

        start = Math.max(start, root.start);
        end = Math.min(end, root.end);

        if(start == root.start && end == root.end) return root.max;

        int left = query(root.left, start, end);
        int right = query(root.right, start, end);

        return Math.max(left, right);
    }
}

这题和上一题没有任何区别。。你是想告诉我 TreeNode 里除了 max/min 还可以存 count 是吗。。。

public class Solution {
    /**
     *@param root, start, end: The root of segment tree and 
     *                         an segment / interval
     *@return: The count number in the interval [start, end]
     */
    public int query(SegmentTreeNode root, int start, int end) {
        // write your code here
        if(root == null) return 0;
        if(end < root.start) return 0;
        if(start > root.end) return 0;

        start = Math.max(start, root.start);
        end = Math.min(end, root.end);

        if(root.start == start && root.end == end) return root.count;

        int left = query(root.left, start, end);
        int right = query(root.right, start, end);

        return left + right;

    }
}

Last updated