{ "cells": [ { "cell_type": "markdown", "id": "5dc31240", "metadata": {}, "source": [ "# Common Random Number GP\n", "\n", "This notebook walks through an example of a ''common random number'' (crn) GP. The crnGP is described in detail in [Fadkihar et. al 2022](https://ieeexplore.ieee.org/abstract/document/10408258). \n", "\n", "The main idea behind the crnGP is that we can utilize random seed information to assess the correlation between different random seeds. This makes it possible to do trajectory-level inference for the location and random seed pairing (x,r)." ] }, { "cell_type": "markdown", "id": "1b22dd58", "metadata": {}, "source": [ "In the following example, we design a synthetic function that is evaluated as:\n", "$$\n", "Y(x,s) = \\sin(x + \\frac{\\pi}{2}\\mathcal{1}(s==5)) + s\n", "$$\n", "\n", "Which evaluates the sine function and shifts the amplitude by the value of the random seed $s$, and also introduces a phase shift if the random seed is equal to 5.\n", "\n", "We evaluate the function over $x \\in [0,2\\pi]$ and for seeds $s \\in {1,5}$" ] }, { "cell_type": "code", "execution_count": 1, "id": "d4609ebe", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2026-01-13T09:24:11.006600\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "from hetgpy import crnGP\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "rng = np.random.default_rng(1)\n", "\n", "pps = 10 # points per seed\n", "x = np.linspace(0,2*np.pi,pps).reshape(-1,1)\n", "X = np.vstack([x,x])\n", "seeds = ([1] * pps) + ([5] * pps)\n", "X = np.hstack([X,np.array(seeds).reshape(-1,1)])\n", "# amplitude and phase shift\n", "Z = np.sin(X[:,0] + (np.pi/2)*(X[:,-1]==5)) + X[:,-1]\n", "Ztrue = Z.copy()\n", "noise = 0.5 * rng.normal(size = Z.size)\n", "Z += noise\n", "fig, ax = plt.subplots(figsize=(9,6))\n", "\n", "m1 = X[:,-1]==1\n", "m2 = X[:,-1]==5\n", "ax.scatter(X[m1][:,0],Z[m1],color='blue',label='Seed 1')\n", "ax.scatter(X[m2][:,0],Z[m2],color='red',label='Seed 5')\n", "ax.set_xlabel('X'); ax.set_ylabel('Y')\n", "ax.legend(edgecolor='black');" ] }, { "cell_type": "code", "execution_count": 3, "id": "ef3c6b0a", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2026-01-13T09:25:45.123241\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from hetgpy.optim import crit_EI\n", "model = crnGP()\n", "with np.errstate(divide='ignore', invalid='ignore'):\n", " model.mle(X,Z,covtype=\"Matern5_2\")\n", "pp = 200\n", "xp = np.linspace(0,2*np.pi,pp).reshape(-1,1)\n", "Xp = np.vstack([xp,xp])\n", "seeds = ([1] * pp) + ([5] * pp)\n", "Xp = np.hstack([Xp,np.array(seeds).reshape(-1,1)])\n", "\n", "with np.errstate(divide='ignore', invalid='ignore'):\n", " preds = model.predict(Xp,xprime=Xp)\n", "\n", "fig, ax = plt.subplots(figsize=(9,6))\n", "\n", "# data\n", "m1 = X[:,-1]==1\n", "m2 = X[:,-1]==5\n", "ax.scatter(X[m1][:,0],Z[m1],color='blue',label='Seed 1')\n", "ax.scatter(X[m2][:,0],Z[m2],color='red',label='Seed 5')\n", "\n", "\n", "# preds\n", "\n", "ax.plot(xp,preds['mean'][0:pp],color='blue')\n", "ax.plot(xp,preds['mean'][pp:],color='red')\n", "\n", "Xns = Xp[0:pp,:].copy()\n", "Xns[:,-1] = 10\n", "with np.errstate(divide='ignore', invalid='ignore'):\n", " preds_new_seed = model.predict(Xns,xprime=Xns)\n", "ax.plot(xp,preds_new_seed['mean'],color='green',label = 'New Seed')\n", "\n", "\n", "ax.set_xlabel('X'); ax.set_ylabel('Y')\n", "ax.legend(edgecolor='black');" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }