{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Integrated Mean Squared Prediction Error\n",
"\n",
"This notebook shows the `crit_IMSPE` functionality, based on the example from the `hetGP` R library.\n",
"\n",
"\n",
"Suppose, given a trained GP, we wish to minimize global predictive variance over our potential input space $X$. In this case our acquisition function would be the Integrated Mean-Squared Prediction Error [Gramacy (2020) Ch. 10.3.1](https://bookdown.org/rbg/surrogates/chap10.html#chap10imspe):\n",
"\n",
"\\begin{align*}\n",
"I_{n+1}(x_{n+1}) \\equiv \\mathrm{IMSPE}(\\bar{x}_1, \\dots, \\bar{x}_n, x_{n+1}) = \\int_{x \\in \\mathcal{X}} \\breve{\\sigma}^2_{n+1}(x) \\, dx\n",
"\\end{align*}\n",
"\n",
"Where $\\breve{\\sigma}^2_{n+1}(x)$ is the nugget-free predictive uncertainty.\n",
"\n",
"As mentioned in [Gramacy (2020) Ch. 10.3.1](https://bookdown.org/rbg/surrogates/chap10.html#chap10imspe) such an integral often requires numerical evaluation, but conditional on GP hyperparameters, and if the domain of $X$ is a hyperrectangle (such as $[0,1]$ which is often done when modeling with GPs), then IMSPE can be calculated in closed form:\n",
"\n",
"\\begin{align}\n",
"I_{n+1}(x_{n+1}) &= \\mathbb{E} \\{ \\breve{\\sigma}^2_{n+1}(X) \\} = \\mathbb{E} \\{K_\\theta(X, X) - k_{n+1}^\\top(X) K_{n+1}^{-1} k_{n+1}(X) \\} \\\\\n",
"&= \\mathbb{E} \\{K_\\theta(X,X)\\} - \\mathrm{tr}(K_{n+1}^{-1} W_{n+1})\n",
"\\end{align}\n",
"\n",
"Where $W_{ij} = \\int_{x \\in \\mathcal{X}} k(x_i, x) k(x_j, x)\\, dx$ which exists in closed form when $X$ is a hyperrectangle. For example, [Binois et. al (2019)](https://arxiv.org/pdf/1710.03206) show the case for a separable Gaussian kernel:\n",
"\n",
"\\begin{align*}\n",
"W_{ij}= \\prod_{k=1}^m \\dfrac{\\sqrt{2\\pi \\theta_k} }{4} \\exp\\left\\{-\\dfrac{(x_{ik}-x_{jk})^2}{2 \\theta_k}\\right\\} \n",
"\\left[\\mathrm{erf}\\left\\{\\dfrac{2-(x_{ik}+x_{jk})}{\\sqrt{2 \\theta_k}}\\right\\}+ \\mathrm{erf}\\left\\{\\dfrac{x_{ik}+x_{jk}}{\\sqrt{ 2\\theta_k}}\\right\\} \\right].\n",
"\\end{align*}\n",
"\n",
"And note that $\\mathrm{erf}$ is the [error function](https://en.wikipedia.org/wiki/Error_function)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"from hetgpy.IMSE import crit_IMSPE, Wij\n",
"from hetgpy import hetGP\n",
"from hetgpy.find_reps import find_reps\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"random = np.random.default_rng(42)\n",
"def ftest(x, coef = 0.1):\n",
" return np.sin(2 * np.pi * x) + random.normal(loc = 1, scale = coef)\n",
"\n",
"n = 9\n",
"designs = np.linspace(0.1, 0.9,n).reshape(-1,1)\n",
"reps = random.choice(1 + np.arange(10), size = n)\n",
"\n",
"X = designs[np.repeat(np.arange(n),reps)]\n",
"Z = np.array([ftest(x) for x in X]).squeeze()\n",
"model = hetGP()\n",
"model.mle(\n",
" X = X,\n",
" Z = Z, \n",
" lower = np.array([0.1]), \n",
" upper = np.array([5]),\n",
" known = {},\n",
" init = {})\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.scatter(X,Z)\n",
"ax.set_xlabel('X')\n",
"ax.set_ylabel('f(X)')\n",
"ngrid = 501\n",
"xgrid = np.linspace(0,1,ngrid).reshape(-1,1)\n",
"\n",
"## Precalculations\n",
"Wijs = Wij(mu1 = model.X0, theta = model.theta, type = model.covtype)\n",
"IMSPE_grid = np.array([crit_IMSPE(x,model=model,Wijs=Wijs) for x in xgrid])\n",
"ax2 = ax.twinx()\n",
"ax2.plot(xgrid.squeeze(),IMSPE_grid.squeeze(),'r--')\n",
"ax2.set_ylabel('IMSPE');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that in this case, the IMSPE is minimized at the edge of our input domain ($x=0$)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}