A tutorial for computing W for the SEIR model

Author

Dylan Morris

Published

May 30, 2024

Introduction

This notebook provides a step-by-step guide for working with RandomTimeShifts.jl. This involves how we can solve for the moments (Section 3 of the paper), estimate the LST using the PE method, and get realisations of W using the M3 method. To start with make sure the package is installed and we have access to it’s functions.

if Base.find_package("RandomTimeShifts") === nothing 
    using Pkg
    Pkg.add(url = "https://github.com/djmorris7/RandomTimeShifts.jl")
end
import RandomTimeShifts as rts
using OrdinaryDiffEq
using Random
using Distributions
using CairoMakie
Precompiling RandomTimeShifts
        Info Given RandomTimeShifts was explicitly requested, output will be shown live 
WARNING: Method definition calculate_moments_ND(Any, Any, Any) in module RandomTimeShifts at /Users/dylanmorris/.julia/packages/RandomTimeShifts/ypuX5/src/diff_phi_functional_equations.jl:146 overwritten at /Users/dylanmorris/.julia/packages/RandomTimeShifts/ypuX5/src/diff_phi_functional_equations.jl:313.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  ? RandomTimeShifts
[ Info: Precompiling RandomTimeShifts [dfd47ba6-066f-4f89-982b-4f6940c1cebc]
WARNING: Method definition calculate_moments_ND(Any, Any, Any) in module RandomTimeShifts at /Users/dylanmorris/.julia/packages/RandomTimeShifts/ypuX5/src/diff_phi_functional_equations.jl:146 overwritten at /Users/dylanmorris/.julia/packages/RandomTimeShifts/ypuX5/src/diff_phi_functional_equations.jl:313.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
[ Info: Skipping precompilation since __precompile__(false). Importing RandomTimeShifts [dfd47ba6-066f-4f89-982b-4f6940c1cebc].
┌ Warning: MKL_jll is not available/installed.
└ @ MKL ~/.julia/packages/MKL/kCNzy/src/MKL.jl:39

Model specification

We’ll work with the SEIR model as specified in Section 4.1 of the paper, with parameters given by (\beta, \sigma, \gamma) = (0.56, 0.5, 0.33). For this example we’ll consider a population of size N = 10^6 and assume the trivial initial condition of the branching process (BP) approximation as \boldsymbol{Z}(0) = (1, 0). The progeny generating functions are given by,

f_1(\boldsymbol{s}) = s_2 ,\quad f_2(\boldsymbol{s}) = \frac{\gamma + \beta s_1 s_2}{\gamma + \beta}

and the lifetimes of individuals are specified by the vector,

\boldsymbol{a} = (\sigma, \gamma + \beta).

From the progeny generating functions we can specify the matrix \Omega:

\Omega = \begin{pmatrix} -\sigma & \sigma \\ \beta & -\gamma \end{pmatrix}.

In code we can detail these components as follows:

Z0 = [1, 0]
pars = (0.56, 0.5, 0.33)
β, σ, γ = pars
Ω = [-σ σ
     β -γ]

λ1, u_norm, v_norm = rts.calculate_BP_contributions(Ω)
lifetimes = [σ, β + γ]

PE method

As outlined in the paper, computing the moments really just boils down to identifying the parameters \alpha_{ij} and \beta_{ikl} in the functional equations Eq. (29). For this model we have two non-zero constants

\alpha_{12} = \sigma \textrm{ and }\beta_{212} = \beta.

We can then easily utilise the functions in the package to calculate the moments by defining dictionaries of the non-zero constants like so

αs = Dict(1 => Dict([1, 2] => σ), 2 => nothing)
βs = Dict(1 => nothing, 2 => Dict([2, 1, 2] => β))
Dict{Int64, Union{Nothing, Dict{Vector{Int64}, Float64}}} with 2 entries:
  2 => Dict([2, 1, 2]=>0.56)
  1 => nothing

The moments are then simply computed by calling the function rts.calculate_moments_generic():

num_moments = 3
moments = rts.calculate_moments_generic(Ω, αs, βs, lifetimes; num_moments = num_moments)
3×2 Matrix{Float64}:
 0.446057  0.553943
 0.959241  1.42326
 3.08301   5.32004

Note that the source code for rts.calculate_moments_generic() is quite simple and can be readily extended. This code also provides a straightforward way of backing out the linear system of equations for the moments but these do not tend to be too useful in and of themselves.

LSTs, inverse LSTs, and distributions

Computing the LST

Using the previous section we can now compute the LST. Doing this involves three parts:

  1. Estimate \lambda, \boldsymbol{u} and \boldsymbol{v} for the BP model,
  2. Then using the functions defined above, calculate the moments and the error bound
  3. Setup the system of ODE’s governing the neighbourhood extension part.

Following these, we simply call rts.construct_lst() to compute the estimate to the LST. The only extra part we need is a function for solving for \boldsymbol{F}(\boldsymbol{s}, t) at a fixed value of \boldsymbol{s}:

function F_fixed_s!(du, u, pars, t)
    β, σ, γ = pars
    du[1] = -σ * u[1] + σ * u[2]
    du[2] = γ -+ γ) * u[2] + β * u[1] * u[2]
    return nothing
end
F_fixed_s! (generic function with 1 method)

Now we can easily calculate the LST. Note that this can be simplified but we’ve left it a little more terse to explicitly demonstrate each step.

function construct_lst(pars, Z0, num_moments; ϵ = 1e-6, h = 2.0)
    # 1. Calculate quantities from the branching process
    β, σ, γ = pars
    Ω = [-σ σ
         β -γ]

    λ1, u_norm, v_norm = rts.calculate_BP_contributions(Ω)
    EW = sum(Z0 .* u_norm)

    # 2. Calculate the moments and error bounds
    αs = Dict(1 => Dict([1, 2] => σ), 2 => nothing )
    βs = Dict(1 => nothing, 2 => Dict([2, 1, 2] => β))
    lifetimes = [σ, β + γ]
    
    moments = rts.calculate_moments_generic(Ω, αs, βs, lifetimes; num_moments = num_moments)
    
    error_moment = moments[end, :]
    moments = moments[begin:(end - 1), :]
    coeffs = rts.moment_coeffs(moments)
    L = rts.error_bounds(ϵ, error_moment, num_moments)

    # 3. Setup the system to extend the neighbourhood
    prob = ODEProblem(F_fixed_s!, zeros(length(Z0)), (0, h), pars,
                      save_start = false, saveat = h)
    F_offspring(s) = rts.F_offspring_ode(s, prob)
    μ = exp(h * λ1)

    # Construct the LST
    lst_w = rts.construct_lst(coeffs, μ, F_offspring, L, Z0, λ1, h)
    
    return lst_w 
end
construct_lst (generic function with 1 method)

We can the easily construct a function handle for the lst (where num_moments and h are chosen as in the paper):

num_moments = 30
lst = construct_lst(pars, Z0, num_moments)
(::RandomTimeShifts.var"#lst_w#26"{Float64, Vector{Int64}, RandomTimeShifts.var"#lst_s_rest_x#25"{Float64, var"#F_offspring#12"{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Tuple{Float64, Float64, Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(F_fixed_s!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol}, @NamedTuple{save_start::Bool, saveat::Float64}}, SciMLBase.StandardODEProblem}}, Float64, Float64, Float64, RandomTimeShifts.var"#lst_s0_x#24"{Matrix{Float64}}}, RandomTimeShifts.var"#lst_s0_x#24"{Matrix{Float64}}}) (generic function with 1 method)

