分享|c++ 暴力深度优先搜索到动态规划记忆化搜索之间的华丽转身 - python @cache 的模仿
1810
2024.04.26
2024.10.25
发布于 未知归属地

[TOC]
第一次看到python @cache 魔法的时候是在看 @灵茶山山神 在B站讲题, 当时的反应就是一脸震惊, 看着我c++ dfs 没有记忆化的代码陷入了沉思。

那么c++ 可不可以相对方便的像 python那样 稍微加亿点点代码实现 从暴力搜索到 记忆化搜索的 华丽的转身了,那就有了下面这个拙劣的模仿

附加代码

class null_param {
};

template<typename Sig, class F>
class memoize_helper;

template<typename R, typename... Args, class F>
class memoize_helper<R(Args...), F> {
private:
    using function_type = F;
    using args_tuple_type = tuple<Args...>;

    function_type f;
    mutable map<args_tuple_type, R> cache;

public:
    template<class Function>
    memoize_helper(Function &&f, null_param) : f(std::forward<Function>(f)) {}

    memoize_helper(const memoize_helper &other) : f(other.f) {}

    template<class ...InnerArgs>
    R operator()(InnerArgs &&... args) const {
        auto args_tuple = make_tuple(std::forward<InnerArgs>(args)...);
        auto it = cache.find(args_tuple);
        if (it != cache.end()) {
            return it->second;
        }
        return cache[args_tuple] = f(*this, std::forward<InnerArgs>(args)...);
    }
};


template<size_t Dim, typename R>
class cache_vec_helper : public cache_vec_helper<Dim - 1, R> {
public:
    using type = vector<typename cache_vec_helper<Dim - 1, R>::type>;
};

template<typename R>
class cache_vec_helper<0, R> {
public:
    using type = R;
};

template<size_t Dim, typename R, class F>
class memoize_nvec_helper;

template<size_t Dim, typename R, typename ...Args, class F>
class memoize_nvec_helper<Dim, R(Args...), F> {
public:
    using function_type = F;
    function_type f;
    mutable typename cache_vec_helper<Dim, R>::type cache;
    array<size_t, Dim> szs;
    R dv;

    template<class Function>
    memoize_nvec_helper(Function &&f, vector<int>& vec_sz, R r) : f(std::forward<Function>(f)), dv(r) {
        for (int i = 0; i < Dim; ++i) {
            szs[i] = vec_sz[i];
        }
        initialize_cache(cache, szs);
    }

    template<class ...InnerArgs>
    R operator()(InnerArgs &&... args) {
        static_assert(sizeof...(args) == Dim, "Number of arguments must match the dimension.");
        std::array<int, Dim> indices = {std::forward<InnerArgs>(args)...};

        // 检查是否越界
        for (int i = 0; i < Dim; ++i) {
            if (indices[i] < 0 || indices[i] >= static_cast<int>(szs[i])) {
                return dv;
                //throw std::out_of_range("Index out of bounds");
            }
        }
        // 访问或修改缓存
        return access_cache(cache, indices, std::make_index_sequence<Dim>{}, std::forward<InnerArgs>(args)...);
    }

private:
    // 辅助函数用于递归地初始化 cache
    template<typename T, size_t D>
    void initialize_cache(std::vector<T> &vec, const std::array<size_t, D> &sizes, size_t current_dim = 0) {
        vec.resize(sizes[current_dim]);
        if (current_dim + 1 < D) {
            for (auto &sub_vec: vec) {
                initialize_cache(sub_vec, sizes, current_dim + 1);
            }
        }
    }

    // 特化用于处理最后一个维度
    void initialize_cache(std::vector<R> &vec, const std::array<size_t, Dim> &sizes, size_t current_dim = 0) {
        vec.resize(sizes[current_dim], dv);
    }

    // 辅助函数用于递归地访问或修改缓存
    template<size_t... Is, class ...InnerArgs>
    R access_cache(typename cache_vec_helper<Dim, R>::type &cache, const std::array<int, Dim> &indices,
                   std::index_sequence<Is...>, InnerArgs &&... args) {
        return access_cache_impl(cache, indices, std::make_index_sequence<Dim>{}, std::forward<InnerArgs>(args)...);
    }

    // 实际递归访问缓存的实现
    template<size_t... Is, class ...InnerArgs>
    R access_cache_impl(typename cache_vec_helper<sizeof ...(Is), R>::type &cache, const std::array<int, Dim> &indices,
                        std::index_sequence<Is...>, InnerArgs &&... args) {
        return access_cache_impl(cache, indices, std::index_sequence<Is...>{}, std::forward<InnerArgs>(args)...);
    }

    template<size_t First, size_t... Rest, class ...InnerArgs>
    R access_cache_impl(typename cache_vec_helper<sizeof ...(Rest) + 1, R>::type &cache,
                        const std::array<int, Dim> &indices, std::index_sequence<First, Rest...>,
                        InnerArgs &&... args) {
        return access_cache_impl(cache[indices[First]], indices, std::index_sequence<Rest...>{},
                                 std::forward<InnerArgs>(args)...);
    }

