Source code for idtxl.multivariate_pid

"""Estimate partial information decomposition (PID).

Estimate PID for multiple sources (up to 4 sources) and one target process
using SxPID estimator.

Note:
    Written for Python 3.4+
"""
import numpy as np
from .single_process_analysis import SingleProcessAnalysis
from .estimator import get_estimator
from .results import ResultsMultivariatePID


[docs]class MultivariatePID(SingleProcessAnalysis): """Perform partial information decomposition for individual processes. Perform partial information decomposition (PID) for multiple source processes (up to 4 sources) and a target process in the network. Estimate unique, shared, and synergistic information in the multiple sources about the target. Call analyse_network() on the whole network or a set of nodes or call analyse_single_target() to estimate PID for a single process. See docstrings of the two functions for more information. References: - Williams, P. L., & Beer, R. D. (2010). Nonnegative Decomposition of Multivariate Information, 1–14. Retrieved from http://arxiv.org/abs/1004.2515 - Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure for shared information. 1- 27 Retrieved from http://arxiv.org/abs/2002.03356 Attributes: target : int index of target process sources : array type multiple of indices of source processes settings : dict analysis settings results : dict estimated PID """ def __init__(self): super().__init__()
[docs] def analyse_network(self, settings, data, targets, sources): """Estimate partial information decomposition for network nodes. Estimate, for multiple nodes (target processes), the partial information decomposition (PID) for multiple source processes (up to 4 sources) and each of these target processes in the network. Note: For a detailed description of the algorithm and settings see documentation of the analyse_single_target() method and references in the class docstring. Example: >>> n = 20 >>> alph = 2 >>> s1 = np.random.randint(0, alph, n) >>> s2 = np.random.randint(0, alph, n) >>> s3 = np.random.randint(0, alph, n) >>> target1 = np.logical_xor(s1, s2).astype(int) >>> target = np.logical_xor(target1, s3).astype(int) >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps', >>> normalise=False) >>> settings = { >>> 'lags_pid': [[1, 1, 1], [3, 2, 7]], >>> 'verbose': False, >>> 'pid_estimator': 'SxPID'} >>> targets = [0, 1] >>> sources = [[1, 2, 3], [0, 2, 3]] >>> pid_analysis = MultivariatePID() >>> results = pid_analysis.analyse_network(settings, data, targets, >>> sources) Args: settings : dict parameters for estimation and statistical testing, see documentation of analyse_single_target() for details, can contain - lags_pid : list of lists of ints [optional] - lags in samples between sources and target (default=[[1, 1, ..., 1], [1, 1, ..., 1], ...]) data : Data instance raw data for analysis targets : list of int index of target processes sources : list of lists indices of the multiple source processes for each target, e.g., [[0, 1, 2], [1, 0, 3]], all must lists be of the same lenght and list of lists must have the same length as targets Returns: ResultsMultivariatePID instance results of network inference, see documentation of ResultsMultivariatePID() """ # Set defaults for PID estimation. settings.setdefault("verbose", True) settings.setdefault( "lags_pid", np.array([[1 for i in range(len(sources[0]))]] * len(targets)) ) # Check inputs. if not len(targets) == len(sources) == len(settings["lags_pid"]): raise RuntimeError( "Lists of targets, sources, and lags must have" "the same lengths." ) for lis_1 in sources: for lis_2 in sources: if not len(lis_1) == len(lis_2): raise RuntimeError( "Lists in the list sources must have" "the same lengths." ) # ^ if # ^ for # ^ for list_of_lags = settings["lags_pid"] # Perform PID estimation for each target individually results = ResultsMultivariatePID( n_nodes=data.n_processes, n_realisations=data.n_realisations(), normalised=data.normalise, ) for t in range(len(targets)): if settings["verbose"]: print( "\n####### analysing target with index {0} from list {1}".format( t, targets ) ) settings["lags_pid"] = list_of_lags[t] res_single = self.analyse_single_target( settings, data, targets[t], sources[t] ) results.combine_results(res_single) # Get no. realisations actually used for estimation from single target # analysis. results.data_properties.n_realisations = ( res_single.data_properties.n_realisations ) return results
[docs] def analyse_single_target(self, settings, data, target, sources): """Estimate partial information decomposition for a network node. Estimate partial information decomposition (PID) for multiple source processes (up to 4 sources) and a target process in the network. Note: For a description of the algorithm and the method see references in the class and estimator docstrings. Example: >>> n = 20 >>> alph = 2 >>> s1 = np.random.randint(0, alph, n) >>> s2 = np.random.randint(0, alph, n) >>> s3 = np.random.randint(0, alph, n) >>> target1 = np.logical_xor(s1, s2).astype(int) >>> target = np.logical_xor(target1, s3).astype(int) >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps', >>> normalise=False) >>> settings = { >>> 'verbose' : false, >>> 'pid_estimator': 'SxPID', >>> 'lags_pid': [2, 3, 1]} >>> pid_analysis = MultivariatePID() >>> results = pid_analysis.analyse_single_target(settings=settings, >>> data=data, >>> target=0, >>> sources=[1, 2, 3]) Args: settings : dict parameters for estimator use and statistics: - pid_estimator : str - estimator to be used for PID estimation (for estimator settings see the documentation in the estimators_pid modules) - lags_pid : list of ints [optional] - lags in samples between sources and target (default=[1, 1, ..., 1]) - verbose : bool [optional] - toggle console output (default=True) data : Data instance raw data for analysis target : int index of target processes sources : list of ints indices of the multiple source processes for the target Returns: ResultsMultivariatePID instance results of network inference, see documentation of ResultsPID() """ # Check input and initialise values for analysis. self._initialise(settings, data, target, sources) # Estimate PID and significance. self._calculate_pid(data) # Add analyis info. results = ResultsMultivariatePID( n_nodes=data.n_processes, n_realisations=data.n_realisations(self.current_value), normalised=data.normalise, ) results._add_single_result( settings=self.settings, target=self.target, results=self.results ) self._reset() return results
def _initialise(self, settings, data, target, sources): """Check input, set initial or default values for analysis settings.""" # Check requested PID estimator. assert "pid_estimator" in settings, "Estimator was not specified!" self._pid_estimator = get_estimator(settings["pid_estimator"], settings) self.settings = settings.copy() self.settings.setdefault("lags_pid", [1 for i in range(len(sources))]) self.settings.setdefault("verbose", True) # Check if provided lags are correct and work with the number of # samples in the data. if len(self.settings["lags_pid"]) not in [2, 3, 4]: raise RuntimeError("List of lags must have length 2 or 3 or 4.") # number of lags is equal to number of sources if not len(self.settings["lags_pid"]) == len(sources): raise RuntimeError( "List of lags must have same length as the list sources." ) for i in range(len(self.settings["lags_pid"])): if self.settings["lags_pid"][0] >= data.n_samples: raise RuntimeError( "Lag {0} ({1}) is larger than the number of samples in the data " "set ({2}).".format(i, self.settings["lags_pid"][i], data.n_samples) ) # Check if target and sources are provided correctly. if type(target) is not int: raise RuntimeError("Target must be an integer.") if len(sources) not in [2, 3, 4]: raise RuntimeError("List of sources must have length 2 or 3 or 4.") if target in sources: raise RuntimeError( "The target ({0}) should not be in the list " "of sources ({1}).".format(target, sources) ) self.current_value = (target, max(self.settings["lags_pid"])) self.target = target # TODO works for single vars only, change to multivariate? self.sources = self._lag_to_idx( [(sources[i], self.settings["lags_pid"][i]) for i in range(len(sources))] ) def _calculate_pid(self, data): # TODO Discuss how and if the following statistical testing should be # included included. Remove dummy results. # [orig_pid, sign_1, p_val_1, # sign_2, p_val_2] = stats.unq_against_surrogates(self, data) # [orig_pid, sign_shd, # p_val_shd, sign_syn, p_val_syn] = stats.syn_shd_against_surrogates( # self, # sign_1 = sign_2 = sign_shd = sign_syn = False # p_val_1 = p_val_2 = p_val_shd = p_val_syn = 1.0 target_realisations = data.get_realisations( self.current_value, [self.current_value] )[0] # CHECK! make sure self.source has the same idx as sources data.get_realisations(self.current_value, [self.sources[0]])[0] list_sources_var_realisations = [ data.get_realisations(self.current_value, [self.sources[i]])[0] for i in range(len(self.sources)) ] orig_pid = self._pid_estimator.estimate( s=list_sources_var_realisations, t=target_realisations ) self.results = orig_pid for i in range(len(self.sources)): self.results["source_" + str(i + 1)] = self._idx_to_lag([self.sources[i]]) # ^ for self.results["selected_vars_sources"] = [ self.results["source_" + str(i + 1)][0] for i in range(len(self.sources)) ] self.results["current_value"] = self.current_value # self.results['unq_s1_sign'] = sign_1 # self.results['unq_s2_sign'] = sign_2 # self.results['unq_s1_p_val'] = p_val_1 # self.results['unq_s2_p_val'] = p_val_2 # self.results['syn_sign'] = sign_syn # self.results['syn_p_val'] = p_val_syn # self.results['shd_sign'] = sign_shd # self.results['shd_p_val'] = p_val_shd # TODO make mi_against_surrogates in stats more generic, such that # it becomes an arbitrary permutation test where one arguemnt gets # shuffled and then all arguents are passed to the provided estimator def _reset(self): """Reset instance after analysis.""" self.__init__() del self.results del self.settings del self._pid_estimator