Source code for idtxl.network_inference

"""Parent class for all network inference."""
import numpy as np

from . import idtxl_exceptions as ex
from . import stats
from .network_analysis import NetworkAnalysis


[docs]class NetworkInference(NetworkAnalysis): """Parent class for network inference algorithms. Hold variables that are relevant for network inference using for example bivariate and multivariate transfer entropy. Attributes: settings : dict settings for estimation of information theoretic measures and statistical testing, see child classes for documentation target : int target process of analysis current_value : tuple index of the current value selected_vars_full : list of tuples indices of the full set of random variables to be conditioned on selected_vars_target : list of tuples indices of the set of conditionals coming from the target process selected_vars_sources : list of tuples indices of the set of conditionals coming from source processes """ def __init__(self): # Create class attributes for estimation self.statistic_omnibus = None self.sign_omnibus = False self.pvalue_omnibus = None self.statistic_sign_sources = None self.pvalues_sign_sources = None self._cmi_estimator = None self.source_set = [] super().__init__() def _check_target(self, target, n_processes): """Set and check the target provided by the user.""" if not isinstance(target, int) or target < 0: raise RuntimeError( f"The index of the target process ({target}) has to " "be an int >= 0." ) if target > n_processes: raise RuntimeError( f"Trying to analyse target with index {target}, which greater than the " f"number of processes in the data ({n_processes})." ) self.target = target def _check_source_set(self, sources, n_processes): """Set default if no source set was provided by the user.""" if sources == "all": sources = list(range(n_processes)) sources.pop(self.target) elif isinstance(sources, int): sources = [sources] elif isinstance(sources, list): assert isinstance(sources[0], int), "Source list has to contain ints." else: raise TypeError( "Sources have to be passes as a single int, list " 'of ints or "all".' ) if self.target in sources: raise RuntimeError( f"The target ({self.target}) should not be in the list of sources ({sources})." ) if max(sources) > n_processes: raise RuntimeError( f"The list of sources {sources} contains indices greater than " f"the number of processes {n_processes} in the data." ) if min(sources) < 0: raise RuntimeError( f"The source list ({sources}) can not contain negative indices." ) self.source_set = sources if self.settings["verbose"]: print(f"\nTarget: {self.target} - testing sources {self.source_set}") def _include_candidates(self, candidate_set, data): """Include informative candidates into the conditioning set. Loop over each candidate in the candidate set and test if it has significant mutual information with the current value, conditional on all samples that were informative in previous rounds and are already in the conditioning set. If this conditional mutual information is significant using maximum statistics, add the current candidate to the conditional set. Args: candidate_set : list of tuples candidate set to be tested, where each entry is a tuple (process index, sample index) data : Data instance raw data Returns: bool True if a candidate with significant MI was found """ success = False if self.settings["verbose"]: print(f"candidate set: {self._idx_to_lag(candidate_set)}") while candidate_set: # Get realisations for all candidates. cand_real = data.get_realisations(self.current_value, candidate_set)[0] # Reshape candidates to a 1D-array, where realisations for a single # candidate are treated as one chunk. cand_real = cand_real.T.reshape(cand_real.size, 1) # Calculate the (C)MI for each candidate and the target. try: temp_te = self._cmi_estimator.estimate_parallel( n_chunks=len(candidate_set), re_use=["var2", "conditional"], var1=cand_real, var2=self._current_value_realisations, conditional=self._selected_vars_realisations, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the search for more candidates, # though those identified already remain valid print( f"AlgorithmExhaustedError encountered in estimations: {aee.message}. " "Halting current estimation set." ) # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # Test max CMI for significance with maximum statistics. te_max_candidate = max(temp_te) max_candidate = candidate_set[np.argmax(temp_te)] if self.settings["verbose"]: print( "testing candidate: {0} ".format( self._idx_to_lag([max_candidate])[0] ), end="", ) try: significant = stats.max_statistic( self, data, candidate_set, te_max_candidate, conditional=self._selected_vars_realisations, )[0] except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so we'll terminate the # check of significance for this candidate, though those # identified already remain valid print( f"AlgorithmExhaustedError encountered in estimations: {aee.message}" ) print("Halting candidate max stats test") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # If the max is significant keep it and test the next candidate. If # it is not significant break. There will be no further significant # sources b/c they all have lesser TE. if significant: # if self.settings['verbose']: # print(' -- significant') success = True candidate_set.pop(np.argmax(temp_te)) self._append_selected_vars( [max_candidate], data.get_realisations(self.current_value, [max_candidate])[0], ) if self.settings["write_ckp"]: self._write_checkpoint() else: if self.settings["verbose"]: print(" -- not significant") break return success def _force_conditionals(self, cond, data): """Enforce a given conditioning set. Manually add variables to the conditioning set before analysis. Added variables are not tested in the inclusion step of the algorithm, but are tested in the pruning step and may be removed there. Source and target past and current variables can be included. Args: cond : str | dict | list | tuple variables added to the conditioning set, 'faes' adds all source variables with zero-lag to condition out shared information due to instantaneous mixing, a dict can contain a list of variables for each target ({target ind: [(source ind, lag), ...]}), a list of the same variables added for each target ([(source ind, lag), ...]), a tuple with a single variable that is added for each target data : Data instance input data """ if isinstance(cond, str): # Get realisations and indices of source variables with lag 0. Note # that _define_candidates returns tuples with absolute indices and # not lags. if cond == "faes": cond = self._build_variable_list( self.source_set, [self.current_value[1]] ) self._append_selected_vars( cond, data.get_realisations(self.current_value, cond)[0] ) else: # If specific variables for conditioning were provided, convert # lags to absolute sample indices and add variables. if type(cond) is tuple: # easily add single variable cond = [cond] elif type(cond) is dict: # add conditioning variables per target try: cond = cond[self.target] except KeyError: return # no additional variables for the current target print(f"Adding the following variables to the conditioning set: {cond}.") cond_idx = self._lag_to_idx(cond) self._append_selected_vars( cond_idx, data.get_realisations(self.current_value, cond_idx)[0] ) def _remove_non_significant(self, s, p, stat): # Remove non-significant sources from the candidate set. Loop # backwards over the candidates to remove them iteratively. print(f"removing {np.sum(np.invert(s))} variables after seq. max stats") for i in range(s.shape[0] - 1, -1, -1): if not s[i]: self._remove_selected_var(self.selected_vars_sources[i]) p = np.delete(p, i) stat = np.delete(stat, i) return p, stat
[docs]class NetworkInferenceMI(NetworkInference): """Parent class for mutual information network inference algorithms.""" def __init__(self): self.measure = "mi" super().__init__() def _initialise(self, settings, data, sources, target): """Check input, set initial or default values for analysis settings.""" # Check analysis settings and set defaults. self.settings = settings.copy() self.settings.setdefault("verbose", True) self.settings.setdefault("add_conditionals", None) self.settings.setdefault("tau_sources", 1) self.settings.setdefault("local_values", False) # Check lags and taus for multivariate embedding. if "max_lag_sources" not in self.settings: raise RuntimeError( "The maximum lag for source embedding " "(" "max_lag_sources" ") needs to be specified." ) if "min_lag_sources" not in self.settings: raise RuntimeError( "The minimum lag for source embedding " "(" "min_lag_sources" ") needs to be specified." ) if ( not isinstance(self.settings["min_lag_sources"], int) or self.settings["min_lag_sources"] < 0 ): raise RuntimeError("min_lag_sources has to be an integer >= 0.") if ( not isinstance(self.settings["max_lag_sources"], int) or self.settings["max_lag_sources"] < 0 ): raise RuntimeError("max_lag_sources has to be an integer >= 0.") if ( not isinstance(self.settings["tau_sources"], int) or self.settings["tau_sources"] < 0 ): raise RuntimeError("tau_sources must be an integer >= 0.") if self.settings["min_lag_sources"] > self.settings["max_lag_sources"]: raise RuntimeError( f"min_lag_sources ({self.settings['min_lag_sources']}) must be smaller or equal " f"to max_lag_sources ({self.settings['max_lag_sources']})." ) # max_lag_sources can be 0 for MI estimation, in this case we don't # require the tau to be larger than the max lag. Still, tau has to be # one to later generate the candidate set via enumerating all samples. if ( self.settings["max_lag_sources"] > 0 and self.settings["tau_sources"] > self.settings["max_lag_sources"] ): raise RuntimeError( f"tau_sources ({self.settings['tau_sources']}) has to be smaller than " f"max_lag_sources ({self.settings['max_lag_sources']})." ) # Set CMI estimator. self._set_cmi_estimator() # Check the provided target and sources. self._check_target(target, data.n_processes) self._check_source_set(sources, data.n_processes) # Check provided search depths (lags) for sources, set the # current_value. assert data.n_samples >= self.settings["max_lag_sources"] + 1, ( f"Not enough samples in data ({data.n_samples}) to allow for the chosen maximum " f"lag ({self.settings['max_lag_sources']})" ) self.current_value = (self.target, self.settings["max_lag_sources"]) [cv_realisation, repl_idx] = data.get_realisations( current_value=self.current_value, idx_list=[self.current_value] ) self._current_value_realisations = cv_realisation # Remember which realisations come from which replication. This may be # needed for surrogate creation at a later point. self._replication_index = repl_idx # Check the permutation type and no. permutations requested by the # user. This tests if there is sufficient data to do all tests. # surrogates.check_permutations(self, data) # Check and set defaults for checkpointing. If requested, initialise # checkpointing. self.settings = self._set_checkpointing_defaults( self.settings, data, sources, target ) # Reset all attributes to inital values if the instance of # MultivariateTE has been used before. if self.selected_vars_full: self.selected_vars_full = [] self._selected_vars_realisations = None self.selected_vars_sources = [] self.mi_omnibus = None self.pvalue_omnibus = None self.pvalues_sign_sources = None self.mi_sign_sources = None self._min_stats_surr_table = None # Check if the user provided a list of candidates that must go into # the conditioning set. These will be added and used for TE estimation, # but never tested for significance. if self.settings["add_conditionals"] is not None: self._force_conditionals(self.settings["add_conditionals"], data) def _reset(self): """Reset instance after analysis.""" self.__init__() del self.settings del self.source_set del self.pvalues_sign_sources del self.statistic_sign_sources del self.statistic_omnibus del self.pvalue_omnibus del self.sign_omnibus del self._cmi_estimator
[docs]class NetworkInferenceTE(NetworkInference): """Parent class for transfer entropy network inference algorithms.""" def __init__(self): self.measure = "te" super().__init__() def _initialise(self, settings, data, sources, target): """Check input, set initial or default values for analysis settings.""" # Check analysis settings and set defaults. self.settings = settings.copy() self.settings.setdefault("verbose", True) self.settings.setdefault("add_conditionals", None) self.settings.setdefault("tau_target", 1) self.settings.setdefault("tau_sources", 1) self.settings.setdefault("local_values", False) # Check lags and taus for multivariate embedding. if "max_lag_sources" not in self.settings: raise RuntimeError( "The maximum lag for source embedding " "(" "max_lag_sources" ") needs to be specified." ) if "min_lag_sources" not in self.settings: raise RuntimeError( "The minimum lag for source embedding " "(" "min_lag_sources" ") needs to be specified." ) self.settings.setdefault("max_lag_target", settings["max_lag_sources"]) if ( type(self.settings["min_lag_sources"]) is not int or self.settings["min_lag_sources"] < 0 ): raise RuntimeError("min_lag_sources has to be an integer >= 0.") if ( type(self.settings["max_lag_sources"]) is not int or self.settings["max_lag_sources"] < 0 ): raise RuntimeError("max_lag_sources has to be an integer >= 0.") if ( type(self.settings["max_lag_target"]) is not int or self.settings["max_lag_target"] <= 0 ): raise RuntimeError("max_lag_target must be an integer > 0.") if ( type(self.settings["tau_sources"]) is not int or self.settings["tau_sources"] < 0 ): raise RuntimeError("tau_sources must be an integer >= 0.") if ( type(self.settings["tau_target"]) is not int or self.settings["tau_target"] < 1 ): raise RuntimeError("tau_sources must be an integer > 0.") if self.settings["min_lag_sources"] > self.settings["max_lag_sources"]: raise RuntimeError( "min_lag_sources ({0}) must be smaller or equal" " to max_lag_sources ({1}).".format( self.settings["min_lag_sources"], self.settings["max_lag_sources"] ) ) if self.settings["tau_sources"] > self.settings["max_lag_sources"]: raise RuntimeError( "tau_sources ({0}) has to be smaller than " "max_lag_sources ({1}).".format( self.settings["tau_sources"], self.settings["max_lag_sources"] ) ) if self.settings["tau_target"] > self.settings["max_lag_target"]: raise RuntimeError( "tau_target ({0}) has to be smaller than " "max_lag_target ({1}).".format( self.settings["tau_target"], self.settings["max_lag_target"] ) ) # Set CMI estimator. self._set_cmi_estimator() # Check the provided target and sources. self._check_target(target, data.n_processes) self._check_source_set(sources, data.n_processes) # Check provided search depths (lags) for source and target, set the # current_value. max_lag = max(self.settings["max_lag_sources"], self.settings["max_lag_target"]) assert data.n_samples >= max_lag + 1, ( "Not enough samples in data ({0}) to allow for the chosen maximum " "lag ({1})".format(data.n_samples, max_lag) ) self.current_value = (self.target, max_lag) [cv_realisation, repl_idx] = data.get_realisations( current_value=self.current_value, idx_list=[self.current_value] ) self._current_value_realisations = cv_realisation # Remember which realisations come from which replication. This may be # needed for surrogate creation at a later point. self._replication_index = repl_idx # Check the permutation type and no. permutations requested by the # user. This tests if there is sufficient data to do all tests. # surrogates.check_permutations(self, data) # Check and set defaults for checkpointing. If requested, initialise # checkpointing. self.settings = self._set_checkpointing_defaults( self.settings, data, sources, target ) # Reset all attributes to inital values if the instance of # MultivariateTE has been used before. if self.selected_vars_full: self.selected_vars_full = [] self._selected_vars_realisations = None self.selected_vars_sources = [] self.selected_vars_target = [] self.statistic_omnibus = None self.pvalue_omnibus = None self.pvalues_sign_sources = None self.te_sign_sources = None self._min_stats_surr_table = None # Check if the user provided a list of candidates that must go into # the conditioning set. These will be added and used for TE estimation, # but never tested for significance. if self.settings["add_conditionals"] is not None: self._force_conditionals(self.settings["add_conditionals"], data) def _include_target_candidates(self, data): """Test candidates from the target's past.""" procs = [self.target] # Make samples samples = np.arange( self.current_value[1] - 1, self.current_value[1] - self.settings["max_lag_target"] - 1, -self.settings["tau_target"], ).tolist() candidates = self._define_candidates(procs, samples) sources_found = self._include_candidates(candidates, data) # If no candidates were found in the target's past, add at least one # sample so we are still calculating a proper TE. if not sources_found: print( "\nNo informative sources in the target's past - " "adding target sample with lag 1." ) idx = (self.current_value[0], self.current_value[1] - 1) realisations = data.get_realisations(self.current_value, [idx])[0] self._append_selected_vars([idx], realisations) def _reset(self): """Reset instance after analysis.""" self.__init__() del self.settings del self.source_set del self.pvalues_sign_sources del self.statistic_sign_sources del self.statistic_omnibus del self.pvalue_omnibus del self.sign_omnibus del self._cmi_estimator
[docs]class NetworkInferenceBivariate(NetworkInference): """Parent class for bivariate network inference algorithms.""" def __init__(self): super().__init__() def _include_source_candidates(self, data): """Inlcude informative candidates into the conditioning set. Loop over each candidate in the candidate set and test if it has significant mutual information with the current value, conditional on all samples that were informative in previous rounds and are already in the conditioning set. If this conditional mutual information is significant using maximum statistics, add the current candidate to the conditional set. Args: data : Data instance raw data """ # Define samples for candidate sets. if self.settings["max_lag_sources"] == 0: samples = np.zeros(1).astype(int) else: samples = np.arange( self.current_value[1] - self.settings["min_lag_sources"], self.current_value[1] - self.settings["max_lag_sources"] - 1, -self.settings["tau_sources"], ) # Check if target variables were selected to distinguish between TE # and MI analysis. if len(self._selected_vars_target) == 0: conditional_realisations_target = None else: conditional_realisations_target = self._selected_vars_target_realisations # Iterate over all potential sources in the analysis. This way, the # conditioning uses past variables from the current source only # (opposed to past variables from all sources as in multivariate # network inference). success = False for source in self.source_set: candidate_set = self._define_candidates([source], samples) if self.settings["verbose"]: print( "candidate set current source: {0}\n".format( self._idx_to_lag(candidate_set) ), end="", ) # Initialise conditional realisations. This gets updated if sources # are selected in the iterative conditioning. For MI calculation # this is None. conditional_realisations = conditional_realisations_target while candidate_set: # Get realisations for all candidates. cand_real = data.get_realisations(self.current_value, candidate_set)[0] # Reshape candidates to a 1D-array, where realisations for a # single candidate are treated as one chunk. cand_real = cand_real.T.reshape(cand_real.size, 1) # Calculate the (C)MI for each candidate and the target. try: temp_te = self._cmi_estimator.estimate_parallel( n_chunks=len(candidate_set), re_use=["var2", "conditional"], var1=cand_real, var2=self._current_value_realisations, conditional=conditional_realisations, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the search for more candidates, # though those identified already remain valid print( "AlgorithmExhaustedError encountered in " "estimations: " + aee.message ) print("Halting current estimation set.") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # Test max CMI for significance with maximum statistics. te_max_candidate = max(temp_te) max_candidate = candidate_set[np.argmax(temp_te)] if self.settings["verbose"]: print( "testing candidate: {0} ".format( self._idx_to_lag([max_candidate])[0] ), end="", ) try: significant = stats.max_statistic( self, data, candidate_set, te_max_candidate, conditional_realisations, )[0] except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the significance check for this # candidate, though those identified already remain valid. print( "AlgorithmExhaustedError encountered in " "estimations: " + aee.message ) print("Halting candidate max stats test") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # If the max is significant move it from the candidate set to # the set of selected sources and test the next candidate. If # it is not significant break. There will be no further # significant sources b/c they all have lesser TE. if significant: success = True candidate_set.pop(np.argmax(temp_te)) candidate_realisations = data.get_realisations( self.current_value, [max_candidate] )[0] self._append_selected_vars([max_candidate], candidate_realisations) # Update conditioning set for max. statistics in the next # round. if conditional_realisations is None: conditional_realisations = candidate_realisations else: conditional_realisations = np.hstack( (conditional_realisations, candidate_realisations) ) if self.settings["write_ckp"]: self._write_checkpoint() else: if self.settings["verbose"]: print(" -- not significant") break return success def _prune_candidates(self, data): """Remove uninformative candidates from the final conditional set. For each sample in the final conditioning set, check if it is informative about the current value given all other samples in the final set. If a sample is not informative, it is removed from the final set. Args: data : Data instance raw data """ # FOR LATER we don't need to test the last included in the first round if self.settings["verbose"]: if not self.selected_vars_sources: print("no sources selected, nothing to prune ...") # Check if target variables were selected to distinguish between TE # and MI analysis. if len(self._selected_vars_target) == 0: conditional_realisations_target = None cond_target_dim = 0 else: conditional_realisations_target = self._selected_vars_target_realisations cond_target_dim = conditional_realisations_target.shape[1] # Prune all selected sources separately. This way, the conditioning # uses past variables from the current source only (opposed to past # variables from all sources as in multivariate network inference). significant_sources = np.unique([s[0] for s in self.selected_vars_sources]) for source in significant_sources: # Find selected past variables for current source print("selected vars sources {0}".format(self.selected_vars_sources)) source_vars = [s for s in self.selected_vars_sources if s[0] == source] print( "selected candidates current source: {0}".format( self._idx_to_lag(source_vars) ) ) # If only a single variable was selected for the current source, no # pruning is necessary. The minimum statistic would be equal to the # maximum statistic for this variable. if len(source_vars) == 1: if self.settings["verbose"]: print(" -- significant") continue # Find the candidate with the minimum TE/MI into the target. while source_vars: # Allocate memory, collect realisations, and calculate TE/MI # in parallel for all selected variables in the current # process. temp_te = np.empty(len(source_vars)) cond_dim = cond_target_dim + len(source_vars) - 1 candidate_realisations = np.empty( (data.n_realisations(self.current_value) * len(source_vars), 1) ).astype(data.data_type) conditional_realisations = np.empty( ( data.n_realisations(self.current_value) * len(source_vars), cond_dim, ) ).astype(data.data_type) i_1 = 0 i_2 = data.n_realisations(self.current_value) for candidate in source_vars: temp_cond = data.get_realisations( self.current_value, set(source_vars).difference(set([candidate])), )[0] temp_cand = data.get_realisations(self.current_value, [candidate])[ 0 ] if temp_cond is None: conditional_realisations = conditional_realisations_target re_use = ["var2", "conditional"] else: re_use = ["var2"] if conditional_realisations_target is None: conditional_realisations[i_1:i_2,] = temp_cond else: conditional_realisations[i_1:i_2,] = np.hstack( (temp_cond, conditional_realisations_target) ) candidate_realisations[i_1:i_2,] = temp_cand i_1 = i_2 i_2 += data.n_realisations(self.current_value) try: temp_te = self._cmi_estimator.estimate_parallel( n_chunks=len(source_vars), re_use=re_use, var1=candidate_realisations, var2=self._current_value_realisations, conditional=conditional_realisations, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the pruning check, # assuming that we need not prune any more print( "AlgorithmExhaustedError encountered in estimations: " "{}. Halting current estimation set.".format(aee.message) ) # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # Find variable with minimum MI/TE. Test min TE/MI for # significance with minimum statistics. Build conditioning set # for minimum statistics by removing the minimum candidate. te_min_candidate = min(temp_te) min_candidate = source_vars[np.argmin(temp_te)] if self.settings["verbose"]: print( "testing candidate: {0} ".format( self._idx_to_lag([min_candidate])[0] ), end="", ) remaining_candidates = set(source_vars).difference(set([min_candidate])) conditional_realisations_sources = data.get_realisations( self.current_value, remaining_candidates )[0] if conditional_realisations_target is None: conditional_realisations = conditional_realisations_sources elif conditional_realisations_sources is None: conditional_realisations = conditional_realisations_target else: conditional_realisations = np.hstack( ( conditional_realisations_target, conditional_realisations_sources, ) ) try: [significant, p, surr_table] = stats.min_statistic( self, data, source_vars, te_min_candidate, conditional_realisations, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the pruning check, # assuming that we need not prune any more print( "AlgorithmExhaustedError encountered in " "estimations: " + aee.message ) print("Halting current pruning and allowing others to" " remain.") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # Remove the minimum it is not significant and test the next # min. candidate. If the minimum is significant, break. All # other sources will be significant as well (b/c they have # higher TE/MI). if not significant: self._remove_selected_var(min_candidate) source_vars.pop(np.argmin(temp_te)) if len(source_vars) == 0: print("No remaining candidates after pruning.") if self.settings["write_ckp"]: self._write_checkpoint() else: if self.settings["verbose"]: print(" -- significant") break def _test_final_conditional(self, data): """Perform statistical test on the final conditional set.""" if not self.selected_vars_sources: if self.settings["verbose"]: print("no sources selected ...") self.statistic_omnibus = None self.sign_omnibus = False self.pvalue_omnibus = None self.pvalues_sign_sources = None self.statistic_sign_sources = None self.statistic_single_link = None else: if self.settings["verbose"]: print( "selected variables: {0}".format( self._idx_to_lag(self.selected_vars_full) ) ) try: [s, p, stat] = stats.omnibus_test(self, data) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll set the results to zero print( "AlgorithmExhaustedError encountered in " "estimations: " + aee.message ) print("Halting omnibus test and setting to not significant.") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) stat = 0 s = False p = 1 self.statistic_omnibus = stat self.sign_omnibus = s self.pvalue_omnibus = p # Test individual links if the omnibus test is significant using # the sequential max stats. Remove non-significant links. if self.sign_omnibus: # If there is an ex.AlgorithmExhaustedError exception inside # max_stats_sequential, it will catch it and return # everything as not significant: [s, p, stat] = stats.max_statistic_sequential_bivariate(self, data) p, stat = self._remove_non_significant(s, p, stat) self.pvalues_sign_sources = p self.statistic_sign_sources = stat if self.measure == "te": conditioning = "target" elif self.measure == "mi": conditioning = "none" try: self.statistic_single_link = self._calculate_single_link( data=data, current_value=self.current_value, source_vars=self.selected_vars_sources, target_vars=self.selected_vars_target, sources="all", conditioning=conditioning, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the computation of single link stats. # Since max stats sequential etc all passed up to here, # it seems ok to let everything through still but # just write a 0 for final values print( "AlgorithmExhaustedError encountered in " "final_conditional estimations: " + aee.message ) print("Halting final_conditional estimations") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) self.statistic_single_link = np.zeros( len(self.selected_vars_sources) ) else: self.selected_vars_sources = [] self.selected_vars_full = self.selected_vars_target self.pvalues_sign_sources = None self.statistic_sign_sources = None self.statistic_single_link = None
[docs]class NetworkInferenceMultivariate(NetworkInference): """Parent class for multivariate network inference algorithms.""" def __init__(self): super().__init__() def _include_source_candidates(self, data): """Test candidates in the source's past.""" procs = self.source_set if self.settings["max_lag_sources"] == 0: samples = np.zeros(1).astype(int) else: samples = np.arange( self.current_value[1] - self.settings["min_lag_sources"], self.current_value[1] - self.settings["max_lag_sources"] - 1, -self.settings["tau_sources"], ) candidates = self._define_candidates(procs, samples) # Possible extension in the future: include non-selected target # candidates as further candidates, # they may get selected due to # synergies. self._include_candidates(candidates, data) def _prune_candidates(self, data): """Remove uninformative candidates from the final conditional set. For each sample in the final conditioning set, check if it is informative about the current value given all other samples in the final set. If a sample is not informative, it is removed from the final set. Args: data : Data instance raw data """ # FOR LATER we don't need to test the last included in the first round if self.settings["verbose"]: if self.selected_vars_sources: print( "selected candidates: {0}".format( self._idx_to_lag(self.selected_vars_sources) ) ) else: print("no sources selected, nothing to prune ...") # If only a single variable was selected, no pruning is necessary. The # minimum statistic would be equal to the maximum statistic for this # variable. if len(self.selected_vars_sources) == 1: if self.settings["verbose"]: print(" -- significant") return while self.selected_vars_sources: # Find the candidate with the minimum TE into the target. temp_te = np.empty(len(self.selected_vars_sources)) cond_dim = len(self.selected_vars_full) - 1 candidate_realisations = np.empty( ( data.n_realisations(self.current_value) * len(self.selected_vars_sources), 1, ) ).astype(data.data_type) conditional_realisations = np.empty( ( data.n_realisations(self.current_value) * len(self.selected_vars_sources), cond_dim, ) ).astype(data.data_type) # calculate TE simultaneously for all candidates i_1 = 0 i_2 = data.n_realisations(self.current_value) for candidate in self.selected_vars_sources: # Separate the candidate realisations and all other # realisations to test the candidate's individual contribution. [temp_cond, temp_cand] = self._separate_realisations( self.selected_vars_full, candidate ) if temp_cond is None: conditional_realisations = None re_use = ["var2", "conditional"] else: conditional_realisations[i_1:i_2,] = temp_cond re_use = ["var2"] candidate_realisations[i_1:i_2,] = temp_cand i_1 = i_2 i_2 += data.n_realisations(self.current_value) try: temp_te = self._cmi_estimator.estimate_parallel( n_chunks=len(self.selected_vars_sources), re_use=re_use, var1=candidate_realisations, var2=self._current_value_realisations, conditional=conditional_realisations, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the pruning check, # assuming that we need not prune any more print( "AlgorithmExhaustedError encountered in " "estimations: " + aee.message ) print("Halting current pruning and allowing others to" " remain.") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # Find variable with minimum MI/TE. Test min TE/MI for significance # with minimum statistics. Build conditioning set for minimum # statistics by removing the minimum candidate. te_min_candidate = min(temp_te) min_candidate = self.selected_vars_sources[np.argmin(temp_te)] if self.settings["verbose"]: print( "testing candidate: {0} ".format( self._idx_to_lag([min_candidate])[0] ), end="", ) remaining_candidates = set(self.selected_vars_full).difference( set([min_candidate]) ) conditional_realisations = data.get_realisations( self.current_value, remaining_candidates )[0] try: [significant, p, surr_table] = stats.min_statistic( self, data, self.selected_vars_sources, te_min_candidate, conditional_realisations, ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the pruning check, # assuming that we need not prune any more print( "AlgorithmExhaustedError encountered in " "estimations: " + aee.message ) print("Halting current pruning and allowing others to" " remain.") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) break # Remove the minimum it is not significant and test the next min. # candidate. If the minimum is significant, break, all other # sources will be significant as well (b/c they have higher TE). if not significant: # if self.settings['verbose']: # print(' -- not significant\n') self._remove_selected_var(min_candidate) if len(self.selected_vars_sources) == 0: print("No remaining candidates after pruning.") if self.settings["write_ckp"]: self._write_checkpoint() else: if self.settings["verbose"]: print(" -- significant") self._min_stats_surr_table = surr_table break def _test_final_conditional(self, data): """Perform statistical test on the final conditional set.""" if not self.selected_vars_sources: if self.settings["verbose"]: print("no sources selected ...") self.statistic_omnibus = None self.sign_omnibus = False self.pvalue_omnibus = None self.pvalues_sign_sources = None self.statistic_sign_sources = None self.statistic_single_link = None else: if self.settings["verbose"]: print( "selected variables: {0}".format( self._idx_to_lag(self.selected_vars_full) ) ) try: [s, p, stat] = stats.omnibus_test(self, data) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll set the results to zero print( "AlgorithmExhaustedError encountered in estimations: {}. " "Halting current estimation set.".format(aee.message) ) # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) stat = 0 s = False p = 1 self.statistic_omnibus = stat self.sign_omnibus = s self.pvalue_omnibus = p # Test individual links if the omnibus test is significant using # the sequential max stats. Remove non-significant links. if self.sign_omnibus: # If there is an ex.AlgorithmExhaustedError exception inside # max_stats_sequential, it will catch it and return # everything as not significant: [s, p, stat] = stats.max_statistic_sequential(self, data) p, stat = self._remove_non_significant(s, p, stat) self.pvalues_sign_sources = p self.statistic_sign_sources = stat # Calculate TE for all links in the network. Calculate local TE # if requested by the user. try: self.statistic_single_link = self._calculate_single_link( data=data, current_value=self.current_value, source_vars=self.selected_vars_sources, target_vars=self.selected_vars_target, sources="all", conditioning="full", ) except ex.AlgorithmExhaustedError as aee: # The algorithm cannot continue here, so # we'll terminate the computation of single link stats. # Since max stats sequential etc all passed up to here, # it seems ok to let everything through still but # just write a 0 for final values print( "AlgorithmExhaustedError encountered in " "final_conditional estimations: " + aee.message ) print("Halting final_conditional estimations") # For now we don't need a stack trace: # traceback.print_tb(aee.__traceback__) self.statistic_single_link = np.zeros( len(self.selected_vars_sources) ) else: self.selected_vars_sources = [] self.selected_vars_full = self.selected_vars_target self.pvalues_sign_sources = None self.statistic_sign_sources = None self.statistic_single_link = None