    // 递归访问缓存的辅助函数
    template<size_t Last, class ...InnerArgs>
    R access_cache_impl(typename cache_vec_helper<1, R>::type &cache, const std::array<int, Dim> &indices,
                        std::index_sequence<Last>, InnerArgs &&... args) {
        auto &elem = cache[indices[Last]];
        if (elem != dv) {
            return elem;
        }
        // 获取每一参数
        // 获取indices 的每一个成员
        // 递归调用
        return elem = f(*this, std::forward<InnerArgs>(args)...);
    }
};


template<int Dim, typename R, class F>
class memoize_vec_helper;

// 一维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<1, R(Args...), F> {
private:
    using function_type = F;
    function_type f;
    mutable vector<R> cache;
    R dv;

public:
    template<class Function>
    memoize_vec_helper(Function &&f, int sz, R r) : f(std::forward<Function>(f)), cache(sz, r), dv(r) {}

    template<class InnerArgs>
    R operator()(InnerArgs &&arg) const {
        if (arg < 0 || arg >= cache.size()) {
            return dv;
        }
        if (cache[arg] != dv) {
            return cache[arg];
        }
        return cache[arg] = f(*this, std::forward<InnerArgs>(arg));
    }
};

// 二维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<2, R(Args...), F> {
private:
    using function_type = F;
    function_type f;
    mutable vector<vector<R>> cache;
    R dv;

public:
    template<class Function>
    memoize_vec_helper(Function &&f, int fs, int ss, R r) : f(std::forward<Function>(f)), cache(fs, vector<R>(ss, r)),
                                                            dv(r) {}

    template<typename IndexType>
    R operator()(IndexType first, IndexType second) const {
        // 这里需要修改,因为arg现在是一个pair或者tuple
        if (first < 0 || first >= cache.size() || second < 0 || second >= cache[first].size()) {
            return dv;
        }
        if (cache[first][second] != dv) {
            return cache[first][second];
        }
        return cache[first][second] = f(*this, first, second);
    }
};

// 三维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<3, R(Args...), F> {
private:
    using function_type = F;
    function_type f;
    mutable vector<vector<vector<int>>> cache;
    R dv;

public:
    template<class Function>
    memoize_vec_helper(Function &&f, int fs, int ss, int ts, R r) : f(std::forward<Function>(f)),
                                                                    cache(fs,vector<vector<int>>(ss,vector<int>(ts,r))), dv(r) {}

    template<typename IndexType>
    R operator()(IndexType first, IndexType second, IndexType third) const {
        // 这里需要修改,因为arg现在是一个pair或者tuple
        if (first < 0 || first >= cache.size() || second < 0 || second >= cache[first].size() || third < 0 || third >= cache[first][second].size()) {
            return dv;
        }
        // 这里需要修改,因为arg现在是一个pair或者tuple
        if (cache[first][second][third] != dv) {
            return cache[first][second][third];
        }
        return cache[first][second][third] = f(*this, first, second, third);
    }
};


/**
 * @brief  cache使用map
 */
template<class Sig, class F>
memoize_helper<Sig, std::decay_t<F>> cache(F &&f) {
    return memoize_helper<Sig, std::decay_t<F>>(std::forward<F>(f), null_param{});
}

/*
 * n纬度数组
 *
 * */
template<size_t Dim, class Sig, class F>
memoize_nvec_helper<Dim, Sig, std::decay_t<F>> cache_nvec(F &&f, vector<int> sz, int default_value) {
    return memoize_nvec_helper<Dim, Sig, std::decay_t<F>>(std::forward<F>(f), sz, default_value);
}

// 创建一维数组的函数
/***
 * @brief  chache 使用一维数组
 * @param f 函数
 * @param sz 数组大小
 * @param default_value 默认值
 * @return
 */
template<class Sig, class F>
memoize_vec_helper<1, Sig, std::decay_t<F>> cache_vec(F &&f, int sz, int default_value) {
    return memoize_vec_helper<1, Sig, std::decay_t<F>>(std::forward<F>(f), sz, default_value);
}

// 创建二维数组的函数
/***
 * @brief cache 使用二维数组
 * @param f 函数
 * @param sz 第一维数组大小
 * @param sz2 第二维数组大小
 * @param default_value 默认值
 * @return
 */
template<class Sig, class F>
memoize_vec_helper<2, Sig, std::decay_t<F>> cache_vec2(F &&f, int sz, int sz2, int default_value) {
    return memoize_vec_helper<2, Sig, std::decay_t<F>>(std::forward<F>(f), sz, sz2, default_value);
}


// 创建三维数组的函数
/***
 * @brief
 * @param f 函数
 * @param sz 第一维数组大小
 * @param sz2 第二维数组大小
 * @param sz3 第三维数组大小
 * @param default_value 默认值
 * @return
 */
template<class Sig, class F>
memoize_vec_helper<3, Sig, std::decay_t<F>> cache_vec3(F &&f, int sz, int sz2, int sz3, int default_value) {
    return memoize_vec_helper<3, Sig, std::decay_t<F>>(std::forward<F>(f), sz, sz2, sz3, default_value);
}

