{
"cells": [
{
"cell_type": "markdown",
"id": "5dc31240",
"metadata": {},
"source": [
"# Common Random Number GP\n",
"\n",
"This notebook walks through an example of a ''common random number'' (crn) GP. The crnGP is described in detail in [Fadkihar et. al 2022](https://ieeexplore.ieee.org/abstract/document/10408258). \n",
"\n",
"The main idea behind the crnGP is that we can utilize random seed information to assess the correlation between different random seeds. This makes it possible to do trajectory-level inference for the location and random seed pairing (x,r)."
]
},
{
"cell_type": "markdown",
"id": "1b22dd58",
"metadata": {},
"source": [
"In the following example, we design a synthetic function that is evaluated as:\n",
"$$\n",
"Y(x,s) = \\sin(x + \\frac{\\pi}{2}\\mathcal{1}(s==5)) + s\n",
"$$\n",
"\n",
"Which evaluates the sine function and shifts the amplitude by the value of the random seed $s$, and also introduces a phase shift if the random seed is equal to 5.\n",
"\n",
"We evaluate the function over $x \\in [0,2\\pi]$ and for seeds $s \\in {1,5}$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d4609ebe",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"from hetgpy import crnGP\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rng = np.random.default_rng(1)\n",
"\n",
"pps = 10 # points per seed\n",
"x = np.linspace(0,2*np.pi,pps).reshape(-1,1)\n",
"X = np.vstack([x,x])\n",
"seeds = ([1] * pps) + ([5] * pps)\n",
"X = np.hstack([X,np.array(seeds).reshape(-1,1)])\n",
"# amplitude and phase shift\n",
"Z = np.sin(X[:,0] + (np.pi/2)*(X[:,-1]==5)) + X[:,-1]\n",
"Ztrue = Z.copy()\n",
"noise = 0.5 * rng.normal(size = Z.size)\n",
"Z += noise\n",
"fig, ax = plt.subplots(figsize=(9,6))\n",
"\n",
"m1 = X[:,-1]==1\n",
"m2 = X[:,-1]==5\n",
"ax.scatter(X[m1][:,0],Z[m1],color='blue',label='Seed 1')\n",
"ax.scatter(X[m2][:,0],Z[m2],color='red',label='Seed 5')\n",
"ax.set_xlabel('X'); ax.set_ylabel('Y')\n",
"ax.legend(edgecolor='black');"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ef3c6b0a",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from hetgpy.optim import crit_EI\n",
"model = crnGP()\n",
"with np.errstate(divide='ignore', invalid='ignore'):\n",
" model.mle(X,Z,covtype=\"Matern5_2\")\n",
"pp = 200\n",
"xp = np.linspace(0,2*np.pi,pp).reshape(-1,1)\n",
"Xp = np.vstack([xp,xp])\n",
"seeds = ([1] * pp) + ([5] * pp)\n",
"Xp = np.hstack([Xp,np.array(seeds).reshape(-1,1)])\n",
"\n",
"with np.errstate(divide='ignore', invalid='ignore'):\n",
" preds = model.predict(Xp,xprime=Xp)\n",
"\n",
"fig, ax = plt.subplots(figsize=(9,6))\n",
"\n",
"# data\n",
"m1 = X[:,-1]==1\n",
"m2 = X[:,-1]==5\n",
"ax.scatter(X[m1][:,0],Z[m1],color='blue',label='Seed 1')\n",
"ax.scatter(X[m2][:,0],Z[m2],color='red',label='Seed 5')\n",
"\n",
"\n",
"# preds\n",
"\n",
"ax.plot(xp,preds['mean'][0:pp],color='blue')\n",
"ax.plot(xp,preds['mean'][pp:],color='red')\n",
"\n",
"Xns = Xp[0:pp,:].copy()\n",
"Xns[:,-1] = 10\n",
"with np.errstate(divide='ignore', invalid='ignore'):\n",
" preds_new_seed = model.predict(Xns,xprime=Xns)\n",
"ax.plot(xp,preds_new_seed['mean'],color='green',label = 'New Seed')\n",
"\n",
"\n",
"ax.set_xlabel('X'); ax.set_ylabel('Y')\n",
"ax.legend(edgecolor='black');"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}