6/25, 区间类DP

6/25, 动态规划,区间类DP

  • 求一段区间的解 min/max/count

  • 相比划分类 DP ,区间类 DP 为连续相连的 subproblem ,中间不留空,更有 divide & conquer 的味道。

  • 转移方程通过区间更新

  • 从大到小的更新

Matrix-chain multiplication (算法导论)

  • 给定矩阵向量 [A1, A2, A3 .. An]

  • 矩阵乘法有结合律,所以任意的 parenthesization 结果一样

  • Dimension (a x b) 乘 (b x c) 得到 (a x c) ,总计算量为 a x b x c

  • 可能的括号加法为 Catalan 数,O(2^n),因而搜索不合适。

  • 让 dp[i][j] 代表(i , j) 区间内最优的括号顺序运算次数

    • 符合 optimal substructure,反证法

  • A = rows * cols,假如从 k 分开左右 i <= k < j ,如下 k = 5 时:

  • [A1, A2, A3, A4, A5 || A6, A7,A8,A9]

  • 左子问题为 A1.rows x A5.cols

  • 右子问题为 A6.rows x A9.cols

  • 其中 A5.cols = A6.rows

  • 其总花费为 dp[1,5] + dp[6,9] + A1.rows * A5.cols * A9.cols

至此,对于任意 size (i , j) 的向量区间,我们都可以遍历所有合理 k 的切分点,实现记忆化的 divide & conquer,当前区间的最优解一定由其最优子区间拼接而成。

子问题图如下。其实就是一个 n x n 的矩阵对角线,代表所有的子区间。

上一题的求所有区间最优解进行拼接的思路和 optimal substructure 结构和这题非常像,再贴一遍,感受一下。

不过 Matrix Chain Multiplication 要比这个复杂,时间复杂度为 O(n^3). 毕竟每个切点上会生成两个 subproblems.

public class Solution {
    public int minCut(String s) {
        if(s == null || s.length() <= 1) return 0;
        int len = s.length();

        boolean[][] isPalindrome = new boolean[len][len];
        int[] dp = new int[len];

        for(int i = 0; i < len; i++){
            dp[i] = i;
            for(int j = 0; j <= i; j++){
                if(s.charAt(i) == s.charAt(j) && (i - j < 2 || isPalindrome[j + 1][i - 1])){
                    isPalindrome[i][j] = isPalindrome[j][i] = true;
                    if(j == 0){
                        dp[i] = 0;
                    } else {
                        dp[i] = Math.min(dp[i], dp[j - 1] + 1);
                    }
                }
            }
        }

        return dp[len - 1];
    }
}

著名的区间类 DP 入门题 -- 石子归并

以数组【3,4,5,6】 为例,进行归并的 subproblem graph 如下,path 上的数字代表每一步的 cost.

  • 我们一定可以得到一个 height balanced complete tree,因为每步都只归并两堆石子,只是每步的 branching factor 不同;

  • 所有 subproblem 的叶节点 cost 一致,为所有石子的总和。

这种画法的结构是对的,但是并不合理,因为没有体现出“overlap subproblems”,每个子问题看起来都像独立问题一样。

这样就明显多了,而且和前面的 “区间划分DP” 联系紧密。

自己用记忆化搜索写的第一版,比较粗糙~

  • 每次归并的 cost = 归并两个区间的最优 cost + 两个区间的区间和

  • 因此区间最优用 dp[][] 记忆化搜素,区间和 sum[][] 可以 O(n ^ 2) 时间预处理。

  • O(n^2) preprocess + O(n^2) number of intervals * O(n) number of candidate cuts = O(n^2) + O(n^3)

  • 可以看到,记忆化搜索中,dp[][] 每一个位置只会被遍历一次而且不会再生成新的 subproblems,其时间复杂度和 bottom-up 的迭代循环是一样的。

