{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# hetGPy SIR Example\n", "\n", "This document walks through a two-dimenstional SIR example.\n", "\n", "### The SIR Simulation Model\n", "\n", "A Susceptible-Infected-Recovered (SIR) model is a [canonical epidemiological compartmental model](https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology). The core idea is that we can model the spread of an infectious disease by dividing the population into _compartments_ of those who have not had the disease (**S**), those who are infected (**I**) and those who have recovered (**R**), subject to infection rates $\\beta$ and recovery rates $\\gamma$. This can be compactly summarized with a flow diagram (from the above-linked wikipedia page):\n", "\n", "\n", "![SIR Example Model](https://upload.wikimedia.org/wikipedia/commons/thumb/3/30/Diagram_of_SIR_epidemic_model_states_and_transition_rates.svg/1024px-Diagram_of_SIR_epidemic_model_states_and_transition_rates.svg.png)\n", "\n", "\n", "Our specific implementation is a stochastic version of the SIR model via the Gillespie algorithm from [Hu and Ludkovski (2017)](https://epubs.siam.org/doi/10.1137/15M1045168) where the two model inputs are the size of the initial susceptible population $S_0$ and the initial infected population $I_0$ and our output is the number of cumulative infections in the population (note that both the inputs and outputs are scaled to $[0,1]$ to simplify the calculations, but are otherwise not meaningful). \n", "\n", "\n", "`hetGPy` contains an implemention of this stochastic SIR model (which is also available in the original `hetGP` R package). Note that the example simulations take 1-2 minutes to run.\n", "\n", "This walkthrough is also essentially a Python replicate of the R examples from [Binois and Gramacy (2018)](https://cran.r-project.org/web/packages/hetGP/vignettes/hetGP_vignette.pdf) and [Gramacy (2020)](https://bookdown.org/rbg/surrogates/chap10.html#chap10varp) Ch. 10.2.1.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from hetgpy import hetGP\n", "from hetgpy.test_functions import sirEval\n", "from scipy.stats.qmc import LatinHypercube\n", "from scipy.stats import norm\n", "from scipy.io import loadmat\n", "\n", "\n", "# space filling design of inputs\n", "seed = 10\n", "rand = np.random.default_rng(seed)\n", "lhs = LatinHypercube(d = 2, seed = seed)\n", "X = lhs.random(200)\n", "\n", "# replicate each input location between 1 and max_reps times\n", "max_reps = 100\n", "reps = rand.choice(max_reps + 1, size = len(X), replace = True) \n", "idxs = np.repeat(np.arange(len(X)),reps)\n", "\n", "# run SIR simulation\n", "X = X[idxs,:]\n", "Y = np.zeros(len(X))\n", "for i in range(len(X)):\n", " Y[i] = sirEval(X[[i],:],seed = i).squeeze()\n", "\n", "# predictive grid\n", "xseq = np.linspace(0,1,100)\n", "xgrid = np.array([(y,x) for x in xseq for y in xseq])" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "model = hetGP()\n", "model.mle(\n", " X = X,\n", " Z = Y,\n", " covtype = \"Matern5_2\",\n", " lower = np.array([0.05,0.05]),\n", " upper = np.array([10.0,10.0]),\n", " maxit = 1e4\n", ")\n", "preds = model.predict(xgrid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we can plot our results:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-12-05T09:45:38.889082\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "from matplotlib import cm\n", "import matplotlib.pyplot as plt\n", "f, ax = plt.subplots(nrows=1,ncols=2, figsize = (9,6))\n", "\n", "# data\n", "m = preds['mean'].reshape(100,100)\n", "v = (preds['sd2'] + preds['nugs']).reshape(100,100)\n", "\n", "ax0 = ax[0].imshow(m,origin='lower',extent=[0,1,0,1],cmap='viridis')\n", "ax1 = ax[1].imshow(v,origin='lower',extent=[0,1,0,1],cmap='magma')\n", "ax[0].set_title('Predictive Mean')\n", "ax[1].set_title('Predictive Variance')\n", "for a in ax:\n", " a.set_xlabel('S0')\n", " a.set_ylabel('I0')\n", "\n", "kws = dict(fraction=0.046,pad=0.04)\n", "f.colorbar(ax0, ax = ax[0],**kws)\n", "f.colorbar(ax1, ax = ax[1],**kws)\n", "f.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As discussed in the other resources linked at the beginning of this example, the interpretation of this example is that when there are a large amount of initital Susceptibles ($S_0 > ~ 0.8$) and a moderate amount of initial Infecteds ($I_0 > ~ 0.2$) see high variance in that some stochastic realizations result in large amounts of cumulative infections, but others do not. This is due to the high nonlinearity of an SIR model." ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 4 }