Source code for mbank.flow.utils

"""
mbank.flow.utils
================

Plotting & validation utilities for the `mbank.flow`
"""

import matplotlib.pyplot as plt
import warnings

import scipy.stats
import numpy as np
import tempfile

import os
import re
import torch

########################################################################

[docs]class validation_metric(): "Base class for validation metric: the method get_dist should be implemented by subclasses" def __init__(self, data, flow, N_estimation = 1000): self.data = data self.flow = flow self.N_samples = len(data) dist_list = [] for n in range(N_estimation): true_noise = flow._distribution.sample(self.N_samples).detach().numpy() dist_list.append(self.get_dist(true_noise)) dist_list = np.array(dist_list) self.metric_mean = np.mean(dist_list) self.metric_std = np.std(dist_list) self.metric_std_of_mean = np.std(dist_list)/np.sqrt(N_estimation) return
[docs] def get_dist(self, data): "Measures the distance between data and the noise distribution (a standard normal)" raise NotImplementedError("The distance between two dataset must be implemented by sub-classes!")
[docs] def get_validation_metric(self): "Check if the data transformed into noise are consistent with the random normal distribution" noise_data = self.flow.transform_to_noise(self.data).detach().numpy() dist = self.get_dist(noise_data) return dist
[docs]class ks_metric(validation_metric): "Class to compute the validation metric using the Kolmogorov-Smirnov test"
[docs] def get_dist(self, data): true_noise = self.flow._distribution.sample(self.N_samples).detach().numpy() pval = 1. for d in range(data.shape[1]): _, pvalue = scipy.stats.kstest(data[:,d], true_noise[:,d]) pval *= pvalue return np.log10(pval+1e-300)
[docs]class cross_entropy_metric(validation_metric): "Class to compute the validation metric using the Cross Entropy distance"
[docs] def get_dist(self, data): return self.flow._distribution.log_prob(data).mean()
########################################################################
[docs]class early_stopper: """ Implements early stopping for the training of the normalizing flow model """ def __init__(self, patience=10, min_delta=0, temp_file = None, return_best_model = True, verbose = False): self.patience = patience self.min_delta = np.abs(min_delta) self.counter = 0 self.min_validation_loss = np.inf if temp_file is None: self.temp = '.temp_flow_{}.zip'.format(np.random.randint(0, np.iinfo(np.int32).max)) else: self.temp = temp_file self.verbose = verbose self.return_best_model = return_best_model if self.verbose: print("Storing checkpoint flow in: ", self.temp) def __call__(self, flow, epoch, train_loss, validation_loss): #print('##') #print(validation_loss, self.min_validation_loss, self.counter) #print(validation_loss, self.min_validation_loss + self.min_delta) if torch.isnan(validation_loss): validation_loss, self.counter = np.inf, self.patience if self.verbose: print("nans appearing in the validation loss: terminating the training") if validation_loss > self.min_validation_loss - self.min_delta: self.counter += 1 if self.counter >= self.patience: if self.return_best_model: flow.load_weights(self.temp) if self.verbose: print("Terminating training due to early stopping") return True else: self.counter = 0 if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss flow.save_weigths(self.temp) return False
########################################################################
[docs]def plot_loss_functions(history, savefolder = None): """ Given a history dict, returned by :func:`mbank.flow.flowmodel.GW_Flow.train_flow_forward_KL`, it plots the loss function and the validation metric as a function of the epoch Parameters ---------- history: dict An history dict (as returned by :func:`mbank.flow.flowmodel.GW_Flow.train_flow_forward_KL`) savefolder: str A folder where to save the plots: they will be saved with the names `loss.png` and `validation_metric.png`. """ if isinstance(savefolder, str): if not savefolder.endswith('/'): savefolder = savefolder+'/' train_loss = history['train_loss'] validation_loss = history['validation_loss'] metric = history['valmetric_value'] validation_epoch = range(0, len(train_loss), history['validation_step']) #print(len(train_loss), len(validation_loss), len(validation_epoch), len(metric), history['validation_step']) plt.figure() plt.plot(range(len(train_loss)), train_loss, label = 'train') plt.plot(validation_epoch, validation_loss, label = 'validation') plt.xlabel("Epoch") plt.ylabel("Loss") plt.yscale('log') plt.legend() if isinstance(savefolder, str): plt.savefig(savefolder+"loss.png") #Plotting the validation metric only if it's present in the history dict if len(metric): plt.figure() plt.plot(validation_epoch, metric, c= 'b', label = 'validation metric') #plt.gca().fill_between(validation_epoch, metric_mean - metric_std, metric_mean + metric_std, alpha = 0.5, color='orange') #plt.axhline(metric_mean, c = 'r', label = 'expected value') plt.xlabel("Epoch") plt.ylabel(r"$\log(D_{KL})$") #plt.ylabel(r"$\log(p_{value})$") plt.legend() if isinstance(savefolder, str): plt.savefig(savefolder+"validation_metric.png") return
[docs]def create_gif(folder, savefile, fps = 1): "Given a folder of plots generated by a callback, it creates a gif summarizing the training history" #https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python try: import imageio.v2 as iio except ImportError: msg = "Unable to find one or more of the package `imageio`: you will not be able to create a gif. If interested, try `pip install imageio`." warnings.warn(msg) return if not folder.endswith('/'): folder = folder+'/' filenames = os.listdir(folder) good_files, epochs_list = [], [] for f in filenames: num_regex = re.findall(r'\d+', f) if len(num_regex)==0: continue epochs_list.append(int(num_regex[0])) good_files.append(f) ids_ = np.argsort(epochs_list) with iio.get_writer(savefile, mode='I', fps=fps) as writer: for id_ in ids_: image = iio.imread(folder+good_files[id_]) writer.append_data(image) return
[docs]def plotting_callback(model, epoch, train_loss, validation_loss, dirname, data_to_plot, variable_format, basefilename = None): "An example callback for plotting the KDE pairplots." if not os.path.isdir(dirname): os.mkdir(dirname) if not dirname.endswith('/'): dirname= dirname+'/' if isinstance(basefilename, str): savefile= '{}/{}_{}.png'.format(dirname, basefilename, epoch) else: savefile= '{}/{}.png'.format(dirname, epoch) data_flow = model.sample(data_to_plot.shape[0]).detach().numpy() compare_probability_distribution(data_flow, data_true = data_to_plot, variable_format = variable_format, title = 'epoch = {}'.format(epoch), savefile = savefile ) return False
[docs]def compare_probability_distribution(data_flow, data_true = None, variable_format = None, title = None, hue_labels = ('flow', 'train'), savefile = None, show = False): """ Shows the probability distribution learnt by the flow and compares it with the training one. It makes a nice contour plot to visualize the 2D slices of the multidimensional PDF. Parameters ---------- data_flow: :class:`~numpy:numpy.ndarray` Samples from the normalizing flow data_true: :class:`~numpy:numpy.ndarray` Samples from the target (true) distribution (if None, it will not be plotted) variable_format: str Variable format, to place the axes labels properly title: str A title for the plot hue_labels: list/tuple Labels for the two distributions: they will appear in the legend savefile: str File to save the plot at show: bool Whether to show the plot """ try: import pandas as pd import seaborn as sns except ImportError: msg = "Unable to find the packages `pandas` and `seaborn`: you will not be able to use the function `compare_probability_distribution`.\nIf you want to go ahead, try `pip install pandas seaborn`." warnings.warn(msg) return from mbank.handlers import variable_handler var_handler = variable_handler() labels = var_handler.labels(variable_format, latex = False) if isinstance(variable_format, str) else None hue_labels = list(hue_labels) plot_data = pd.DataFrame(data_flow, columns = labels) if data_true is not None: temp_plot_data = pd.DataFrame(data_true, columns = labels) plot_data = pd.concat([plot_data, temp_plot_data], axis=0, ignore_index = True) plot_data['distribution'] = hue_labels[0] if data_true is not None: with warnings.catch_warnings(): warnings.simplefilter("ignore") plot_data['distribution'][len(data_true):] = hue_labels[1] bins_dat = 40 if False: kdeplot_div = sns.jointplot( data=plot_data, x=labels[0], y=labels[1], kind="kde", hue="distribution", ratio=3, marginal_ticks=True, levels=8 ) g = sns.PairGrid(plot_data, hue="distribution", hue_order = hue_labels[::-1] if data_true is not None else hue_labels[:1]) g.map_upper(sns.scatterplot, s = 1) g.map_lower(sns.kdeplot, levels=8) #g.map_diag(sns.kdeplot, lw=2, legend=False) g.map_diag(sns.histplot, element = 'step') g.add_legend() if isinstance(title, str): plt.suptitle(title) if isinstance(savefile, str):plt.savefig(savefile) if show: plt.show() return