Source code for mbank.bank

"""
mbank.bank
==========

Module to implement a bank of gravitational waves signals.
It implement the class ``cbc_bank`` which provides a large number of functionalities to generate a bank, perform I/O operations on files
"""

import numpy as np
import warnings

	#ligo.lw imports for xml files: pip install python-ligo-lw
from ligo.lw import utils as lw_utils
from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw.utils import process as ligolw_process

from tqdm import tqdm

import ray

import scipy.spatial

from .placement import place_stochastically_in_tile, place_stochastically, place_iterative, place_random_tiling, place_pruning, create_mesh
from .utils import DefaultSnglInspiralTable, avg_dist, read_xml, partition_tiling, split_boundaries, plawspace, get_boundary_box

from .handlers import variable_handler, tiling_handler
from .metric import cbc_metric

############
#TODO: create a package for placing N_points in a box with lloyd algorithm (extra)
############

#############DEBUG LINE PROFILING
try:
	from line_profiler import LineProfiler

	def do_profile(follow=[]):
		def inner(func):
			def profiled_func(*args, **kwargs):
				try:
					profiler = LineProfiler()
					profiler.add_function(func)
					for f in follow:
						profiler.add_function(f)
					profiler.enable_by_count()
					return func(*args, **kwargs)
				finally:
					profiler.print_stats()
			return profiled_func
		return inner
except:
	pass

#pip install line_profiler
#add decorator @do_profile(follow=[]) before any function you need to track

####################################################################################################################
#FIXME: you are not able to perform the FFT of the WFs... Learn how to do it and do it well!

####################################################################################################################
####################################################################################################################