which can be evaluated easily lst(0.8)

lst(0.8)
0.8083652639976009

or for vectors (using Julia’s dot syntax) as

x = rand.(Uniform(0, 10), 5)
lst.(x)
5-element Vector{Float64}:
 0.9098108378377037
 0.6601254228590385
 0.654892134111919
 0.6259138260123503
 0.6565342427150288

Distributions from the LST

To invert the LST we need to estimate the extinction probability from the BP model. This can be done by defining a function.

function calculate_extinction_probs(pars)
    q0 = [0, 0]
    tspan = (0, 10000)

    prob = ODEProblem(F_fixed_s!, q0, tspan, pars)
    sol = solve(prob,
                Tsit5();
                save_start = false,
                save_everystep = false,
                save_end = true,
                abstol = 1e-9,
                reltol = 1e-9)

    q1 = sol.u[1]

    return q1
end
calculate_extinction_probs (generic function with 1 method)

Note that solving the differential equations is equivalent to finding the minimal non negative solution of the equation \boldsymbol{f}(\boldsymbol{q}) = \boldsymbol{q}.

Next we simply apply the inversion method from the package.

function invert_lst(Z0, pars, lst)
    q1 = calculate_extinction_probs(pars)
    q_star = prod(q1 .^ Z0)
    W_cdf = rts.construct_W_cdf_ilst(lst, q_star)
    
    return W_cdf
end
invert_lst (generic function with 1 method)

With the distribution of W in hand we can readily estimate the distribution and hence the density of the time-shift conditional on W > 0.

