Basic recursive life insurance models¶
Suppose that if a person is alive at the beginning of a year, they have a 1% chance of dying by the end of the year. Actuaries say they have a constant mortality rate of .01
.
An iterative life insurance model¶
mortality_rate = .01
death_probabilities = {}
alive_probabilities = {0: 1.0}
for t in range(3):
probability_death = alive_probabilities[t] * mortality_rate
death_probabilities[t] = probability_death
alive_probabilities[t + 1] = alive_probabilities[t] - probability_death
print(alive_probabilities)
{0: 1.0, 1: 0.99, 2: 0.9801, 3: 0.970299}
This code works, but we want something that is easier to reason about. People generally prefer expressing formulas in a recursive style as below.
A recursive life insurance model¶
def deaths(t):
"""probability of dying in [t, t+1)"""
return alive(t) * mortality_rate
def alive(t):
"""probability of being alive at time t"""
if t <= 0:
return 1
return alive(t-1) - deaths(t-1)
print(f"recursive {alive(2)=}")
print(f"iterative {alive_probabilities[2]=}")
print("Same result, nice")
recursive alive(2)=0.9801 iterative alive_probabilities[2]=0.9801 Same result, nice
The problem with recursion¶
Let's add some print statements and see if we can trace what the code is doing.
def dead(t):
print(f"dead({t})")
return alive(t) * mortality_rate
def alive(t):
print(f"alive({t})")
if t <= 0:
return 1
return alive(t-1) - dead(t-1)
alive(2)
alive(2) alive(1) alive(0) dead(0) alive(0) dead(1) alive(1) alive(0) dead(0) alive(0)
0.9801
From the logs above, we can imagine that the function calls to calculate alive(2)
are structured in the following way.
alive(2)
_______|__________
/ \
alive(1) dead(1)
/ \ |
alive(0) dead(0) alive(1)
| / \
alive(0) alive(0) dead(0)
|
alive(0)
The number of function calls grows exponentially. alive(t)
will call alive(t-1)
at least twice, causing alive(t-2)
to be called at least 4 times, and so on until our base case alive(0)
is hit. Let's do another experiment to verify the exponential growth.
def experiment_problematic_recursion(t):
def dead(t):
"""probability of dying in [t, t+1)"""
return alive(t) * mortality_rate
total = 0
def alive(t):
"""probability of being alive at time t"""
nonlocal total
total += 1
if t <= 0:
return 1
return alive(t-1) - dead(t-1)
alive(t)
return total
def run_experiment_problematic_recursion():
print("Running experiment_no_cache")
for t in range(5):
print(f"alive({t}) makes {experiment_problematic_recursion(t)} calls to alive")
run_experiment_problematic_recursion()
Running experiment_no_cache alive(0) makes 1 calls to alive alive(1) makes 3 calls to alive alive(2) makes 7 calls to alive alive(3) makes 15 calls to alive alive(4) makes 31 calls to alive
The formula appears to be pow(2, t+1) - 1
.
Fixing the problem with recursion¶
If we have calculated alive(5)
once, we store the result and just return that the next time alive(5)
needs to be calculated. This avoids making an exponential number of recursive calls.
This is called recursion with memoization. When we store the result of a function, we say that it has been cached and it is stored in the cache.
dead_cache = {}
def dead(t):
if t in dead_cache:
return dead_cache[t]
print(f"dead({t})")
dead_cache[t] = alive(t) * mortality_rate
return dead_cache[t]
alive_cache={}
def alive(t):
if t in alive_cache:
return alive_cache[t]
print(f"alive({t})")
if t <= 0:
alive_cache[t] = 1
else:
alive_cache[t] = alive(t-1) - dead(t-1)
return alive_cache[t]
alive(2)
alive(2) alive(1) alive(0) dead(0) dead(1)
0.9801
The functions that are executed and not looked up from the cache are called cache misses. We make the graph of the cache misses -
alive(2)
______|______
/ \
alive(1) dead(1)
/ \
alive(0) dead(0)
Our tree from earlier with no cache -
alive(2)
_______|__________
/ \
alive(1) dead(1)
/ \ |
alive(0) dead(0) alive(1)
| / \
alive(0) alive(0) dead(0)
|
alive(0)
The difference is that we never execute a function twice when we use the cache.
The difference becomes more important as t
becomes large.
def experiment_cache(t):
cache_misses_alive = 0
alive_cache={}
def alive(t):
nonlocal cache_misses_alive
if t in alive_cache:
return alive_cache[t]
cache_misses_alive += 1
if t <= 0:
alive_cache[t] = 1
else:
alive_cache[t] = alive(t-1) - dead(t-1)
return alive_cache[t]
dead_cache = {}
def dead(t):
if t in dead_cache:
return dead_cache[t]
dead_cache[t] = alive(t) * mortality_rate
return dead_cache[t]
alive(t)
return cache_misses_alive
def run_experiment_cache():
print("Running experiment_cache")
for t in range(5):
print(f"alive({t}) makes {experiment_cache(t)} calls to alive")
run_experiment_cache()
print("\ncompared to\n")
run_experiment_problematic_recursion()
Running experiment_cache alive(0) makes 1 calls to alive alive(1) makes 2 calls to alive alive(2) makes 3 calls to alive alive(3) makes 4 calls to alive alive(4) makes 5 calls to alive compared to Running experiment_no_cache alive(0) makes 1 calls to alive alive(1) makes 3 calls to alive alive(2) makes 7 calls to alive alive(3) makes 15 calls to alive alive(4) makes 31 calls to alive
We no longer have an exponential number of function calls being made.
Discussion¶
The root cause of our performance problems with recursion was that alive(t)
calls alive(t-1)
twice, which results in alive(t-2)
being called four times, and so on.
This problem is not specific to actuarial science. Consider the following code that calculates Fibonacci numbers recursively.
def fib(n):
if n <= 1:
return n
return fib(n-1) + fib(n-2)
Recursion with memoization is a common strategy for solving coding interview questions:
- Easy: https://leetcode.com/problems/climbing-stairs/description/
- Easy: https://leetcode.com/problems/min-cost-climbing-stairs/description/
- Medium: https://leetcode.com/problems/knight-probability-in-chessboard/description/
- Hard: https://leetcode.com/problems/sum-of-distances-in-tree/description/
- Hard: https://leetcode.com/problems/handshakes-that-dont-cross/description/