使用方式

1、将上面的代码拷贝到代码文件
2、写出暴力dfs的代码
3、修改 , 我们会发现, 代码基本上一样的

比如暴力搜索代码如下
function<R(Args...)> dfs = [&](Args... args)->R {
    // 逻辑代码
};
调用方式 
dfs(args...);

修改如下

auto memo = [&](auto& dfs, Args... args)->R {
    // 逻辑代码, 这里的代码和之前不变
};
// 调用方式
cache<R(Args...)>(memo)(args...)

R 是函数的返回值
Args... 是参数列表的类型
args... 实际的调用参数

补充 使用数组代替map做为缓存,应对map可能超时的情况

上面的cache 使用的是map, 但是更多时候我们可以使用一维数组 和二维数组

使用一维数组的方式

cache_vec<R(int)>(memo)(int, #数组大小, #数组默认值); 

使用二维数组的方式

cache_vec2<R(int, int)>(memo)(int, int, #数组大小, #第二维数组大小, #数组默认值); 

使用三维数组的方式

cache_vec3<R(int, int, int)>(memo)(int, int, int, #数组大小, #第二维数组大小, #第三维数组大小, #数组默认值);

使用n维数组的方式

举例当n 为2时

cache_nvec<2, int(int, int)>(memo, {m, n}, INT_MAX)(0, 0)

2 代表数组为二维数组
int(int, int) 代表函数参数类型
memo
{m, n}使用一个vector 代表每一个纬度的大小
INT_MAX 代表默认值

举例, 使用map 二维数组 n纬数组的方式

197.地下城游戏

class Solution {
public:
    int calculateMinimumHP(vector<vector<int>>& dungeon) {
        int m = dungeon.size(), n = dungeon[0].size();
        auto memo = [&](auto&& dfs, int i, int j) {
            if (i == m - 1 && j == n - 1) {
                return max(1, 1 - dungeon[i][j]);
            }
            int val = min(dfs(i + 1, j), dfs(i, j + 1)) - dungeon[i][j];
            return max(val, 1);
        };
        // 使用map
        // return cache<int(int, int)>(memo)(0, 0);
        // 使用二维数组
        // return cache_vec2<int(int, int)>(memo, m, n, INT_MAX)(0, 0);
        // 使用n纬度数组
        return cache_nvec<2, int(int, int)>(memo, {m, n}, INT_MAX)(0, 0);
    }
};

实战,请详细比较 —brute的代码和 记忆话后的代码

爬楼梯问题 70.爬楼梯

class Solution {
public:
    int climbStairs(int n) {
        auto memo = [&](auto& dfs, int i)->int {
            if (i <= 2) {
                return i;
            }
            return dfs(i - 1) + dfs(i - 2);
        };
        return cache<int(int)>(memo)(n);
    }
};

打家劫舍问题 198.打家劫舍

class Solution {
public:
    // 暴力代码
    int rob_brute(vector<int>& nums) {
        function<int(int)> dfs = [&](int n) -> int {
            if(n < 0) return 0;
            return max(dfs(n - 1), dfs(n - 2) + nums[n]);
        };
        return dfs(nums.size() - 1);   
    }

    int rob(vector<int>& nums) {
        auto memo = [&](auto& dfs, int n) -> int {
            if(n < 0) return 0;
            return max(dfs(n - 1), dfs(n - 2) + nums[n]);
        };
        return cache<int(int)>(memo)(nums.size() - 1);   
    }
};

网格问题 62. 不同路径

class Solution {
public:
    int uniquePaths(int m, int n) {
        auto memo = [&](auto& dfs, int i, int j) {
            if (i >= m || j >= n) {
                return 0;
            }
            if (i == m - 1 and j == n - 1) {
                return 1;
            }
            return dfs(i + 1, j) + dfs(i, j + 1);
        };
        return cache<int(int, int)>(memo)(0, 0);
    }
};

01背包问题 416.分割等和子集

class Solution {
public:
    bool canPartition_brute(vector<int>& nums) {
        if (nums.size() <= 1) {
            return false;
        }
        int total = accumulate(nums.begin(), nums.end(), 0);
        if ((total & 0x1) != 0) {
            return false;
        }
        ranges::sort(nums);
        int half = total >> 1;
        function<int(int, int)> dfs = [&](int i, int j)->int {
            if (j == 0) {
                return 1;
            }
            if (i >= nums.size() || j < 0 || j < nums[i]) {
                return 0;
            }
            return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
        };
        return dfs(0, half);
    }

    bool canPartition(vector<int>& nums) {
        if (nums.size() <= 1) {
            return false;
        }
        int total = accumulate(nums.begin(), nums.end(), 0);
        if ((total & 0x1) != 0) {
            return false;
        }
        ranges::sort(nums);
        int half = total >> 1;
        auto memo = [&](auto& dfs, int i, int j)->int {
            if (j == 0) {
                return 1;
            }
            if (i >= nums.size() || j < 0 || j < nums[i]) {
                return 0;
            }
            return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
        };
        return cache<int(int, int)>(memo)(0, half);
    }
};
评论 (16)