function time_shift_from_W_cdf(t_range, W_cdf, EW, λ1)
    q_star = W_cdf(0.0)

    F_τ_cdf(t) = W_cdf_ilst(exp1 * t) * EW)
    cdf_vals = rts.eval_cdf(F_τ_cdf, t_range)
    pdf_vals = rts.pdf_from_cdf(cdf_vals, Δt) ./ (1 - q_star)
    # We need to remove off the jump in the CDF and then renormalise
    # Think about this at the limits x -> 0 and x -> ∞ of 
    # (F(x) - q_star) / (1 - q_star) and it makes sense.
    cdf_vals = (cdf_vals .- q_star) ./ (1 - q_star)

    return pdf_vals, cdf_vals
end
time_shift_from_W_cdf (generic function with 1 method)
W_cdf = invert_lst([1, 0], pars, lst)
(::RandomTimeShifts.var"#W_cdf#15"{Bool, RandomTimeShifts.var"#lst_w#26"{Float64, Vector{Int64}, RandomTimeShifts.var"#lst_s_rest_x#25"{Float64, var"#F_offspring#12"{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Tuple{Float64, Float64, Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(F_fixed_s!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol}, @NamedTuple{save_start::Bool, saveat::Float64}}, SciMLBase.StandardODEProblem}}, Float64, Float64, Float64, RandomTimeShifts.var"#lst_s0_x#24"{Matrix{Float64}}}, RandomTimeShifts.var"#lst_s0_x#24"{Matrix{Float64}}}, Float64, Vector{ComplexF64}, Vector{ComplexF64}}) (generic function with 1 method)

Moment matching

The moment matching method uses much of the same setup code as we needed for the PE method.

We begin by calculating the first five moments of W and the extinction probabilities:

# 1. Calculate quantities from the branching process
β, σ, γ = pars
Ω = [-σ σ
        β -γ]

λ1, u_norm, v_norm = rts.calculate_BP_contributions(Ω)
EW = sum(Z0 .* u_norm)

# 2. Calculate the moments and error bounds
# diff_SEIR_coeffs_!(A, b, phis) = diff_SEIR_coeffs!(A, b, β, lifetimes, λ1, phis)

num_moments = 5
# moments = rts.calculate_moments_ND(diff_SEIR_coeffs_!, num_moments, Ω)

αs = Dict(1 => Dict([1, 2] => σ), 2=> nothing)
βs = Dict(1 => nothing, 2 => Dict([2, 1, 2] => β))
lifetimes = [σ, β + γ]

moments = rts.calculate_moments_generic(Ω, αs, βs, lifetimes; num_moments = num_moments)
q1 = calculate_extinction_probs(pars)
2-element Vector{Float64}:
 0.5892857140820947
 0.589285714463803

Now we optimise the loss function to get the parameters of the surrogate distribution for W | W > 0.

# Estimate the distribution of timeshifts through m3 method
pars_m3 = rts.minimise_loss(moments, q1)
2-element Vector{Vector{Float64}}:
 [1.0652559497393368, 1.0202889336446201, 1.0004603116552397]
 [1.4689002848789088, 1.0245523002997907, 1.0842464006256929]

This gives W_1 | W_1 > 0 \sim GG(1.067, 1.020, 1.001) and W_2 | W_2 > 0 \sim GG(1.482, 1.019, 1.088).

Now we can sample realisations and look at the distribution:

W_samples = rts.sample_W(10000, pars_m3, q1, Z0)

fig, ax = hist(W_samples, normalization = :probability, )
ax.xlabel = "w"
ax.ylabel = "density"
fig
┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.
└ @ Makie ~/.julia/packages/Makie/We6MY/src/scenes.jl:227

Consistency between methods

For a sanity check we can compare (visually) the distributions obtained using the two methods. The easiest way to do this is through the CDF and we renormalise it to remove the point mass at W = 0.

W_samples = rts.sample_W(10000, pars_m3, q1, Z0)

w_range = 0:0.5:maximum(W_samples)
cdf_vals = rts.eval_cdf(W_cdf, w_range)
cdf_vals = (cdf_vals .- W_cdf(0)) ./ (1 - W_cdf(0))

fig, ax = ecdfplot(W_samples, color = "red")
scatter!(ax, w_range, cdf_vals)
ax.xlabel = "w"
ax.ylabel = "density"
fig
┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.
└ @ Makie ~/.julia/packages/Makie/We6MY/src/scenes.jl:227

We see great consistency between the two methods and so things are working as we expect.