[TOC]
第一次看到python @cache 魔法的时候是在看 @灵茶山山神 在B站讲题, 当时的反应就是一脸震惊, 看着我c++ dfs 没有记忆化的代码陷入了沉思。
那么c++ 可不可以相对方便的像 python那样 稍微加亿点点代码实现 从暴力搜索到 记忆化搜索的 华丽的转身了,那就有了下面这个拙劣的模仿
class null_param {
};
template<typename Sig, class F>
class memoize_helper;
template<typename R, typename... Args, class F>
class memoize_helper<R(Args...), F> {
private:
using function_type = F;
using args_tuple_type = tuple<Args...>;
function_type f;
mutable map<args_tuple_type, R> cache;
public:
template<class Function>
memoize_helper(Function &&f, null_param) : f(std::forward<Function>(f)) {}
memoize_helper(const memoize_helper &other) : f(other.f) {}
template<class ...InnerArgs>
R operator()(InnerArgs &&... args) const {
auto args_tuple = make_tuple(std::forward<InnerArgs>(args)...);
auto it = cache.find(args_tuple);
if (it != cache.end()) {
return it->second;
}
return cache[args_tuple] = f(*this, std::forward<InnerArgs>(args)...);
}
};
template<size_t Dim, typename R>
class cache_vec_helper : public cache_vec_helper<Dim - 1, R> {
public:
using type = vector<typename cache_vec_helper<Dim - 1, R>::type>;
};
template<typename R>
class cache_vec_helper<0, R> {
public:
using type = R;
};
template<size_t Dim, typename R, class F>
class memoize_nvec_helper;
template<size_t Dim, typename R, typename ...Args, class F>
class memoize_nvec_helper<Dim, R(Args...), F> {
public:
using function_type = F;
function_type f;
mutable typename cache_vec_helper<Dim, R>::type cache;
array<size_t, Dim> szs;
R dv;
template<class Function>
memoize_nvec_helper(Function &&f, vector<int>& vec_sz, R r) : f(std::forward<Function>(f)), dv(r) {
for (int i = 0; i < Dim; ++i) {
szs[i] = vec_sz[i];
}
initialize_cache(cache, szs);
}
template<class ...InnerArgs>
R operator()(InnerArgs &&... args) {
static_assert(sizeof...(args) == Dim, "Number of arguments must match the dimension.");
std::array<int, Dim> indices = {std::forward<InnerArgs>(args)...};
// 检查是否越界
for (int i = 0; i < Dim; ++i) {
if (indices[i] < 0 || indices[i] >= static_cast<int>(szs[i])) {
return dv;
//throw std::out_of_range("Index out of bounds");
}
}
// 访问或修改缓存
return access_cache(cache, indices, std::make_index_sequence<Dim>{}, std::forward<InnerArgs>(args)...);
}
private:
// 辅助函数用于递归地初始化 cache
template<typename T, size_t D>
void initialize_cache(std::vector<T> &vec, const std::array<size_t, D> &sizes, size_t current_dim = 0) {
vec.resize(sizes[current_dim]);
if (current_dim + 1 < D) {
for (auto &sub_vec: vec) {
initialize_cache(sub_vec, sizes, current_dim + 1);
}
}
}
// 特化用于处理最后一个维度
void initialize_cache(std::vector<R> &vec, const std::array<size_t, Dim> &sizes, size_t current_dim = 0) {
vec.resize(sizes[current_dim], dv);
}
// 辅助函数用于递归地访问或修改缓存
template<size_t... Is, class ...InnerArgs>
R access_cache(typename cache_vec_helper<Dim, R>::type &cache, const std::array<int, Dim> &indices,
std::index_sequence<Is...>, InnerArgs &&... args) {
return access_cache_impl(cache, indices, std::make_index_sequence<Dim>{}, std::forward<InnerArgs>(args)...);
}
// 实际递归访问缓存的实现
template<size_t... Is, class ...InnerArgs>
R access_cache_impl(typename cache_vec_helper<sizeof ...(Is), R>::type &cache, const std::array<int, Dim> &indices,
std::index_sequence<Is...>, InnerArgs &&... args) {
return access_cache_impl(cache, indices, std::index_sequence<Is...>{}, std::forward<InnerArgs>(args)...);
}
template<size_t First, size_t... Rest, class ...InnerArgs>
R access_cache_impl(typename cache_vec_helper<sizeof ...(Rest) + 1, R>::type &cache,
const std::array<int, Dim> &indices, std::index_sequence<First, Rest...>,
InnerArgs &&... args) {
return access_cache_impl(cache[indices[First]], indices, std::index_sequence<Rest...>{},
std::forward<InnerArgs>(args)...);
}
// 递归访问缓存的辅助函数
template<size_t Last, class ...InnerArgs>
R access_cache_impl(typename cache_vec_helper<1, R>::type &cache, const std::array<int, Dim> &indices,
std::index_sequence<Last>, InnerArgs &&... args) {
auto &elem = cache[indices[Last]];
if (elem != dv) {
return elem;
}
// 获取每一参数
// 获取indices 的每一个成员
// 递归调用
return elem = f(*this, std::forward<InnerArgs>(args)...);
}
};
template<int Dim, typename R, class F>
class memoize_vec_helper;
// 一维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<1, R(Args...), F> {
private:
using function_type = F;
function_type f;
mutable vector<R> cache;
R dv;
public:
template<class Function>
memoize_vec_helper(Function &&f, int sz, R r) : f(std::forward<Function>(f)), cache(sz, r), dv(r) {}
template<class InnerArgs>
R operator()(InnerArgs &&arg) const {
if (arg < 0 || arg >= cache.size()) {
return dv;
}
if (cache[arg] != dv) {
return cache[arg];
}
return cache[arg] = f(*this, std::forward<InnerArgs>(arg));
}
};
// 二维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<2, R(Args...), F> {
private:
using function_type = F;
function_type f;
mutable vector<vector<R>> cache;
R dv;
public:
template<class Function>
memoize_vec_helper(Function &&f, int fs, int ss, R r) : f(std::forward<Function>(f)), cache(fs, vector<R>(ss, r)),
dv(r) {}
template<typename IndexType>
R operator()(IndexType first, IndexType second) const {
// 这里需要修改,因为arg现在是一个pair或者tuple
if (first < 0 || first >= cache.size() || second < 0 || second >= cache[first].size()) {
return dv;
}
if (cache[first][second] != dv) {
return cache[first][second];
}
return cache[first][second] = f(*this, first, second);
}
};
// 三维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<3, R(Args...), F> {
private:
using function_type = F;
function_type f;
mutable vector<vector<vector<int>>> cache;
R dv;
public:
template<class Function>
memoize_vec_helper(Function &&f, int fs, int ss, int ts, R r) : f(std::forward<Function>(f)),
cache(fs,vector<vector<int>>(ss,vector<int>(ts,r))), dv(r) {}
template<typename IndexType>
R operator()(IndexType first, IndexType second, IndexType third) const {
// 这里需要修改,因为arg现在是一个pair或者tuple
if (first < 0 || first >= cache.size() || second < 0 || second >= cache[first].size() || third < 0 || third >= cache[first][second].size()) {
return dv;
}
// 这里需要修改,因为arg现在是一个pair或者tuple
if (cache[first][second][third] != dv) {
return cache[first][second][third];
}
return cache[first][second][third] = f(*this, first, second, third);
}
};
/**
* @brief cache使用map
*/
template<class Sig, class F>
memoize_helper<Sig, std::decay_t<F>> cache(F &&f) {
return memoize_helper<Sig, std::decay_t<F>>(std::forward<F>(f), null_param{});
}
/*
* n纬度数组
*
* */
template<size_t Dim, class Sig, class F>
memoize_nvec_helper<Dim, Sig, std::decay_t<F>> cache_nvec(F &&f, vector<int> sz, int default_value) {
return memoize_nvec_helper<Dim, Sig, std::decay_t<F>>(std::forward<F>(f), sz, default_value);
}
// 创建一维数组的函数
/***
* @brief chache 使用一维数组
* @param f 函数
* @param sz 数组大小
* @param default_value 默认值
* @return
*/
template<class Sig, class F>
memoize_vec_helper<1, Sig, std::decay_t<F>> cache_vec(F &&f, int sz, int default_value) {
return memoize_vec_helper<1, Sig, std::decay_t<F>>(std::forward<F>(f), sz, default_value);
}
// 创建二维数组的函数
/***
* @brief cache 使用二维数组
* @param f 函数
* @param sz 第一维数组大小
* @param sz2 第二维数组大小
* @param default_value 默认值
* @return
*/
template<class Sig, class F>
memoize_vec_helper<2, Sig, std::decay_t<F>> cache_vec2(F &&f, int sz, int sz2, int default_value) {
return memoize_vec_helper<2, Sig, std::decay_t<F>>(std::forward<F>(f), sz, sz2, default_value);
}
// 创建三维数组的函数
/***
* @brief
* @param f 函数
* @param sz 第一维数组大小
* @param sz2 第二维数组大小
* @param sz3 第三维数组大小
* @param default_value 默认值
* @return
*/
template<class Sig, class F>
memoize_vec_helper<3, Sig, std::decay_t<F>> cache_vec3(F &&f, int sz, int sz2, int sz3, int default_value) {
return memoize_vec_helper<3, Sig, std::decay_t<F>>(std::forward<F>(f), sz, sz2, sz3, default_value);
}1、将上面的代码拷贝到代码文件
2、写出暴力dfs的代码
3、修改 , 我们会发现, 代码基本上一样的
比如暴力搜索代码如下
function<R(Args...)> dfs = [&](Args... args)->R {
// 逻辑代码
};
调用方式
dfs(args...);修改如下
auto memo = [&](auto& dfs, Args... args)->R {
// 逻辑代码, 这里的代码和之前不变
};
// 调用方式
cache<R(Args...)>(memo)(args...)
R 是函数的返回值
Args... 是参数列表的类型
args... 实际的调用参数
上面的cache 使用的是map, 但是更多时候我们可以使用一维数组 和二维数组
cache_vec<R(int)>(memo)(int, #数组大小, #数组默认值); cache_vec2<R(int, int)>(memo)(int, int, #数组大小, #第二维数组大小, #数组默认值); cache_vec3<R(int, int, int)>(memo)(int, int, int, #数组大小, #第二维数组大小, #第三维数组大小, #数组默认值);
举例当n 为2时
cache_nvec<2, int(int, int)>(memo, {m, n}, INT_MAX)(0, 0)2 代表数组为二维数组
int(int, int) 代表函数参数类型
memo
{m, n}使用一个vector 代表每一个纬度的大小
INT_MAX 代表默认值
class Solution {
public:
int calculateMinimumHP(vector<vector<int>>& dungeon) {
int m = dungeon.size(), n = dungeon[0].size();
auto memo = [&](auto&& dfs, int i, int j) {
if (i == m - 1 && j == n - 1) {
return max(1, 1 - dungeon[i][j]);
}
int val = min(dfs(i + 1, j), dfs(i, j + 1)) - dungeon[i][j];
return max(val, 1);
};
// 使用map
// return cache<int(int, int)>(memo)(0, 0);
// 使用二维数组
// return cache_vec2<int(int, int)>(memo, m, n, INT_MAX)(0, 0);
// 使用n纬度数组
return cache_nvec<2, int(int, int)>(memo, {m, n}, INT_MAX)(0, 0);
}
};class Solution {
public:
int climbStairs(int n) {
auto memo = [&](auto& dfs, int i)->int {
if (i <= 2) {
return i;
}
return dfs(i - 1) + dfs(i - 2);
};
return cache<int(int)>(memo)(n);
}
};class Solution {
public:
// 暴力代码
int rob_brute(vector<int>& nums) {
function<int(int)> dfs = [&](int n) -> int {
if(n < 0) return 0;
return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
return dfs(nums.size() - 1);
}
int rob(vector<int>& nums) {
auto memo = [&](auto& dfs, int n) -> int {
if(n < 0) return 0;
return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
return cache<int(int)>(memo)(nums.size() - 1);
}
};
class Solution {
public:
int uniquePaths(int m, int n) {
auto memo = [&](auto& dfs, int i, int j) {
if (i >= m || j >= n) {
return 0;
}
if (i == m - 1 and j == n - 1) {
return 1;
}
return dfs(i + 1, j) + dfs(i, j + 1);
};
return cache<int(int, int)>(memo)(0, 0);
}
};class Solution {
public:
bool canPartition_brute(vector<int>& nums) {
if (nums.size() <= 1) {
return false;
}
int total = accumulate(nums.begin(), nums.end(), 0);
if ((total & 0x1) != 0) {
return false;
}
ranges::sort(nums);
int half = total >> 1;
function<int(int, int)> dfs = [&](int i, int j)->int {
if (j == 0) {
return 1;
}
if (i >= nums.size() || j < 0 || j < nums[i]) {
return 0;
}
return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
};
return dfs(0, half);
}
bool canPartition(vector<int>& nums) {
if (nums.size() <= 1) {
return false;
}
int total = accumulate(nums.begin(), nums.end(), 0);
if ((total & 0x1) != 0) {
return false;
}
ranges::sort(nums);
int half = total >> 1;
auto memo = [&](auto& dfs, int i, int j)->int {
if (j == 0) {
return 1;
}
if (i >= nums.size() || j < 0 || j < nums[i]) {
return 0;
}
return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
};
return cache<int(int, int)>(memo)(0, half);
}
};