public class Solution {
    /**
     * @param A an integer array
     * @return an integer
     */
    public int stoneGame(int[] A) {
        // Write your code here
        if(A == null || A.length == 0) return 0;
        int n = A.length;

        // Minimum cost to merge interval dp[i][j]
        int[][] dp = new int[n][n];
        int[][] sum = new int[n][n];

        // Pre-process interval sum
        for(int i = 0; i < n; i++){
            for(int j = i; j >= 0; j--){
                if(j == i) sum[i][j] = A[i];
                else sum[i][j] = sum[j][i] = A[j] + sum[j + 1][i];
            }
        }

        return memoizedSearch(0, n - 1, A, dp, sum);
    }

    private int memoizedSearch(int start, int end, int[] A, int[][] dp, int[][] sum){
        if(start > end) return 0;
        if(start == end) return 0;
        if(start + 1 == end) return A[start] + A[end];

        if(dp[start][end] != 0) return dp[start][end];

        int min = Integer.MAX_VALUE;
        for(int i = start; i < end; i++){
            int cost = memoizedSearch(start, i, A, dp, sum) + memoizedSearch(i + 1, end, A, dp, sum) + sum[start][i] + sum[i + 1][end];
            min = Math.min(min, cost);
        }

        dp[start][end] = min;

        return min;
    }
}

对于 interval sum ,根据搜索结构可以做一个显而易见的优化,因为每次 split 的 start, pivot, end 我们都知道,而且合并(start, end) 区间的两堆石子,最终的区间和一定为 (start, end) 的区间和,用一维的 prefix sum 数组就可以了。

用 prefix sum 数组要记得初始化时候的 int[n + 1] zero padding,还有取值时候对应的 sum[end + 1] - sum[start + 1 - 1] offset.

public class Solution {
    /**
     * @param A an integer array
     * @return an integer
     */
    public int stoneGame(int[] A) {
        // Write your code here
        if(A == null || A.length == 0) return 0;
        int n = A.length;

        // Minimum cost to merge interval dp[i][j]
        int[][] dp = new int[n][n];
        int[] sum = new int[n + 1];

        // Pre-process interval sum
        for(int i = 0; i < n; i++){
            sum[i + 1] = sum[i] + A[i];
        }

        return memoizedSearch(0, n - 1, A, dp, sum);
    }

    private int memoizedSearch(int start, int end, int[] A, int[][] dp, int[] sum){
        if(start > end) return 0;
        if(start == end) return 0;
        if(start + 1 == end) return A[start] + A[end];

        if(dp[start][end] != 0) return dp[start][end];

        int min = Integer.MAX_VALUE;
        for(int i = start; i < end; i++){
            int cost = memoizedSearch(start, i, A, dp, sum) + memoizedSearch(i + 1, end, A, dp, sum) + sum[end + 1] - sum[start];
            min = Math.min(min, cost);
        }

        dp[start][end] = min;

        return min;
    }
}

这题和石子归并很像,更像 Matrix Chain Multiplication. 都是区间类 DP,而且原数组会随着操作逐渐减小,动态变化。

然而就算是动态变化的数组,变化的也并不是状态,而只是子状态的范围,记忆化搜索中的 (start, end).

所以这题的难点在于,如何在动态变化的数组中,依然正确定义并计算 subproblem.

问题一:边界气球

  • 考虑到计算方式为相邻气球乘积,可以两边放上 1 来做 padding,不会影响最后结果的正确性。

问题二:子问题返回后,如何处理相邻气球?

  • 在stone game中,最后融合两个区间要靠区间和;

  • 在busrt balloon中,两个区间返回时已经都被爆掉了,融合区间靠的是两个区间最外面相邻的气球。(因此 padding 才很重要)

  • 正如 Matrix Chain Multiplication 中,左右区间相乘结束返回时,最后融合那步的 cost = A(start).rows * A(k).cols * A(end).cols

public class Solution {
    public int maxCoins(int[] nums) {
        if(nums == null || nums.length == 0) return 0;
        int n = nums.length;
        int[] arr = new int[n + 2];
        arr[0] = 1;
        arr[n + 1] = 1;
        for(int i = 0; i < n; i++){
            arr[i + 1] = nums[i];
        }

        int[][] dp = new int[n + 2][n + 2];


        return memoizedSearch(1, n, arr, dp);
    }

