{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# hetGPy Sequential Learning\n",
"\n",
"This notebook illustrates the use of the `update` method in `hetGPy`.\n",
"\n",
"Suppose we are interested in online learning, where we have a trained GP model, but we want to continually update our GP regression as we acquire new data, each time intializing the maximum likelihood estimation routine from our current set of GP hyperparameters.\n",
"\n",
"In `hetGPy`, model updates are on the order of $O(n^2)$ which is faster than the full training size of $O(n^3)$ see [Binois and Gramacy (2021)](https://www.jstatsoft.org/article/view/v098i13) for details.\n",
"\n",
"First, we import our libraries and setup a few helper functions.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from hetgpy import hetGP\n",
"from hetgpy.test_functions import f1d\n",
"from copy import copy\n",
"from scipy.stats import norm\n",
"import os\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"%config InlineBackend.figure_formats = ['svg']\n",
"\n",
"rand = np.random.default_rng(42)\n",
"\n",
"def noise_fun(x,coef=1):\n",
" noise = coef * (1.1 + np.sin(2*np.pi*x))**2\n",
" return noise\n",
"\n",
"def f1d_n(x):\n",
" '''f1d function with spatially varying noise'''\n",
" noise_var = noise_fun(x)\n",
" return f1d(x).squeeze() + rand.normal(loc = 0, scale = noise_var.squeeze(), size = len(x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we see up an online learning procedure where we gather new data along a moving window in $[0,1]$. Our data-generating process is noisy and nonstationary, so we will see higher noise for low values of $X$ and lower noise for high values of X. The true underlying function is the `f1d` function from [Forrester, Sobester, Keane (2008) Appendix](https://onlinelibrary.wiley.com/doi/book/10.1002/9780470770801) \n",
"\n",
"where:\n",
"\\begin{align*}\n",
"f(x) = (6x-2)^2 \\sin{12x-4}\n",
"\\end{align*}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## Initial data set\n",
"nvar = 1\n",
"n = 20\n",
"X = np.linspace(0, 0.1,n).reshape(-1,1)\n",
"mult = rand.choice(np.arange(1,6),size = n, replace = True)\n",
"X = np.vstack(np.repeat(X,mult,axis=0))\n",
"\n",
"testpts = np.linspace(0,1,int(10*n))\n",
"Z = f1d_n(X)\n",
"Ztrue = f1d(testpts)\n",
"\n",
"nsteps = 10\n",
"npersteps = 10\n",
"step = 1\n",
"\n",
"interval_size = int(len(testpts)/nsteps)\n",
"interval = np.arange(1,interval_size+1)\n",
"\n",
"\n",
"model = hetGP()\n",
"model.mle(\n",
" X = X,\n",
" Z = Z,\n",
" lower = 0.1 + 0.0*np.arange(X.shape[1]),\n",
" upper = 5 + 0.0*np.arange(X.shape[1]),\n",
" maxit = 500,\n",
" settings = {'checkHom':False}\n",
")\n",
"model_init = copy(model)\n",
"\n",
"fig, ax = plt.subplots(nrows=1,ncols=2, figsize = (11.5,8),sharey=True)\n",
"for i in range(2):\n",
" # plot true test function\n",
" ax[i].plot(testpts.squeeze(),Ztrue,alpha=0.2,color='red',label = 'True')\n",
" ax[i].set_xlabel('X')\n",
" ax[i].set_ylabel('f1d(X)')\n",
" \n",
"for i in range(nsteps):\n",
" \n",
" interval = np.clip(interval + interval_size,a_min=None,a_max=len(testpts)-1)\n",
" newIds = sorted(rand.choice(interval,size=npersteps,replace = False))\n",
" newmult = rand.choice(np.arange(1,6),len(newIds), replace = True)\n",
" newIds = np.repeat(newIds,newmult)\n",
" newX = testpts[newIds].reshape(-1,1)\n",
" newZ = f1d_n(newX.squeeze())\n",
" \n",
" model.update(Xnew = newX, Znew = newZ)\n",
" \n",
" X = np.vstack([X,newX]) \n",
" Z = np.hstack([Z,newZ.squeeze()])\n",
"\n",
" plot_steps = (2,nsteps-1)\n",
"\n",
" # show an early iteration and a later one\n",
" if i in plot_steps:\n",
" j = 0 if i == plot_steps[0] else 1\n",
" ax[j].scatter(X.squeeze().copy(),Z.copy(), alpha = 1.0, color = 'black',label = 'Data')\n",
" preds = model.predict(testpts.reshape(-1,1))\n",
" preds['upper'] = norm.ppf(0.95, loc = preds['mean'], scale = np.sqrt(preds['sd2'] + preds['nugs'])).squeeze()\n",
" preds['lower'] = norm.ppf(0.05, loc = preds['mean'], scale = np.sqrt(preds['sd2'] + preds['nugs'])).squeeze() \n",
" ax[j].plot(testpts.squeeze(),preds['mean'],color='blue',label='Predicted Mean')\n",
" ax[j].plot(testpts.squeeze(),preds['lower'],color='blue',linestyle='dashed')\n",
" ax[j].plot(testpts.squeeze(),preds['upper'],color='blue',linestyle='dashed')\n",
" ax[j].set_title(f'Iter {i+1}')\n",
" ax[j].legend(loc='upper left');\n",
"\n",
"fig.tight_layout()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}