{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Integrated Mean Squared Prediction Error\n", "\n", "This notebook shows the `crit_IMSPE` functionality, based on the example from the `hetGP` R library.\n", "\n", "\n", "Suppose, given a trained GP, we wish to minimize global predictive variance over our potential input space $X$. In this case our acquisition function would be the Integrated Mean-Squared Prediction Error [Gramacy (2020) Ch. 10.3.1](https://bookdown.org/rbg/surrogates/chap10.html#chap10imspe):\n", "\n", "\\begin{align*}\n", "I_{n+1}(x_{n+1}) \\equiv \\mathrm{IMSPE}(\\bar{x}_1, \\dots, \\bar{x}_n, x_{n+1}) = \\int_{x \\in \\mathcal{X}} \\breve{\\sigma}^2_{n+1}(x) \\, dx\n", "\\end{align*}\n", "\n", "Where $\\breve{\\sigma}^2_{n+1}(x)$ is the nugget-free predictive uncertainty.\n", "\n", "As mentioned in [Gramacy (2020) Ch. 10.3.1](https://bookdown.org/rbg/surrogates/chap10.html#chap10imspe) such an integral often requires numerical evaluation, but conditional on GP hyperparameters, and if the domain of $X$ is a hyperrectangle (such as $[0,1]$ which is often done when modeling with GPs), then IMSPE can be calculated in closed form:\n", "\n", "\\begin{align}\n", "I_{n+1}(x_{n+1}) &= \\mathbb{E} \\{ \\breve{\\sigma}^2_{n+1}(X) \\} = \\mathbb{E} \\{K_\\theta(X, X) - k_{n+1}^\\top(X) K_{n+1}^{-1} k_{n+1}(X) \\} \\\\\n", "&= \\mathbb{E} \\{K_\\theta(X,X)\\} - \\mathrm{tr}(K_{n+1}^{-1} W_{n+1})\n", "\\end{align}\n", "\n", "Where $W_{ij} = \\int_{x \\in \\mathcal{X}} k(x_i, x) k(x_j, x)\\, dx$ which exists in closed form when $X$ is a hyperrectangle. For example, [Binois et. al (2019)](https://arxiv.org/pdf/1710.03206) show the case for a separable Gaussian kernel:\n", "\n", "\\begin{align*}\n", "W_{ij}= \\prod_{k=1}^m \\dfrac{\\sqrt{2\\pi \\theta_k} }{4} \\exp\\left\\{-\\dfrac{(x_{ik}-x_{jk})^2}{2 \\theta_k}\\right\\} \n", "\\left[\\mathrm{erf}\\left\\{\\dfrac{2-(x_{ik}+x_{jk})}{\\sqrt{2 \\theta_k}}\\right\\}+ \\mathrm{erf}\\left\\{\\dfrac{x_{ik}+x_{jk}}{\\sqrt{ 2\\theta_k}}\\right\\} \\right].\n", "\\end{align*}\n", "\n", "And note that $\\mathrm{erf}$ is the [error function](https://en.wikipedia.org/wiki/Error_function)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-06-26T15:31:45.586592\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.4, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "from hetgpy.IMSE import crit_IMSPE, Wij\n", "from hetgpy import hetGP\n", "from hetgpy.find_reps import find_reps\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "random = np.random.default_rng(42)\n", "def ftest(x, coef = 0.1):\n", " return np.sin(2 * np.pi * x) + random.normal(loc = 1, scale = coef)\n", "\n", "n = 9\n", "designs = np.linspace(0.1, 0.9,n).reshape(-1,1)\n", "reps = random.choice(1 + np.arange(10), size = n)\n", "\n", "X = designs[np.repeat(np.arange(n),reps)]\n", "Z = np.array([ftest(x) for x in X]).squeeze()\n", "model = hetGP()\n", "model.mle(\n", " X = X,\n", " Z = Z, \n", " lower = np.array([0.1]), \n", " upper = np.array([5]),\n", " known = {},\n", " init = {})\n", "\n", "fig, ax = plt.subplots()\n", "ax.scatter(X,Z)\n", "ax.set_xlabel('X')\n", "ax.set_ylabel('f(X)')\n", "ngrid = 501\n", "xgrid = np.linspace(0,1,ngrid).reshape(-1,1)\n", "\n", "## Precalculations\n", "Wijs = Wij(mu1 = model.X0, theta = model.theta, type = model.covtype)\n", "IMSPE_grid = np.array([crit_IMSPE(x,model=model,Wijs=Wijs) for x in xgrid])\n", "ax2 = ax.twinx()\n", "ax2.plot(xgrid.squeeze(),IMSPE_grid.squeeze(),'r--')\n", "ax2.set_ylabel('IMSPE');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in this case, the IMSPE is minimized at the edge of our input domain ($x=0$)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 4 }