Source code for hetgpy.plot

'''
Suite of plotting functions for model checks/diagnostics/etc.
'''

import warnings
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt


[docs] def plot_optimization_iterates(object, keys_and_title = None): r''' Plot maximum likelihood iterates Parameters ---------- object: hetgpy.homGP.homGP or hetgpy.hetGP.hetGP model hetGPy object keys_and_title: iterable for model component to extract (theta, g, etc.) Returns ------- fig, ax: matplotlib figure and axes ''' def extract_variable(key): # extract iterates from model object out = np.array([d[key] for d in object['iterates']]) if len(out.shape)==1: out = out.reshape(-1,1) return out fig, ax = plt.subplots(nrows=1,ncols = len(keys_and_title),figsize=(11.5,8)) fig.supxlabel('Iteration') xs = np.arange(len(object['iterates'])) i = 0 for key, ax_title in keys_and_title.items(): ys = extract_variable(key) for j in range(ys.shape[1]): label = key if key in ('theta','Delta'): label = r'$\{}_{}$'.format(key,j+1) ax[i].plot(xs, ys[:,j],label=label) # axis options ax[i].set_title(ax_title) i+=1 return fig, ax
[docs] def plot_diagnostics(model, interval:str='predictive'): r''' Diagnostics plot which mirrors the plot(model) routine in hetGP Plots the LOO predctions against the model data Parameters ---------- model: hetGPy model interval: str one of 'confidence' or 'predictive' Returns ------- fig, ax: matplotlib figure and axes ''' preds = model.predict(model.X0, interval=interval, interval_lower=0.05, interval_upper=0.95) pred_interval = preds['confidence_interval'] if interval == 'confidence' else preds['predictive_interval'] fig, ax = plt.subplots() idxs = np.repeat(np.arange(len(model.X0)),model.mult) ax.hlines( y=preds['mean'], xmin=pred_interval['lower'], xmax=pred_interval['upper'], label='Prediction Interval' if interval == 'predictive' else 'Confidence Interval', zorder=-10) ax.scatter(model.Z, preds['mean'][idxs], facecolors='none', edgecolors='black', label='Observations',zorder=5) ax.axline((0, 0), slope=1,color='black',linestyle='dashed') ax.scatter(model.Z0[(model.mult>1).nonzero()[0]], preds['mean'][(model.mult>1).nonzero()[0]], label=r'Averages (if mult \textgreater 1)',color='red',zorder=10) ax.legend(loc='upper left',edgecolor='black') ax.set_title('Model Diagnostics') ax.set_xlabel('Observed') ax.set_ylabel('Predicted') return fig, ax