    private int memoizedSearch(int start, int end, int[] arr, int[][] dp){
        if(dp[start][end] != 0) return dp[start][end];

        int max = 0;
        for(int i = start; i <= end; i++){
            int cur = arr[start - 1] * arr[i] * arr[end + 1];
            int left = memoizedSearch(start, i - 1, arr, dp);
            int right = memoizedSearch(i + 1, end, arr, dp);

            max = Math.max(max, cur + left + right);
        }

        dp[start][end] = max;

        return max;
    }
}

弄了半天写了个错误的版本,只考虑了 cut 位置对齐的情况,可以过 157 / 281 个 test cases, 然而像 "abc" 和 "bca" 这种起始位置就不对齐的就会出错。

a | bc

bc | a

所以很显然的,O(n^3) 泡汤了~

http://www.blogjava.net/sandy/archive/2013/05/22/399605.html

下面的是基于九章答案的记忆化搜素解法,改了我好久。。。

改写过程中一直在犯的错误是,在 subcall 中 s1,s2 已经是 substring 的情况下,依然用上一层传过来的参数作为参考去切分新的 substring. 这是错误的,只需要在参数中得到的 s1, s2 上切割就好了,因为传进来的并不是最原始的 string.

每一层 search 中,参数里面的 start / end / n 代表着相对于原始 string 的位置,用于查询和记录 DP; 而这一层的 s1, s2 又是新的子问题,除了涉及传参和DP之外的地方,都以 s1, s2 为准。

s.substring(i,j) 中,最后截取的 substring 长度就是 j - i.

public class Solution {
    public boolean isScramble(String s1, String s2) {
        if(!isAnagram(s1, s2)) return false;

        int n = s1.length();

        // dp[i][j][k] : s1 starting from index i, s2 string from index j
        //               pick k chars, are we getting scrambled strings ?

        // 0 : not searched, 1 : true, -1 : false;
        int[][][] dp = new int[n][n][n + 1];

        return isScrambleMemo(s1, s2, 0, 0, n, dp);
    }

    private boolean isScrambleMemo(String s1, String s2, int oneStart, int twoStart, int n, int[][][] dp){
        if(dp[oneStart][twoStart][n] != 0) return (dp[oneStart][twoStart][n] == 1) ? true : false;

        if(s1.equals(s2)){
            dp[oneStart][twoStart][n] = 1;
            return true;
        }
        if(!isAnagram(s1, s2)){
            dp[oneStart][twoStart][n] = -1;
            return false;
        }


        // i = number of characters we take
        for(int i = 1; i < s1.length() ; i++){
            String s1Left = s1.substring(0, i);
            String s1Right = s1.substring(i, s1.length());

            String leftSideS2Left = s2.substring(0, i);
            String leftSideS2Right = s2.substring(i, s2.length());

            String rightSideS2Left = s2.substring(0, s2.length() - i);
            String rightSideS2Right = s2.substring(s2.length() - i, s2.length());

            if(isScrambleMemo(s1Left, leftSideS2Left, oneStart, twoStart, i, dp) && 
                isScrambleMemo(s1Right, leftSideS2Right, oneStart + i, twoStart + i, n - i, dp)) {

                dp[oneStart][twoStart][n] = 1;
                return true;    
            }
            if(isScrambleMemo(s1Left, rightSideS2Right, oneStart, twoStart + n - i, i, dp) && 
               isScrambleMemo(s1Right, rightSideS2Left, oneStart + i, twoStart, n - i, dp)) {

                dp[oneStart][twoStart][n] = 1;
                return true;
            }
        }

        dp[oneStart][twoStart][n] = -1;
        return false;
    }

    // Assuming only lower case letters
    private boolean isAnagram(String s1, String s2){
        if(s1.length() != s2.length()) return false;
        int[] hash = new int[26];
        for(int i = 0; i < s1.length(); i++){
            int index = s1.charAt(i) - 'a';
            hash[index] ++;
        }
        for(int i = 0; i < s2.length(); i++){
            int index = s2.charAt(i) - 'a';
            hash[index] --;
            if(hash[index] < 0) return false;
        }
        return true;
    }
}

Last updated