[docs]class cbc_bank: """ The class implements a bank for compact binary coalescence signals (CBC). A bank is a collection of templates (saved in the method ``bank.templates``). Each template is a row of the array; the columns are specified by a given variable format. To know more about the available variable formats, please refer to :class:`mbank.handlers.variable_handler`. A bank is generated from a tiling object (created internally) that speeds up the template placing. The tiling file is not part of a bank and lives as an independent object :class:`mbank.handlers.tiling_handler`. A bank can be saved in txt or in the std ligo xml file. """ def __init__(self, variable_format, filename = None): """ Initialize the bank with a given variable format. If a filename is given, the bank is loaded from file. Parameters ---------- variable_format: str How to handle the variables. See class variable_handler for more details filename: str Optional filename to load the bank from (if None, the bank will be initialized empty) """ #TODO: start dealing with spins properly from here... self.var_handler = variable_handler() self.variable_format = variable_format self.templates = None #empty bank assert self.variable_format in self.var_handler.valid_formats, "Wrong variable format given" if isinstance(filename, str): self.load(filename) #Adding the variable names as properties #TODO: Can you do this better? Now you're making a class just to store an integer... class get_var: def __init__(self, var_id): self.var_id = var_id def __call__(self, bank_obj): if bank_obj.templates is None: return None return bank_obj.templates[:,self.var_id] for i, name in enumerate(self.var_handler.labels(self.variable_format)): setattr(cbc_bank, name, property(get_var(i))) return @property def D(self): """ The dimensionality of the space Returns ------- D: float Keeps the dimensionality of the space """ return self.var_handler.D(self.variable_format) #handy shortening @property def placing_methods(self): """ List all the available placing methods Returns ------- placing_methods: list The available methods for placing the templates """ return ['uniform', 'qmc', 'geometric', 'iterative', 'stochastic', 'random', 'tile_random', 'geo_stochastic', 'random_stochastic', 'tile_stochastic', 'pruning', 'iterative_stochastic']
[docs] def load(self, filename): """ Load a template bank from file. They are added to the existing templates (if any). Parameters ---------- filename: str Filename to load the bank from """ if filename.endswith('.npy'): templates_to_add = np.load(filename) elif filename.endswith('.txt') or filename.endswith('.dat'): templates_to_add = np.loadtxt(filename) elif filename.endswith('.xml') or filename.endswith('.xml.gz'): if self.var_handler.format_info[self.variable_format]['e']: warnings.warn("Currently loading from an xml file does not support eccentricity") #reading the BBH components BBH_components = read_xml(filename, lsctables.SnglInspiralTable) #making the templates suitable for the bank templates_to_add = self.var_handler.get_theta(BBH_components, self.variable_format) #(N,D) else: raise ValueError("Type of file not recognized!") self.add_templates(templates_to_add) return
def _save_xml(self, filename, f_max = 1024., ifo = 'L1'): """ Save the bank to an xml file suitable for LVK applications Parameters ---------- filename: str Filename to save the bank at f_max: float End frequency (in Hz) for the templates ifo: str Name of the interferometer the bank refers to """ #getting the masses and spins of the rows m1, m2, s1x, s1y, s1z, s2x, s2y, s2z, e, meanano, iota, phi = self.var_handler.get_BBH_components(self.templates, self.variable_format).T if np.any(e != 0.): msg = "Currently xml format does not support eccentricity... The saved bank '{}' will have zero eccentricity".format(filename) warnings.warn(msg) #preparing the doc #See: https://git.ligo.org/RatesAndPopulations/lvc-rates-and-pop/-/blob/master/bin/lvc_rates_injections#L168 xmldoc = ligolw.Document() xmldoc.appendChild(ligolw.LIGO_LW()) signl_inspiral_table = lsctables.New(lsctables.SnglInspiralTable) #register a process_table about what code made the file process = ligolw_process.register_to_xmldoc( xmldoc, program="mbank", paramdict={},#process_params, #what should I enter here? comment="A bank of BBH, generated using a metric approach") #here we add the rows one by one for i in range(m1.shape[0]): #defining the row row = DefaultSnglInspiralTable() #This is a dirty trick for a std initialization (works) #row = lsctables.New(lsctables.SnglInspiralTable).RowType() #setting bank parameters row.mass1, row.mass2 = m1[i], m2[i] row.spin1x, row.spin1y, row.spin1z = s1x[i], s1y[i], s1z[i] row.spin2x, row.spin2y, row.spin2z = s2x[i], s2y[i], s2z[i] row.alpha3 = iota[i] row.alpha5 = phi[i] #are you sure it's alpha5? See here: https://github.com/gwastro/sbank/blob/7072d665622fb287b3dc16f7ef267f977251d8af/sbank/waveforms.py#L845 #shall I need to set other things by hand? E.g. taus... row.mtotal = row.mass1 + row.mass2 row.eta = row.mass1 * row.mass2 / row.mtotal**2 row.mchirp = ((row.mass1 * row.mass2)**3/row.mtotal)**0.2 row.chi = (row.mass1 *row.spin1z + row.mass2 *row.spin2z) / row.mtotal #is this the actual chi? #this is chi from https://git.ligo.org/lscsoft/gstlal/-/blob/master/gstlal-inspiral/python/_spawaveform.c#L896 #row.chi = (np.sqrt(row.spin1x**2+row.spin1y**2+row.spin1z**2)*m1 + np.sqrt(row.spin2x**2+row.spin2y**2+row.spin2z**2)*m2)/row.mtotal row.f_final = f_max row.ifo = ifo #setting the ifo chosen by the user #Setting additional parameters row.process_id = process.process_id #This must be an int row.event_id = i row.Gamma0 = float(i) #apparently Gamma0 is the template id in gstlal (for some very obscure reason) #for k, v in std_extra_params.items(): # setattr(row, k, v) signl_inspiral_table.append(row) #xmldoc.appendChild(ligolw.LIGO_LW()).appendChild(signl_inspiral_table) #ligolw_process.set_process_end_time(process) xmldoc.childNodes[-1].appendChild(signl_inspiral_table) lw_utils.write_filename(xmldoc, filename, verbose=False) xmldoc.unlink() return
[docs] def save_bank(self, filename, f_max = 1024., ifo = 'L1'): #TODO: change this name to `save` """ Save the bank to file **WARNING: xml file format currently does not support eccentricity** Parameters ---------- filename: str Filename to save the bank at f_max: float End frequency (in Hz) for the templates (applies only to xml format) ifo: str Name of the interferometer the bank refers to (only applies to xml files) """ if self.templates is None: raise ValueError('Bank is empty: cannot save an empty bank!') if filename.endswith('.npy'): templates_to_add = np.save(filename, self.templates) elif filename.endswith('.txt') or filename.endswith('.dat'): templates_to_add = np.savetxt(filename, self.templates) elif filename.endswith('.xml') or filename.endswith('.xml.gz'): self._save_xml(filename, f_max, ifo) else: raise RuntimeError("Type of file not understood. The file can only end with 'npy', 'txt', 'data, 'xml', 'xml.gx'") return
@property def BBH_components(self): """ Returns the BBH components of the templates in the bank. They are: `m1, m2, s1x, s1y, s1z, s2x, s2y, s2z, e, meanano iota, phi` Returns ------- BBH_components: :class:`~numpy:numpy.ndarray` shape: (N,12) Array of BBH components of the templates in the bank. They have the same layout as :func:`mbank.handlers.variable_handler.get_BBH_components` """ if self.templates is not None: return self.var_handler.get_BBH_components(self.templates, self.variable_format) return
[docs] def add_templates(self, new_templates): """ Adds a bunch of templates to the bank. They must be of a shape suitable for the variable format Parameters ---------- new_templates: :class:`~numpy:numpy.ndarray` shape: (N,D)/(D,) New templates to add. They need to be stored in an array of shape (N,D) or (D,), where D is the dimensionality of the bank """ new_templates = np.asarray(new_templates) if new_templates.ndim == 1: new_templates = new_templates[None,:] assert new_templates.ndim == 2, "The new templates are provided with a wrong shape!" assert self.D == new_templates.shape[1], "The templates to add have the wrong dimensionality. Must be {}, but given {}".format(self.D, new_templates.shape[1]) new_templates = self.var_handler.switch_BBH(new_templates, self.variable_format) if self.templates is None: self.templates = new_templates else: self.templates = np.concatenate([self.templates, new_templates], axis = 0) #(N,4) return
[docs] def place_templates(self, tiling, minimum_match, placing_method, N_livepoints = 10000, covering_fraction = 0.01, empty_iterations = 100, verbose = True): """ Given a tiling, it places the templates according to the given method and **adds** them to the bank Parameters ---------- tiling: tiling_handler A tiling handler with a non-empty tiling minimum_match: float Minimum match for the bank: it controls the distance between templates as in ``utils.avg_dist()`` placing_method: str The placing method to set templates in each tile. It can be: - `uniform` -> Uniform drawing in each hyper-rectangle, according to the volume - `qmc` -> Quasi Monte Carlo drawing in each hyper-rectangle, according to the volume - `random` -> Random points sampled from the tiling are added to the bank until only a fraction eta = 0.01 of the original volume is left uncovered - `stochastic` -> Stochastic placement: proposal are made and accepted only if they are more distant than sqrt(1-minimum_match) from the rest of the templates - `random_stochastic` -> Random placement, followed by a stochastic placement to "fill" the holes left by the random method - `geometric` -> Geometric placement: templates are placed on a lattice, with spacing computed according to the minimum match requirement - `iterative` -> Each tile is split iteratively until the number of templates in each subtile is equal to one - `iterative_stochastic` -> The outcome of the iterative method is given as input to the stochastic method - `geo_stochastic` -> Geometric placement + stochastic placement - `tile_stochastic` -> Stochastic placement performed for each tile separately - `pruning` -> The volume is covered with some point that are killed by placing the templates - `tile_random` -> Random placement for each tile separately Those methods are listed in `cbc_bank.placing_methods` N_livepoints: float Only applies for `random` method or `pruning` method. For `random` (or related), it represents the number of livepoints to use for the estimation of the coverage fraction. For `pruning`, it amounts to the the ratio between the number of livepoints and the number of templates placed by ``uniform`` placing method. covering_fraction: float Only applies for `random` method or `pruning` method. Fraction of livepoints to be covered before terminating the loop empty_iterations: int Number of consecutive proposal inside a tile to be rejected before the tile is considered full. It only applies to the ``stochastic`` placing method (or related). verbose: bool Whether to print the output Returns ------- new_templates: :class:`~numpy:numpy.ndarray` The templates generated (already added to the bank) """ ####### # Some initial checks ####### assert placing_method in self.placing_methods, ValueError("Wrong placing method '{}' selected. The methods available are: ".format(placing_method, self.placing_methods)) assert self.D == tiling[0][0].maxes.shape[0], ValueError("The tiling doesn't match the chosen variable format (space dimensionality mismatch)") if self.variable_format.startswith('m1m2_'): raise RuntimeError("Currently mbank does not support template placing with m1m2 format for the masses") for R, M in tiling: eigs, _ = np.linalg.eig(M) #sanity checks on the metric eigenvalues if np.any(eigs < 0): warnings.warn("The metric has a negative eigenvalue @ {}: the template placing in this tile may be unreliable. This is pathological as the metric computation may have failed.\nEigvs are: {}".format((R.maxes+R.mins)/2, eigs)) abs_det = np.abs(np.prod(eigs)) if abs_det < 1e-50: #checking if the determinant is close to zero... msg = "The determinant of the metric is zero! It is impossible to place templates into this tile: maybe the approximant you are using is degenerate with some of the sampled quantities?\nRectangle: {}\nMetric: {}".format(R, M) raise ValueError(msg) ####### # Initializing some useful quantities ####### if placing_method in ['geometric', 'geo_stochastic', 'random', 'random_stochastic', 'iterative_stochastic'] : coarse_boundaries = np.stack([tiling.boundaries.mins, tiling.boundaries.maxes], axis =0) #(2,D) dist = avg_dist(minimum_match, self.D) #desired average distance between templates if verbose: print("Approx number of templates {}".format(int(tiling.compute_volume()[0] / np.power(dist, self.D)))) #total number of points according to volume placement N_points = lambda t: N_livepoints*t.compute_volume()[0] / np.power(np.sqrt(1-minimum_match), self.D) new_templates = [] if placing_method in ['stochastic', 'random', 'pruning', 'uniform', 'qmc', 'random_stochastic']: it = iter(()) elif verbose: it = tqdm(range(len(tiling)), desc = 'Placing the templates within each tile', leave = True) else: it = range(len(tiling)) ####### # Loop on the tiling (for methods that requires it) ####### for i in it: t = tiling[i] #current tile boundaries_ij = np.stack([t[0].mins, t[0].maxes], axis =0) #boundaries of the tile eigs, _ = np.linalg.eig(t[1]) #eigenvalues if placing_method in ['geometric', 'geo_stochastic']: new_templates_ = create_mesh(2*np.sqrt(1-minimum_match), t, coarse_boundaries = None) #(N,D) elif placing_method in ['iterative', 'iterative_stochastic']: new_templates_ = place_iterative(minimum_match, t) elif placing_method == 'tile_stochastic': new_templates_ = place_stochastically_in_tile(minimum_match, t) elif placing_method == 'tile_random': temp_t_ = tiling_handler(t) new_templates_ = place_random_tiling(minimum_match, temp_t_, N_livepoints = N_livepoints, tolerance = 0.01, verbose = False) new_templates.extend(new_templates_) ####### # Placing the templates with the remaining methods ####### if placing_method in ['uniform', 'qmc']: vol_tot, _ = tiling.compute_volume() N_templates = int( vol_tot/(dist**self.D) )+1 if tiling.flow and placing_method == 'uniform': new_templates = tiling.sample_from_flow(N_templates) else: new_templates = tiling.sample_from_tiling(N_templates, qmc = (placing_method=='qmc')) if placing_method in ['random', 'random_stochastic']: new_templates = place_random_tiling(minimum_match, tiling, N_livepoints = N_livepoints, covering_fraction = covering_fraction, verbose = verbose) if placing_method in ['geo_stochastic', 'random_stochastic', 'iterative_stochastic', 'stochastic']: new_templates = place_stochastically(minimum_match, tiling, empty_iterations = empty_iterations, seed_bank = None if placing_method == 'stochastic' else new_templates, verbose = verbose) if placing_method in ['pruning']: if tiling.flow: warnings.warn("Currently the flow is not implemented with the pruning method") #As a rule of thumb, the fraction of templates/N_livepoints must be below 10% (otherwise, bad injection recovery) N_points_max = int(1e6) N_points_tot = N_points(tiling) if N_points_tot >N_points_max: thresholds = plawspace(coarse_boundaries[0,0], coarse_boundaries[1,0], -8./3., int(N_points_tot/N_points_max)+2)[1:-1] partition = partition_tiling(thresholds, 0, tiling) #print("\tThresholds: ",thresholds) #print("\tN_points: ", [int(N_points(p)) for p in partition]) else: partition = [tiling] #print(N_points_tot, len(partition)) new_templates = [] if verbose: it = tqdm(partition, desc = 'Loops on the partitions for random placement') else: it = partition for p in it: #TODO: make this a ray function? Too much memory expensive, probably... #The template volume for random is sqrt(1-MM) (not dist) new_templates_ = place_pruning(minimum_match, p, N_points = int(N_points(p)), covering_fraction = covering_fraction, verbose = verbose) new_templates.extend(new_templates_) new_templates = np.stack(new_templates, axis =0) self.add_templates(new_templates) return new_templates
[docs] def generate_bank(self, metric_obj, minimum_match, boundaries, tolerance, placing_method = 'random', metric_type = 'hessian', grid_list = None, train_flow = False, use_ray = False, N_livepoints = 50, covering_fraction = 0.01, empty_iterations = 100, max_depth = 6, n_layers = 2, hidden_features = 4, N_epochs = 1000, verbose = True): #FIXME: here you should use kwargs, directing the user to the docs of other functions? """ **DEPRECATED** Generates a bank using a hierarchical hypercube tesselation. The bank generation consists in two steps: 1. Tiling generation by iterative splitting of the parameter space 2. Template placing in each tile, according to the method given in ``placing_method`` Parameters ---------- metric_obj: cbc_metric A cbc_metric object to compute the match with minimum_match: float Average match between templates boundaries: :class:`~numpy:numpy.ndarray` shape: (2,D) - An array with the boundaries for the model. Lower limit is boundaries[0,:] while upper limits is boundaries[1,:] tolerance: float Threshold used for the tiling algorithm. It amounts to the maximum tolerated relative change between the metric determinant of the child and the parent ``|M|``. For more information, see `mbank.handlers.tiling_handler.create_tiling` placing_method: str The placing method to set templates in each tile. See `place_templates` for more information. metric_type: str The method computation method to use. For more information, you can check ``metric.cbc_metric.get_metric``. train_flow: bool Whether to train a normalizing flow model after the tiling is generated. It will be used for metric interpolation during the template placing grid_list: list A list of ints, each representing the number of coarse division of the space. If use ray option is set, the subtiling of each coarse division will run in parallel If None, no prior splitting will be made. use_ray: bool Whether to use ray to parallelize N_livepoints: float Only applies for `random` method or `pruning` method. For `random` (or related), it represents the number of livepoints to use for the estimation of the coverage fraction. For `pruning`, it amounts to the the ratio between the number of livepoints and the number of templates placed by ``uniform`` placing method. covering_fraction: float Only applies for `random` method or `pruning` method. Fraction of livepoints to be covered before terminating the template placing empty_iterations: int Number of consecutive proposal inside a tile to be rejected before the tile is considered full. It only applies to the ``stochastic`` placing method. max_depth: int Maximum number of splitting before quitting the iteration. If None, the iteration will go on until the volume condition is not met n_layers: int Number of layers of the flow See `mbank.flow.STD_GW_flow` for more information hidden_features: int Number of hidden features for the masked autoregressive flow in use. See `mbank.flow.STD_GW_flow` for more information N_epochs: int Number of epochs for the training of the flow verbose: bool Whether to print some output Returns ------- tiling: tiling_handler A list of tiles used for the bank generation """ ## #Initialization & various checks assert minimum_match<1. and minimum_match>0., "`minimum_match` should be in the range (0,1)!" if minimum_match <0.9: msg = "Average match is set to be smaller than 0.9. Although the code will work, this can give unreliable results as the metric match approximation is less accurate." warnings.warn(msg) dist = avg_dist(minimum_match, self.D) #desired average distance in the metric space between templates if self.variable_format.startswith('m1m2_'): raise RuntimeError("Currently mbank does not support template placing with m1m2 format for the masses") if grid_list is None: grid_list = [1 for i in range(self.D)] assert len(grid_list) == self.D, "Wrong number of grid sizes. Expected {}; given {}".format(self.D, len(grid_list)) ### #creating a proper grid list for a coarse boundary creation boundaries_list = split_boundaries(boundaries, grid_list, use_plawspace = True) ### #creating the tiling metric_fun = lambda center: metric_obj.get_metric(center, overlap = False, metric_type = metric_type) #metric_type = 'hessian') #metric_type = 'block_diagonal_hessian') #metric_type = 'parabolic_fit_hessian', target_match = 0.9, N_epsilon_points = 10, log_epsilon_range = (-4, 1)) t_obj = tiling_handler() #empty tiling handler t_obj.create_tiling_from_list(boundaries_list, tolerance, metric_fun, max_depth = max_depth, use_ray = use_ray ) if train_flow: t_obj.train_flow(N_epochs = N_epochs, n_layers = n_layers, hidden_features = hidden_features, verbose = verbose) ## #placing the templates #(if there is KeyboardInterrupt, the tiling is returned anyway) try: self.place_templates(t_obj, minimum_match, placing_method = placing_method, N_livepoints = N_livepoints, covering_fraction =covering_fraction, empty_iterations = empty_iterations, verbose = verbose) except KeyboardInterrupt: self.templates = None return t_obj
[docs] def enforce_boundaries(self, boundaries): """ Remove from the bank the templates that do not lie within the given boundaries Parameters ---------- boundaries: :class:`~numpy:numpy.ndarray` shape: (2,D) - An array with the boundaries for the model. Lower limit is boundaries[0,:] while upper limits is boundaries[1,:] """ boundaries = np.asarray(boundaries) if self.templates is None: return ids_ok = np.logical_and(np.all(self.templates > boundaries[0,:], axis =1), np.all(self.templates < boundaries[1,:], axis = 1)) #(N,) new_bank_size = sum(ids_ok) if new_bank_size == 0: self.templates = None warnings.warn("No template fits into the boundaries") elif new_bank_size < self.templates.shape[0]: self.templates = self.templates[ids_ok,:] else: pass #print("The bank already fits into the boundaries") return
[docs] def generate_bank_mcmc(self, metric_obj, N_templates, boundaries, n_walkers = 100, use_ray = False, thin_factor = None, load_chain = None, save_chain = None, verbose = True): """ Fills the bank with a Markov Chain Monte Carlo (MCMC) method. The MCMC sample from the probability distribution function induced by the metric: .. math:: p(\\theta) \propto \\sqrt{|M(\\theta)|} The function uses `emcee` package, not in the `mbank` dependencies. Parameters ---------- metric_obj: cbc_metric A cbc_metric objec to compute the PDF to distribute the templates N_templates: int Number of new templates to sample from the PDF boundaries: :class:`~numpy:numpy.ndarray` shape: (2,D) - An array with the boundaries for the model. Lower limit is boundaries[0,:] while upper limits is boundaries[1,:] n_walkers: int Number of independent walkers during the chain. If `use_ray` option is `True`, they will be run in parellel. use_ray: bool Whether to use `ray` to parallelize the sampling. #DO THIS OPTION thin_factor: int How many MC steps to discard before selecting one. If `None` it is computed authomatically based on the autocorrelation: this is the recommended behaviour load_chain: str Path to a file where the position of each walker is stored, togheter with integrated aucorellation tau. The file must keep a np.array of dimensions (n_walkers, D). The first line of the file is intended to be the autocorrelation time for each variable. If it is not provided, a standard value of 4 (meaning a thin step of 2) is assumed. If set, the sampler will start from there and the burn-in phase will not be required. save_chain: str If not None, it saves the path in which to save the status of the sampler. The file saved is ready to be loaded with option `load_chain` verbose: bool whether to print to screen the output """ try: import emcee except ModuleNotFoundError: raise ModuleNotFoundError("Unable to sample from the metric PDF as package `emcee` is not installed. Please try `pip install emcee`") burnin_steps = lambda tau: int(2 * np.max(tau)) if burnin else 0 #initializing the sampler sampler = emcee.EnsembleSampler(n_walkers, self.D, metric_obj.log_pdf, args=[boundaries], vectorize = True) #tau is a (D,) vector holding the autocorrelation for each variable #it is used to estimate the thin factor if isinstance(load_chain, str): #this will output an estimate of tau and a starting chain. The actual sampling will start straight away burnin = False loaded_chain = np.loadtxt(load_chain) if loaded_chain.shape[0]<n_walkers: raise RuntimeError("The given input file does not have enough walkers. Required {} but given {}".format(n_walkers, loaded_chain.shape[0])) elif loaded_chain.shape[0] == n_walkers: start = loaded_chain tau = 4 + np.zeros((self.D,)) else: tau, start = loaded_chain[0,:], loaded_chain[1:n_walkers+1,:] print('tau', tau) assert start.shape == (n_walkers, self.D), "Wrong shape for the starting chain. Expected {} but given {}. Unable to continue".format((n_walkers, self.D), start.shape) else: burnin = True start = np.random.uniform(*boundaries, (n_walkers, self.D)) n_burnin = 0 ########### #This part has two purposes: # - Give a first estimation for tau parameters (required to decide the size of burn-in steps and the thin step) # - Do a burn in phase (discard some samples to achieve stationariety) ########### if burnin: tau_list = [] step = 30 def dummy_generator(): #dummy generator for having an infinite loop while True: yield if verbose: it_obj = tqdm(dummy_generator(), desc='Burn-in/calibration phase') else: it_obj = dummy_generator() for _ in it_obj: n_burnin += step state = sampler.run_mcmc(start, nsteps = step, progress = False, tune = False) start = state.coords #very important! The chain will start from here tau = sampler.get_autocorr_time(tol = 0) tau_list.append(tau) if len(tau_list)>1 and np.all(np.abs(tau_list[-2]-tau_list[-1]) < 0.001*tau_list[-1]): tau = tau_list[-1] break if verbose: print("") ########### #doing the actual sampling if thin_factor is None: thin = max(int(0.5 * np.min(tau)),1) else: thin = thin_factor if verbose: print('Thin factor: {} | burn-in: {} '.format( thin, burnin_steps(tau))) n_steps = int((N_templates*thin)/n_walkers) - int(n_burnin) + burnin_steps(tau) + thin #steps left to do... if verbose: print("Steps done: {} | Steps to do: {}".format(n_burnin, n_steps)) if n_steps > 0: try: state = sampler.run_mcmc(start, n_steps, progress = verbose, tune = True) except KeyboardInterrupt: pass #FIXME: understand whether you want to change the thin factor... it is likely underestimated during the burn-in phase #On the other hand, it makes difficult to predict how many steps you will need #updating thin factor #FIXME: this needs to be taken into account! if thin_factor is None and False: tau = sampler.get_autocorr_time(tol = 0) thin = max(int(0.5 * np.min(tau)),1) if verbose: print('Updated -- Thin factor: {} | burn-in: {} '.format( thin, burnin_steps(tau))) chain = sampler.get_chain(discard = burnin_steps(tau), thin = thin, flat=True)[-N_templates:,:] if isinstance(save_chain, str) and (state is not None): chain_to_save = state.coords #(n_walkers, D) to_save = np.concatenate([tau[None,:], chain_to_save], axis = 0) np.savetxt(save_chain, to_save) #adding chain to the bank self.add_templates(chain) return