{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# hetGPy Introduction and Example\n",
"\n",
"This notebook shows a Gaussian Process Regression example using `hetGPy`. \n",
"\n",
"The dataset of interest is the classic [mcycle](https://rdrr.io/cran/MASS/man/mcycle.html) data (Silverman 1985, Venables and Ripley 2002)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"* For this example, we will run the Gaussian Process Regression in both R and Python."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# load libraries\n",
"from scipy.io import loadmat\n",
"from scipy.stats import norm # predictive itervals\n",
"import numpy as np # arrays\n",
"from hetgpy.example_data import mcycle\n",
"%config InlineBackend.figure_formats = ['svg']\n",
"\n",
"\n",
"# load dataset\n",
"m = mcycle()\n",
"X = m['times']\n",
"Z = m['accel']\n",
"xgrid = np.linspace(0,60,301).reshape(-1,1) # predictive grid"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Analysis with `hetGPy` is centered around the `hetGP` (or equivalent, `homGP`) object.\n",
"\n",
"This acts similarly to a `sklearn` model objects, like [Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html)\n",
"\n",
"Specifically, the general workflow is: \n",
"\n",
"* Instantiate the model object: `model = hetGP()`\n",
"* Fit to training data: `model.mleHetGP`\n",
"* Predict on unseen data: `model.predict`\n",
"\n",
"Subsequent notebooks will also show how to update a model object for online learning, and to use a model for sequential design."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from hetgpy import hetGP\n",
"# model\n",
"model = hetGP()\n",
"model.mleHetGP(\n",
" X = X,\n",
" Z = Z,\n",
" lower = np.array([1]),\n",
" upper = np.array([100]),\n",
" covtype = \"Gaussian\",\n",
" maxit = 100\n",
")\n",
"\n",
"preds = model.predict(xgrid)\n",
"preds['upper'] = norm.ppf(0.95, loc = preds['mean'], scale = np.sqrt(preds['sd2'] + preds['nugs'])).squeeze()\n",
"preds['lower'] = norm.ppf(0.05, loc = preds['mean'], scale = np.sqrt(preds['sd2'] + preds['nugs'])).squeeze()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"fig, ax = plt.subplots()\n",
"\n",
"ax.scatter(X.squeeze(),Z, color='black', label='Data')\n",
"ax.plot(xgrid.squeeze(),preds['mean'],color='blue',label='Mean')\n",
"ax.plot(xgrid.squeeze(),preds['upper'],color='blue',linestyle='dashed',label='90% Predictive Interval')\n",
"ax.plot(xgrid.squeeze(),preds['lower'],color='blue',linestyle='dashed')\n",
"ax.set_xlabel('times (milliseconds)');\n",
"ax.set_ylabel('accel (g)')\n",
"ax.legend();"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}