"""Parent class for network inference and network comparison.
"""
import ast
import copy as cp
import itertools as it
import os.path
from datetime import datetime
from pprint import pprint
from shutil import copyfile
import numpy as np
from . import idtxl_io as io
from . import idtxl_utils as utils
from .estimator import get_estimator
[docs]class NetworkAnalysis:
"""Provide an analysis setup for network inference or comparison.
The class provides routines to check user input and set defaults.
"""
def __init__(self):
self.settings = {}
self.target = None
self.current_value = None
self.selected_vars_full = []
self.selected_vars_sources = []
self.selected_vars_target = []
self._current_value_realisations = None
self._selected_vars_realisations = None
self._min_stats_surr_table = None
@property
def current_value(self):
"""Get index of the current_value."""
return self._current_value
@current_value.setter
def current_value(self, idx):
if (idx is not None) and (type(idx) is not tuple):
raise TypeError(
(
"The current value should be a tuple (index "
+ "process, index sample)."
)
)
self._current_value = idx
@property
def _current_value_realisations(self):
"""Get realisations of the current_value."""
if self.__current_value_realisations is None:
print("Attribute has not been set yet.")
if type(self.__current_value_realisations) is tuple:
raise TypeError("something went wrong")
return self.__current_value_realisations
@_current_value_realisations.setter
def _current_value_realisations(self, realisations):
self.__current_value_realisations = realisations
@_current_value_realisations.deleter
def _current_value_realisations(self):
del self.__current_value_realisations
@property
def selected_vars_full(self):
"""List of indices of the full conditional set."""
if self._selected_vars_full is None:
print("Attribute has not been set yet.")
return self._selected_vars_full
@selected_vars_full.setter
def selected_vars_full(self, idx_list):
if type(idx_list) is not list and (type(idx_list[0]) is not tuple):
raise TypeError(
("Expected a list of tuples (index process, " + "index sample).")
)
self._selected_vars_full = idx_list
@property
def selected_vars_target(self):
"""List of indices of target samples in the conditional set."""
if self._selected_vars_target is None:
print("Attribute has not been set yet.")
return self._selected_vars_target
@selected_vars_target.setter
def selected_vars_target(self, idx_list):
if idx_list is not None and type(idx_list) is not list:
raise TypeError(
("Expected a list of tuples (index process, " + "index sample).")
)
self._selected_vars_target = idx_list
@property
def selected_vars_sources(self):
"""List of indices of source samples in the conditional set."""
if self._selected_vars_sources is None:
print("Attribute has not been set yet.")
return self._selected_vars_sources
@selected_vars_sources.setter
def selected_vars_sources(self, idx_list):
if idx_list is not None and type(idx_list) is not list:
raise TypeError(
("Expected a list of tuples (index process, " + "index sample).")
)
self._selected_vars_sources = idx_list
@property
def _selected_vars_realisations(self):
"""Get realisations of the full conditional set."""
return self.__selected_vars_realisations
@_selected_vars_realisations.setter
def _selected_vars_realisations(self, realisations):
self.__selected_vars_realisations = realisations
@property
def _selected_vars_target_realisations(self):
"""Get realisations of the target samples in the conditional.
Note:
Each time this property is called, realisations are actually
extracted from the array of all realisations, which may be slow!
Use temporary variables to speed things up.
"""
if self.selected_vars_target is None:
return None
indices = np.zeros(len(self.selected_vars_target)).astype(int)
for i, idx in enumerate(self.selected_vars_target):
indices[i] = self.selected_vars_full.index(idx)
self._selected_vars_target_realisations = self._selected_vars_realisations[
:, indices
]
return self.__selected_vars_target_realisations
@_selected_vars_target_realisations.setter
def _selected_vars_target_realisations(self, realisations):
self.__selected_vars_target_realisations = realisations
@property
def _selected_vars_sources_realisations(self):
"""Get realisations of the source samples in the conditional.
Note:
Each time this property is called, realisations are actually
extracted from the array of all realisations, which may be slow!
Use temporary variables to speed things up.
"""
indices = np.zeros(len(self.selected_vars_sources)).astype(int)
for i, idx in enumerate(self.selected_vars_sources):
indices[i] = self.selected_vars_full.index(idx)
self._selected_vars_sources_realisations = self._selected_vars_realisations[
:, indices
]
return self.__selected_vars_sources_realisations
@_selected_vars_sources_realisations.setter
def _selected_vars_sources_realisations(self, realisations):
self.__selected_vars_sources_realisations = realisations
def _append_selected_vars_realisations(self, realisations):
"""Append realisations of conditionals to existing realisations.
Returns:
realisations: numpy array with dimensions replications x number
of indices.
"""
if self._selected_vars_realisations is None or realisations.size == 0:
self._selected_vars_realisations = realisations
else:
self._selected_vars_realisations = np.hstack(
(self._selected_vars_realisations, realisations)
)
def _idx_to_lag(self, idx_list, current_value_sample=None):
"""Change sample indices to lags for each sample in the list."""
if current_value_sample is None:
try:
current_value_sample = self.current_value[1]
except (AttributeError, TypeError):
raise AttributeError("Current value not set.")
lag_list = cp.copy(idx_list)
for c in idx_list:
if c[1] > current_value_sample:
raise IndexError("Sample time index larger than current " "value.")
lag_list[idx_list.index(c)] = (c[0], current_value_sample - c[1])
return lag_list
def _lag_to_idx(self, lag_list, current_value_sample=None):
"""Change sample lags to indices for each sample in the list."""
if current_value_sample is None:
try:
current_value_sample = self.current_value[1]
except (AttributeError, TypeError):
raise AttributeError("Current value not set.")
idx_list = cp.copy(lag_list)
for c in lag_list:
if c[1] > current_value_sample:
raise IndexError("Sample lag larger than current value.")
idx_list[lag_list.index(c)] = (c[0], current_value_sample - c[1])
return idx_list
def _set_cmi_estimator(self):
"""Check and set requested CMI estimator."""
# Set CMI estimator. Check if the user requested the estimation of
# local values. If so, initialise a local estimator additionally to the
# average estimator. Internally, the average estimator is used for
# building the non-uniform embedding, etc. The local estimator is used
# to estimate single-link MI/TE or single-process AIS in the end.
assert "cmi_estimator" in self.settings, "Estimator was not specified!"
if self.settings["local_values"]:
self.settings["local_values"] = False
self._cmi_estimator = get_estimator(
self.settings["cmi_estimator"], self.settings
)
self.settings["local_values"] = True
self._cmi_estimator_local = get_estimator(
self.settings["cmi_estimator"], self.settings
)
else:
self._cmi_estimator = get_estimator(
self.settings["cmi_estimator"], self.settings
)
def _separate_realisations(self, idx_full, idx_single):
"""Separate single index realisations from a set of realisations.
Return the realisations of a single index and the realisations of the
remaining set of indices. The function takes realisations from the
array in self._selected_vars_realisations. This allows to reuse the
collected realisations when pruning the conditional set after
candidates have been included.
Args:
idx_full : list of tuples
indices indicating the full set
idx_single : tuple
index to be removed
Returns:
numpy array
realisations of the set without the single index
numpy array
realisations of the variable at the single index
"""
# Get indices of the remaining variables.
idx_remaining = cp.copy(idx_full)
idx_remaining.pop(idx_remaining.index(idx_single))
# Find the indices of the columns with the realisations of the
# requested variables (the single one to be removed and the remaining
# variables).
array_col_single = self.selected_vars_full.index(idx_single)
array_col_remain = np.zeros(len(idx_remaining)).astype(int)
for i, idx in enumerate(idx_remaining):
array_col_remain[i] = self.selected_vars_full.index(idx)
# Get realisations of the single and remaining variables.
real_single = np.expand_dims(
self._selected_vars_realisations[:, array_col_single], axis=1
)
if len(idx_full) == 1:
# If no realiastions remain, set variable to None instead of and
# empty array so the JIDT estimator doesn't break
real_remain = None
else:
real_remain = self._selected_vars_realisations[:, array_col_remain]
return real_remain, real_single
def _define_candidates(self, processes, samples):
"""Build a list of candidate indices.
Build a list of candidate indices. Note that variables that were
manually added to the conditioning set via the 'add_conditionals'
setting are removed from the candidate set if both sets are not
disjoint.
Args:
processes : list of int
process indices
samples: list of int
sample indices
Returns:
a list of tuples, where each tuple holds the index of one
candidate and has the form (process index, sample index), indices
are absolute values with respect to some data array.
"""
candidate_set = self._build_variable_list(processes, samples)
# Remove candidates that were already manullay added to the
# conditioning set via the 'add_conditionals' setting. Otherwise the
# candidates get tested in the inclusion step.
candidate_set = self._remove_forced_conditionals(candidate_set)
return candidate_set
def _build_variable_list(self, processes, samples):
"""Build a list of variable tuples with (process index, sample index).
Args:
processes : list of int
process indices
samples: list of int
sample indices
Returns:
a list of variable tuples
"""
var_list = []
for idx in it.product(processes, samples):
var_list.append(idx)
return var_list
def _remove_forced_conditionals(self, candidate_set):
"""Remove enforced conditioning variables from candidate set."""
if self.settings["add_conditionals"] is not None:
cond = self.settings["add_conditionals"]
if type(cond) is tuple: # easily add single variable
cond = [cond]
elif type(cond) is dict: # add conditioning variables per target
try:
cond = cond[self.target]
except KeyError:
return # no additional variables for the current target
cond_idx = self._lag_to_idx(cond)
candidate_set = list(set(candidate_set).difference(set(cond_idx)))
return candidate_set
def _append_selected_vars_idx(self, idx):
"""Append indices of conditionals to existing list.
Args:
idx : list of tuples
indices of selected variables, where each entry is a tuple
(idx process, idx sample), where indices are absolute values
with respect to entries in a data array
"""
if self.selected_vars_full is None:
self.selected_vars_full = idx
else:
for i in idx:
self.selected_vars_full.append(i)
# separate indices into source and target indices
for i in idx:
if i[0] == self.target:
self.selected_vars_target.append(i)
else:
self.selected_vars_sources.append(i)
def _append_selected_vars(self, idx, realisations):
"""Append indices and realisation of selected variables.
Args:
idx : list of tuples
indices of selected variables, where each entry is a tuple
(idx process, idx sample), where indices are absolute values
with respect to entries in a data array
realisations : numpy array
realisations of the selected variables
"""
assert len(idx) == realisations.shape[1], (
"Dimensionality of realisations array ({0}) and length of index "
"list ({1}) do not match.".format(realisations.shape[1], len(idx))
)
self._append_selected_vars_idx(idx)
self._append_selected_vars_realisations(realisations)
def _remove_selected_var(self, idx):
"""Remove a single selected variable and its realisations."""
self._selected_vars_realisations = utils.remove_column(
self._selected_vars_realisations, self.selected_vars_full.index(idx)
)
self.selected_vars_full.pop(self.selected_vars_full.index(idx))
if idx[0] == self.target:
self.selected_vars_target.pop(self.selected_vars_target.index(idx))
else:
self.selected_vars_sources.pop(self.selected_vars_sources.index(idx))
def _calculate_single_link(
self,
data,
current_value,
source_vars,
target_vars=None,
sources="all",
conditioning="full",
):
"""Calculate dependency measure for all links into a target.
Calculate dependency measure for all links into a target. A single link
may consist of information that multiple past variables in a source
have about the target. The measure can be transfer entropy or mutual
information and is estimated as the joint information all selected past
variables from a single source have about the target.
The conditioning defines which variables are included in the
conditioning set when estimating a dependency measure. This can be set
to
- 'full' to include all selected variables (for multivariate TE this
includes the target's past variables and past variables from all
other inferred sources, for multivariate MI this includes past
variables from all other inferred sources) from all other inferred
sources and the target's past,
- 'target' to include variables from the target's past alone (for
bivariate TE estimation),
- 'none' for no conditioning (for bivariate MI estimation).
For transfer entropy, the information transfer is calculated
conditional on the target's past. For multivariate TE or MI, the
information (transfer) is calculated conditionally on selected
variables from further sources in the network.
Measures can be estimated either for 'all' sources (determined from the
selected source variables) or for individual sources. A list of
estimated values for each link (source-target combination) is returned.
Args:
data : Data instance
raw data for analysis
current_value : tuple
index of the current value used for estimation, (idx process,
idx sample)
source_vars : np array of tuples
array of past source variables, where one tuple describes a
single variable as (idx process, idx sample)
target_vars : np array of tuples [optional]
array of past target variables
sources : list of ints | 'all' [optional]
return estimates for selected sources or all sources (default)
conditioning : str [optional]
set conditioning set, 'full' for all selected variables
(target's and sources' past), 'target' for variables from the
target's past only, 'none' for no conditioning
Returns:
numpy array
estimate of dependency measure for each link
Raises:
ex.AlgorithmExhaustedError
Raised from estimate() when calculation cannot be made
"""
# Get realisations of target variables and the current value, constant
# over sources. Permute current value realisations to generate
# surrogates if requested.
target_realisations = data.get_realisations(current_value, target_vars)[0]
current_value_realisations = data.get_realisations(
current_value, [current_value]
)[0]
# Check requested sources.
if sources == "all":
sources = np.unique([s[0] for s in source_vars])
else:
if type(sources) is int: # handle integer inputs
sources = [sources]
sources = np.array(sources)
if any(sources > (data.n_processes - 1)):
raise RuntimeError(
"At least one source ({0}) is not in no. "
"nodes in the data ({1}).".format(sources, data.n_processes)
)
# Allocate memory: either a multidimensional array if local values are
# required, or a 1D-array for averaged values for each link.
if self.settings["local_values"]:
# Collect local values in a [sources x samples x replications]
# matrix.
links = np.zeros(
(
len(sources),
data.n_realisations_samples(current_value),
data.n_replications,
)
)
else:
links = np.zeros(len(sources))
# Loop over individual sources.
for i, s in enumerate(sources):
# Separate source variables in variables belonging to the current
# link and variables belonging to the conditioning set. Get
# realisations for the current link's selected source variables.
link_vars = [i for i in source_vars if i[0] == s]
conditional_vars = [i for i in source_vars if i[0] != s]
source_realisations, replication_ind = data.get_realisations(
current_value, link_vars
)
# Determine which type of conditioning is requested.
if conditioning == "full":
if target_realisations is None:
# Use sources' pasts only, returns None if conditional vars
# is empty.
conditional_realisations = data.get_realisations(
current_value, conditional_vars
)[0]
else:
# Use target's and sources' past, check if conditional vars
# is not empty, otherwise np.hstack crashes.
if conditional_vars:
conditional_realisations = np.hstack(
(
data.get_realisations(current_value, conditional_vars)[
0
],
target_realisations,
)
)
else: # use target's past only
conditional_realisations = target_realisations
elif conditioning == "target": # use target's past only (biv. TE)
conditional_realisations = target_realisations
elif conditioning == "none": # no conditioning (bivariate MI)
conditional_realisations = None
else:
raise RuntimeError("Unknown conditioning: {0}.".format(conditioning))
if self.settings["local_values"]:
local_values = self._cmi_estimator_local.estimate(
var1=current_value_realisations,
var2=source_realisations,
conditional=conditional_realisations,
)
links[i] = local_values.reshape(
max(replication_ind) + 1, sum(replication_ind == 0)
).T
else:
links[i] = self._cmi_estimator.estimate(
var1=current_value_realisations,
var2=source_realisations,
conditional=conditional_realisations,
)
return links
def _set_checkpointing_defaults(self, settings, data, sources, target):
"""Set defaults for writing analysis checkpoints."""
settings.setdefault("write_ckp", False)
if settings["write_ckp"]:
settings.setdefault("filename_ckp", "./idtxl_checkpoint")
filename_ckp = "{0}.ckp".format(settings["filename_ckp"])
if not os.path.isfile(filename_ckp):
self._initialise_checkpoint(settings, data, sources, target)
return settings
else:
return settings
def _initialise_checkpoint(self, settings, data, sources, targets):
"""Write first checkpoint file, data, and settings to disk.
Called once at the beggining of an analysis using checkpointing. Write
data and analysis settings to disk. This needs to be done only once.
Initialise checkpoint file: write header with time stamp, path to data
and settings, and targets and sources to be analysed. The checkpoint
file is updated during the analyis.
"""
# Check if targets is an int, convert to array.
if type(targets) is int:
targets = [targets]
# Write data to disk.
io.save_pickle(data, "{0}.dat".format(settings["filename_ckp"]))
# Write settings to disk.
io.save_json(settings, "{0}.json".format(settings["filename_ckp"]))
# Initialise checkpoint file for later updates.
filename_ckp = "{0}.ckp".format(settings["filename_ckp"])
with open(filename_ckp, "w") as text_file:
text_file.write("IDTxl checkpoint file.\n")
timestamp = datetime.now()
text_file.write("{:%Y-%m-%d %H:%M:%S}\n".format(timestamp))
text_file.write(
"Raw data path: {}.dat\n".format(
os.path.abspath(settings["filename_ckp"])
)
)
text_file.write(
"Settings path: {}.json\n".format(
os.path.abspath(settings["filename_ckp"])
)
)
text_file.write("Targets to be analyzed: {}\n".format(targets))
text_file.write("Sources to be analyzed: {}\n\n".format(sources))
text_file.write(
"Selected variables (target: [sources]: [selected variables]):"
"\n{}".format(targets[0])
)
def _write_checkpoint(self):
"""Write checkpoint to disk.
Write checkpoint to disk. The checkpoint contains variables already
selected by network analysis algorithms. To recover from a checkpoint
use the 'recover_checkpoint()‘ method.
Note: IDTxl will always keep the current (*.ckp) and the previous
version (*.ckp.old) of the checkpoint file to ensure a recoverable
state even if writing of the current checkpoint fails.
"""
filename_ckp = "{0}.ckp".format(self.settings["filename_ckp"])
# Check if a checkpoint file already exists. If yes,
# 1. make a copy using the same file name plus the .old extension
# (overwriting the last *.ckp.old file);
# 2. update current checkpoint file.
if os.path.isfile(filename_ckp):
copyfile(filename_ckp, "{}.old".format(filename_ckp))
self._update_checkpoint(filename_ckp)
else:
raise RuntimeError(
"Could not find checkpoint file for updating. "
"Initialise checkpoint first."
)
def _update_checkpoint(self, filename_ckp):
"""Update existing checkpoint file.
Add the last selected variable to the *.ckp file while keeping the
path to data and settings. Overwrite time stamp in header.
"""
# We don't expect these files to become very big. Hence, it is the
# easiest to load the whole file into a data structure and then write
# it back (https://stackoverflow.com/a/328007). Alternatively, we can
# just add the last selected variable as a tuple -> then we have to
# make sure, the last selected candidate always ends up at the end of
# the selected candidates list.
# Write time stamp and info
timestamp = datetime.now()
# Convert absolute indices to lags with respect to the current value.
selected_variables = self._idx_to_lag(
self.selected_vars_full, self.current_value[1]
)
# Read file as list of lines and replace first and last line. Write
# modified file back to disk.
with open(filename_ckp, "r") as f:
lines = f.readlines()
lines[1] = "{:%Y-%m-%d %H:%M:%S}\n".format(timestamp)
if int(lines[-1][0]) == self.target:
lines[-1] = "{0}: {1}: {2}\n".format(
self.target, self.source_set, selected_variables
)
else:
lines.append(
"{0}: {1}: {2}\n".format(
self.target, self.source_set, selected_variables
)
)
with open(filename_ckp, "w") as f:
f.writelines(lines)
[docs] def resume_checkpoint(self, file_path):
"""Resume analysis from a checkpoint saved to disk.
Args:
file_path : str
path to checkpoint file (excluding extension: .ckp)
"""
# Read checkpoint
with open("{}.ckp".format(file_path), "r") as f:
lines = f.readlines()
timestamp = lines[1]
data_path = lines[2][15:].strip()
settings_path = lines[3][15:].strip()
# Load settings and data
data = io.load_pickle(data_path)
settings = io.load_json(settings_path)
verbose = settings.get("verbose", True)
if verbose:
print(
"Resuming analysis from file {}.ckp, saved {}".format(
file_path, timestamp
)
)
# Read targets and sources.
targets = ast.literal_eval(lines[4].split(":")[1].strip())
sources = ast.literal_eval(lines[5].split(":")[1].strip())
# Read selected variables
# Format: target - sources analyzed - selected variables
selected_variables = {} # vars as lags wrt. the current value
for l in range(8, len(lines)):
result = [x.strip() for x in lines[l].split(":")]
# ast.literal_eval(result[2]): IndexError: list index out of range
try:
selected_variables[int(result[0])] = ast.literal_eval(result[2])
except IndexError:
if verbose:
print("No variables previously selected.")
if verbose:
print("Selected variables per target:")
pprint(selected_variables)
# Add already selected candidates as conditionals to be added to the
# settings dict. Note that the time stamp in the selected variables
# list is a lag wrt. the current value. This format is also expected by
# the method that manually adds conditionals.
settings["add_conditionals"] = selected_variables
return data, settings, targets, sources