0
点赞
收藏
分享

微信扫一扫

Java DFS记忆化搜索

DFS之记忆化搜索

  1. 比如说一张图,有时候dfs不断递归的时候,发现不能达到目标状态(位置),需要回到原来状态,但是递归的时候中间结果是没有记录的,比如一个递归下去,当前的dfs有几个方向,但是发现这些方向都不能使用,于是返回上一层,由于没有中间态的记录,在上一层的另一个方向的递归下一层又是新的方向,就造了太多计算。
  2. 再说个极端点的例子,就类似于递归到目标态跟前了,但就是达不到,没办法只能一直往回走,同时在往回走的时候还会再试试当前点其他方向,再继续往下递归,然后下一个递归又有很多个方向,明知道已经完成不了目标但还是继续递归下去,这样就有大量的重复运算,一旦数据量大点时间就会爆炸,所以要引进记忆化搜索,记录递归过程的中间结果,方便以后判断哪个方向已经试过了,行不通!
    在这里插入图片描述
    反正就差不多这个意思吧,可能不太精确,会有很多不必要的计算。

记忆化要点

  1. 加入记忆化的时候,就需要把DFS方法中的可变参数作为维度,把DFS返回的值作为存储值。一般使用数组来存,例如 int[x][y],自己设定0为为计算,-1为false,1为true。再看一下 Leetcode.403这题,这题的维度是石头列表的下标和已跳跃步数,这两都是可变参数,返回的值就是是否能走通。当然在这个数组空间不好判断的时候,我们可以直接用哈希表来存储,也可以降低爆空间风险,如果数据量小的话就用一般数组方便些,还有会遇到可变参数为负数的时候,就用哈希表来存。
    看看代码:
    未加入记忆化搜索的dfs:(时间会爆)
class Solution {
    Map<Integer, Integer> map = new HashMap<>();
    public boolean canCross(int[] ss) {
        int n = ss.length;
        // 将石子信息存入哈希表
        // 为了快速判断是否存在某块石子,以及快速查找某块石子所在下标
        for (int i = 0; i < n; i++) {
            map.put(ss[i], i);
        }
        // check first step
        // 根据题意,第一步是固定经过步长 1 到达第一块石子(下标为 1)
        if (!map.containsKey(1)) return false;
        return dfs(ss, ss.length, 1, 1);
    }
    /**
     * 判定是否能够跳到最后一块石子
     * @param ss 石子列表【不变】
     * @param n  石子列表长度【不变】
     * @param u  当前所在的石子的下标
     * @param k  上一次是经过多少步跳到当前位置的
     * @return 是否能跳到最后一块石子
     */
    boolean dfs(int[] ss, int n, int u, int k) {
        if (u == n - 1) return true;
        for (int i = -1; i <= 1; i++) {
            // 如果是原地踏步的话,直接跳过
            if (k + i == 0) continue;
            // 下一步的石子理论编号
            int next = ss[u] + k + i;
            // 如果存在下一步的石子,则跳转到下一步石子,并 DFS 下去
            if (map.containsKey(next)) {
                boolean cur = dfs(ss, n, map.get(next), k + i);
                if (cur) return true;
            }
        }
        return false;
    }
}

加入记忆化搜索的dfs(正常):

class Solution {
    Map<Integer, Integer> map = new HashMap<>();
    // int[][] cache = new int[2009][2009];
    Map<String, Boolean> cache = new HashMap<>();
    public boolean canCross(int[] ss) {
        int n = ss.length;
        for (int i = 0; i < n; i++) {
            map.put(ss[i], i);
        }
        // check first step
        if (!map.containsKey(1)) return false;
        return dfs(ss, ss.length, 1, 1);
    }
    boolean dfs(int[] ss, int n, int u, int k) {
        String key = u + "_" + k;
        // if (cache[u][k] != 0) return cache[u][k] == 1;
        if (cache.containsKey(key)) return cache.get(key);
        if (u == n - 1) return true;
        for (int i = -1; i <= 1; i++) {
            if (k + i == 0) continue;
            int next = ss[u] + k + i;
            if (map.containsKey(next)) {
                boolean cur = dfs(ss, n, map.get(next), k + i);
                // cache[u][k] = cur ? 1 : -1;
                cache.put(key, cur);
                if (cur) return true;
            }
        }
        // cache[u][k] = -1;
        cache.put(key, false);
        return false;
    }
}
  1. 再看一下 Leetcode.576 ,像这个用例
    在这里插入图片描述
    就会有重复往返的一条路,由于数据量并不大,就用数组来存,免得map的get和put方法一直调用,int[x][y][step] ,x和y就为坐标,step为走的步数,存储的值为0,说明还未计算,-1为不可达,1为可以出去。先看代码:
class Solution {
    int[][] dirs = new int[][]{{1,0},{0,1},{-1,0},{0,-1}};//四个方向
    int[][][] cache;
    int MOD = (int) 1e9 + 7;
    int m,n;
    public int findPaths(int _m, int _n, int maxMove, int startRow, int startColumn) {
        m = _m;n = _n;
        cache = new int[m][n][maxMove+1];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j <n; j++) {
                for (int k = 0; k <= maxMove; k++) {
                    cache[i][j][k] = -1;//初始化全为不可达,为计算状态
                }
            }
        }
        return dfs(startRow,startColumn,maxMove);
    }
    int dfs(int x,int y,int step){
        if(x < 0 || x>=m || y < 0 || y >= n) return 1;//判断是否出界
        if(step == 0) return 0;//未出界并且步数为0则返回0,表示已经计算过但不能出界,否则继续走下去,往回走也行
        if(cache[x][y][step] != -1) return cache[x][y][step];
        int cnt = 0;//每次递归都重新定义一个计数,防止互相干扰
        for(int[] dir:dirs){
            int nx = x + dir[0],ny = y + dir[1];
            cnt += dfs(nx, ny, step - 1); //累加有多少条路
            cnt %= MOD;
        }
        cache[x][y][step] = cnt;//无路可走,记录一下。
        return cnt;
    }
}

当然这里的记忆数组同时还可以判重,防止走重复的路,有点跟BFS的visit类似。

举报

相关推荐

0 条评论