题目:
Consider the following recursive function:
int recur(int n) {
if (n < 3)
return 1;
return recur(n-1) * recur(n-2) * recur(n-3);
}
if memoisation is applied and recur(n-1) is calculated and stored before calculating recur(n-2) and recur(n-3), for the call recur(6), how many calls will be made to recur()?
我的解法这样的:
int count_x = 0;
unordered_map map_x;
int recur(int n) {
count_x++;
if (n < 3)国外服务器
return 1;
// cout << n << endl;
if (map_x.find(n) != map_x.end()) {
return map_x[n];
}
map_x[n] = recur(n-1) * recur(n-2) * recur(n-3);
return map_x[n];
}
TEST(test, recur) {
cout << recur(6) << endl;
cout << count_x << endl;
}
输出结果是:
1
13
现在明确题目的答案不是 13,我这么验证哪里存在问题吗
答案是 7 么(包括 recur(6)本身)
是的!请问为什么呢?
你那个解法的代码里 count_x 的计算没尊重 memoisation 呀
map_x[0], map_x[1], map_x[2] 先放到 map 里
“if memoisation is applied and recur(n-1) is calculated and stored before calculating recur(n-2) and recur(n-3)”
没太搞懂,如果答案是 7,是不是 recur(0) - recur(6)各计算一次?这样的话得按照从小到大进行计算与储存,即 recur(n-3) => recur(n-2) => recur(n-1)。
但是题目里写的是先储存 recur(n-1),那岂不是 recur(n-3)和 recur(n-2)的结果会被重复计算?这样答案为何还是 7 ?
count_x 在最上面啊,每次 recur function 被 call 都会+1,这个计算方式不适用于 memoisation 吗?
答案是 7 的话,执行前明显先把 3 的解也存储起来了:
6 > 5, 4, 3
5 > 4, 3, 2
4 > 3, 2, 1
3 直接查,为 1 次
4 为 3, 2, 1: 都是直接查,为 3 次
5 为 4, 3, 2: 因为上面 4 已经存储,也是直接查, 为 3 次
当然,程序还可以优化为先找相对小的递归的解,这样大的问题就能从小的问题中直接查询答案:
map_x[n] = recur(n-3) * recur(n-1) * recur(n-1);
memoisation 的意义不就是减少原始调用么?所以在你的代码里,当从 map 中返回值时(相当于 memoisation ),就不应该再算作是调用次数了。
那句话是在说 recur(n-1) * recur(n-2) * recur(n-3) 这个表达式中 recur(n-1) 被确保先于后两者求值。在 C++中,同个表达式内的求值顺序通常是不确定的。
太感谢了!恍然大悟!
这样子输出是 7 !!!
```c++
int count_x = 0;
unordered_map
map_x;
int recur(int n) {
// cout << n << endl;
if (map_x.find(n) != map_x.end()) {
return map_x[n];
}
count_x++;
if (n < 3){
map_x[n] = 1;
return 1;
}
map_x[n] = recur(n-1) * recur(n-2) * recur(n-3);
return map_x[n];
}
TEST(test, recur) {
cout << recur(6) << endl;
cout << count_x << endl;
}
```
对本次调用你查表,但你没在内部递归调用前对每个子调用查表,导致你多了一些函数调用。
至于次数,6 、5 、4 、3 、2 、1 、0,每一个都算一遍并缓存,闭着眼睛也是 7 次。
简单直接的判断方法:
3 为临界值,n=3 时最多向下 n-1/n-2/n-3,即 recur 最小的参数 n 为 0 ;
每个 n 对应的值都进行记忆的前提下,每个 n 只需要调用一次 recur ;
首次调用 recur(6),则整个过程需要对 0-6 挨个计算、记忆,所以是 7 次(不管小于 3 的参数是否记忆,当 n=3/4/5 时 n 也被缓存过、不会被重复调用,所以小于 3 的也不会被重复调用)。
可以类推首次调用大于等于 3 和小于 3 的次数
只要进入 recur 就算一次递归调用吧,不管是查表还是继续递归计算。
也就是 recur 被调用的次数。
不得不承认,题目确实不严谨。如果结果是 7 的话问题可以问题成 how many times of recur function actually computed 。我的理解是 return memo 就不算是 computed, 而是 cache,但是依然会有 function be called 。
换种写法
```cpp
int count_x = 0;
unordered_map map_x;
int recur(int n) {
count_x++;
if (map_x.find(n) != map_x.end()) {
return map_x[n];
}
if (n < 3) {
map_x[n] = 1;
return map_x[n];
}
int next_1 = (map_x.find(n-1) == map_x.end()) ? recur(n-1) : map_x[n-1];
int next_2 = (map_x.find(n-2) == map_x.end()) ? recur(n-2) : map_x[n-2];
int next_3 = (map_x.find(n-3) == map_x.end()) ? recur(n-3) : map_x[n-3];
map_x[n] = next_1 * next_2 * next_3;
return map_x[n];
}
```
顺着 lz 思路,修改如下
```cpp
int recur(int n) {
if (map_x.find(n) != map_x.end()) {
return map_x[n];
}
count_x++;
if (n < 3) {
map_x[n] = 1;
} else {
map_x[n] = recur(n-1) * recur(n-2) * recur(n-3);
}
return map_x[n];
}
```
int recur(int n) {
if (map_x.find(n) != map_x.end()) {
return map_x[n];
}
count_x++;
if (n < 3) {
map_x[n] = 1;
} else {
map_x[n] = recur(n-1) * recur(n-2) * recur(n-3);
}
return map_x[n];
}
好吧,谢谢解释。原来是这个意思。
不过,如果是指这一行内的求值顺序的话,题目为啥要特意强调一下?除非这个顺序对于这个题目的答案(total of recur() invoked)会有影响,但是基于 memoisation 的话两者应该一样?还是我忽略了啥
即
// 1)
auto n_3 = recur(n-3);
auto n_2 = recur(n-2);
auto n_1 = recur(n-1);
return n_3 + n_2 + n1;
2)
auto n_1 = recur(n-1);
auto n_2 = recur(n-2);
auto n_3 = recur(n-3);
return n_1 + n_2 + n3;
这个取决于你的 cache lookup op 是在 caller side 还是 callee side~
像 15 楼一样,放在 caller side 的话,就会实质上而不是概念上减少 recur 调用的数量。
cache lookup 也可能是在 callee wrapper
比如 Python 可以直接拿 functools.lru_cache() 修饰符,SICP 3.5 也随便地写了一个 memo-proc
这提倡了一种正交性。
如果有并行求值的情况呢?
题目本身不严谨,我觉得应该说明是最少调用次数。
如果是最少调用次数的话,那应该在 caller 的时候对 recur(n-1),recur(n-2),recur(n-3)也进行查表,这样就不会进入 recur()。
而要得出 f(n-1),必然会得到 f(n-2),f(n-3)...f(n-n), 所以 f(n)对于 f()的调用必然等于 n+1 。
int count_x = 0;
unordered_map map_x;
int recur(int n) {
count_x++;
if (map_x.find(n) != map_x.end()) {
return map_x[n];
}
if (n < 3){
map_x[n] = 1;
return 1;
}
int temp_n1 = map_x.find(n-1) == map_x.end()?recur(n-1) :map_x[n-1];
int temp_n2 = map_x.find(n-2) == map_x.end()?recur(n-2) :map_x[n-2];
int temp_n3 = map_x.find(n-3) == map_x.end()?recur(n-3) :map_x[n-3];
map_x[n] = temp_n1*temp_n2*temp_n3;
return map_x[n];
}