import nninfo
import copy
import torch
import numpy as np
from torch.utils.data import DataLoader
from scipy.stats import norm
from tqdm import tqdm
from external.SxPID.sxpid import SxPID
from external.nsb_estimator import make_mk, _H_nsb

log = nninfo.log.get_logger(__name__)

__all__ = ["Analysis", "WeightMeasurement", "Measurement"]


class Analysis:
    """
    After an experiment is done, the class Analysis allows for
    analyzing the training of the network with information theoretic
    tools, for example Partial Information Decomposition, Fisher information,
    etc.

    It works by loading the network's state from the checkpoint and
    creating the activations and other necessary quantities on demand.
    This ensures that the experiments can be reanalyzed at a later
    stage when more analysis tools might be implemented, since the full
    network state is still there.

    It is also possible to analyze multiple runs in the same step,
    using the choose_chapters keywords.
    """

    def __init__(self, exp, measurement_subdir='measurements/'):
        self._measurement_manager = None
        self._activations = dict()
        self.set_experiment(exp, measurement_subdir)

    def set_experiment(self, experiment, measurement_subdir):
        self.exp = experiment
        self._measurement_manager = nninfo.file_io.MeasurementManager(
            self.exp.experiment_dir, measurement_subdir
        )

    def capture_activations(
        self,
        dataset,
        repeat_dataset=1,
        batch_size=10 ** 4,
        before_noise=False,
        condition=None,
        quantizer_params=None,
    ):
        """
        Captures activations for current network state for the given dataset.

        Returns:
            Iterator that yields dictionaries with keys X, Y, L1, L2, ...
                and ndarrays of size (batch_size, ) containing activations
        """

        assert isinstance(self.exp.network, nninfo.model.NoisyNeuralNetwork) or before_noise == False, 'Extraction after noise only possible for noisy network!'

        # Warn if network is not in eval mode
        if self.exp.network.training:
            log.warning(
                "Network is in training mode. Dropout will be active. "
            )
        
        if not isinstance(dataset, nninfo.task.DataSet):
            dataset = self.exp.task[dataset]

        if not condition is None:
            dataset = dataset.condition(len(dataset), condition)

        feeder = DataLoader(dataset, batch_size=batch_size)
        
        for x, y in feeder:
            for i in range(repeat_dataset):

                if isinstance(self.exp.network, nninfo.model.NoisyNeuralNetwork):
                    act_dict = self.exp.network.extract_activations(
                        x, before_noise=before_noise, quantizer_params=quantizer_params
                    )
                else:
                    act_dict = self.exp.network.extract_activations(
                        x, quantizer_params=quantizer_params
                    )

                act_dict["X"] = x.detach().numpy()

                act_dict["Y"] = y.detach().numpy()

                # Add decision function. Index of maximum value of output layer. If multiple output neurons have the same activation, choose the first!
                act_dict["Yhat"] = np.argmax(
                    act_dict["L{}".format(len(act_dict) - 2)], axis=1
                )

                # Reshape to neuron id dicts (layer_id, neuron_id)->[activations]
                act_dict = {
                    (layer_id, (neuron_idx + 1,)): act_dict[layer_id][:, neuron_idx]
                    if act_dict[layer_id].ndim > 1
                    else act_dict[layer_id]
                    for layer_id in act_dict
                    for neuron_idx in range(
                        act_dict[layer_id].shape[1]
                        if act_dict[layer_id].ndim > 1
                        else 1
                    )
                }

                yield act_dict

    def capture_activations_of_single_samle(
        self,
        single_sample,
        repeat_dataset=1,
        batch_size=10 ** 6,
        before_noise=False,
        quantizer=None,
    ):
        """
        Captures activations for current network state for the given sample.

        Returns:
            Iterator that yields dictionaries with keys X, Y, L1, L2, ...
                and ndarrays of size (batch_size, ) containing activations
        """

        n_batches = int(np.ceil(repeat_dataset/batch_size))
        feeder = [torch.FloatTensor([single_sample]).repeat(batch_size if batch_index < n_batches-1 else repeat_dataset % batch_size, 1) for batch_index in range(n_batches)]

        for x in feeder:
            act_dict = self.exp.network.extract_activations(
                x, before_noise=before_noise, quantizer=quantizer
            )
            act_dict["X"] = x.detach().numpy()

            # Add decision function. Index of maximum value of output layer. If multiple output neurons have the same activation, choose the first!
            act_dict["Yhat"] = np.argmax(
                act_dict["L{}".format(len(act_dict) - 2)], axis=1
            )

            # Reshape to neuron id dicts (layer_id, neuron_id)->[activations]
            act_dict = {
                (layer_id, (neuron_idx + 1,)): act_dict[layer_id][:, neuron_idx]
                if act_dict[layer_id].ndim > 1
                else act_dict[layer_id]
                for layer_id in act_dict
                for neuron_idx in range(
                    act_dict[layer_id].shape[1]
                    if act_dict[layer_id].ndim > 1
                    else 1
                )
            }

            yield act_dict

    def _choose_chapters(self, choose_chapters, custom_chapters_list=None):
        """
        For choosing several chapters at once (used by several perform_*_measurement functions).

        Args:
            choose_chapters (str): Can have the following values:
                * 'all' (performs meas on all the chapters found in the checkpoints directory)
                * 'all_chapters_same_run' (performs m all chapters of the current run)
                * 'all_runs_same_chapter' (performs meas on the same chapter of all runs)
                * 'this' (performs only on the current state of the network)
                * 'custom_list' (performs only on states given in custom_chapters_list
            custom_chapters_list (list, optional): list of (run_id, chapter_id) tuples or list of filenames

        Returns:
            list: chapters_list, either the original custom_chapters_list or an extracted list
        """
        if choose_chapters == "this":
            chapters_list = [(self.exp.run_id, self.exp.chapter_id)]
        elif choose_chapters == "all":
            chapters_list = self.exp.list_checkpoints()
        elif choose_chapters == "all_runs_same_chapter":
            chapters_list = self.exp.list_checkpoints(chapter_ids=[self.exp.chapter_id])
        elif choose_chapters == "all_chapters_same_run":
            chapters_list = self.exp.list_checkpoints(run_ids=[self.exp.run_id])
        else:
            if custom_chapters_list is None:
                raise AttributeError(
                    "Not clear on which chapters measurement should be performed"
                )
            else:
                chapters_list = custom_chapters_list

        return chapters_list

    def perform_fisher(
        self,
        dataset_name,
        choose_chapters="this",
        empirical=False,
        save_offdiag=False,
        custom_chapters_list=None,
        output_size=1,
    ):
        """
        Perform a Fisher Information measurement on the network states during the training.

        Args:
            dataset_name (str): The name of the dataset to perform the FI estimation with.

        Keyword Args:
            choose_chapters (str): Can have the following values:
                * 'all' (performs meas on all the chapters found in the checkpoints directory)
                * 'all_chapters_same_run' (performs m all chapters of the current run)
                * 'all_runs_same_chapter' (performs meas on the same chapter of all runs)
                * 'this' (performs only on the current state of the network)
                * 'custom_list' (performs only on states given in custom_chapters_list
            custom_chapters_list (list): list of (run_id, chapter_id) tuples or list of filenames
            empirical (bool): If true, the gradients are computed with respect to the true label. Default is with respect to the output distribution.
            save_offdiag (bool): Wether the offdiagonal elements of the FI matrix should be saved or not.
            output_size (int): Size of the last layer of the network.


        Returns:
            int: measurement_id of the performed measurements for quick access (all have the same id, because all have the same parameters)
        """
        chapters_list = self._choose_chapters(choose_chapters, custom_chapters_list)

        measurement_id = None
        for state in tqdm(chapters_list):
            if isinstance(state, str):
                self.exp.load_checkpoint(filename=state)
            elif isinstance(state, tuple):
                self.exp.load_checkpoint(run_id=state[0], chapter_id=state[1])
            fm = nninfo.analysis.fisher.FisherMeasurement(
                self.exp.run_id,
                self.exp.chapter_id,
                self.exp.epoch_id,
                dataset_name,
                self.exp.network,
                self.exp.task,
                empirical,
                save_offdiag,
                output_size,
            )
            fm.measure()
            # save the measurement
            measurement_id = self._measurement_manager.save(
                fm, measurement_id=measurement_id
            )
        return measurement_id

    def perform_custom_pid(
        self,
        dataset,
        target_id_list,
        source1_id_list,
        source2_id_list,
        source3_id_list=None,
        source4_id_list=None,
        source5_id_list=None,
        pid_definition="sxpid",
        binning_method="fixed_size",
        binning_params={},
        pointwise=False,
        parts="all",
        choose_chapters="this",
        custom_chapters_list=None,
        repeat_dataset=1,
        n_samples=None,
        measurement_id=None,
        quantizer_params=None,
    ):
        """
        Perform a single Partial Information Decomposition measurement
        on the current network state.

        Args:
            dataset (str or dataset): The dataset you want to perform the PID with.
            target_id_list (list): List of neuron_ids that form the target.
            source1_id_list (list): List of neuron_ids that form source 1.
            source2_id_list (list): List of neuron_ids that form source 2.

        Keyword Args:
            source3_id_list (list): List of neuron_ids that form source 3.
            source4_id_list (list): List of neuron_ids that form source 4.
            pid_definition (str): Definition of PID. At the moment 'sxpid' and 'brojapid'
            binning_method (str): Method for discretization. At the moment 'fixed_size'
                and 'goldfeld:fixed_size'
            binning_params (dict): Additional parameters such as 'n_bins'. Defaults
                is 'n_bins' = 30. Limits for binning are automatically chosen for
                each activation function.
            pointwise (bool): Whether to save the pointwise PID enabled by SxPID.
                Default is False.
            choose_chapters (str): Can have the following values:
                * 'all' (performs meas on all the chapters found in the checkpoints directory)
                * 'all_chapters_same_run' (performs m all chapters of the current run)
                * 'all_runs_same_chapter' (performs meas on the same chapter of all runs)
                * 'this' (performs only on the current state of the network)
                * 'custom_list' (performs only on states given in custom_chapters_list
            custom_chapters_list (list): list of (run_id, chapter_id) tuples or list of filenames

        Returns:
            int: measurement_id of the performed measurements for quick access (all have the
                same id, because all have the same parameters)
        """

        if isinstance(dataset, str):
            dataset = self.exp.task[dataset]

        # Lazy datasets can be extended or shortened to arbitrary numbers of samples
        if not n_samples is None:
            assert isinstance(
                dataset, nninfo.task.LazyDataset
            ), '"n_samples" keyword only allowed for LazyDataset'
            dataset = dataset.extend(n_samples)

        chapters_list = self._choose_chapters(choose_chapters, custom_chapters_list)

        if not "limits" in binning_params:
            binning_params["limits"] = self.exp.network.get_binning_limits()
        if not "binning_method" in binning_params:
            binning_params["binning_method"] = binning_method

        for state in chapters_list:
            # load the right chapter
            if isinstance(state, str):
                self.exp.load_checkpoint(filename=state)
            elif isinstance(state, tuple):
                self.exp.load_checkpoint(run_id=state[0], chapter_id=state[1])
            run_id = self.exp.run_id
            chapter_id = self.exp.chapter_id
            epoch_id = self.exp.epoch_id

            activations_iter = self.capture_activations(
                dataset,
                repeat_dataset=repeat_dataset,
                before_noise=binning_method == "goldfeld:fixed_size",
                quantizer_params=quantizer_params,
            )

            source_id_lists = [source1_id_list, source2_id_list]
            if not source3_id_list is None:
                source_id_lists += [source3_id_list]
            if not source4_id_list is None:
                source_id_lists += [source4_id_list]
            if not source5_id_list is None:
                source_id_lists += [source5_id_list]

            # set all necessary parameters
            pid = nninfo.analysis.pid.PIDMeasurement(
                run_id,
                chapter_id,
                epoch_id,
                dataset,
                pid_definition,
                binning_params,
                target_id_list,
                source_id_lists,
                pointwise=pointwise,
                parts=parts,
            )

            # actual measurement step
            pid.measure(activations_iter)
            # save the measurement
            measurement_id = self._measurement_manager.save(
                pid, measurement_id=measurement_id
            )
        return measurement_id

    def perform_single_neuron_vs_rest_layer_pid(
        self,
        layer_label,
        neuron_index,
        target_id_list,
        dataset_name,
        pid_definition="sxpid",
        binning_method="fixed_size",
        binning_params={},
        pointwise=False,
        choose_chapters="this",
        custom_chapters_list=None,
        repeat_dataset=1,
    ):
        """
        Perform a specialized PID: Target will be target_id_list,
        source 1 will be one neuron, source 2 will be the rest of the respective layer.

        Can perform PID on all chapters at once.

        Args:
            layer_label (str): the label of the layer you want to analyse: 'L<n>'
            neuron_index (int): index of the neuron the measurement should be performed on.
                This is an index (integer) and not a neuron_id (tuple of layer_label
                and neuron index).

        All other arguments are passed on to self.perform_custom_PID.

        Returns:
            list of int: measurement_ids of the performed measurements for quick access
        """

        if neuron_index == 0:
            log.error("Neuron indices always start with 1.")
            raise KeyError

        neuron_id = (layer_label, (neuron_index,))

        layer_neuron_id_list = copy.deepcopy(
            self.exp.network.get_layer_structure_dict()[layer_label]
        )
        layer_neuron_id_list.remove(neuron_id)

        measurement_id = self.perform_custom_pid(
            dataset_name,
            target_id_list,
            [neuron_id],
            layer_neuron_id_list,
            pid_definition=pid_definition,
            binning_method=binning_method,
            binning_params=binning_params,
            pointwise=pointwise,
            choose_chapters=choose_chapters,
            custom_chapters_list=custom_chapters_list,
            repeat_dataset=repeat_dataset,
        )

        return measurement_id

    def compare_acc_loss_multiple_datasets(
        self,
        dataset_names_list,
        show_plot=True,
        save_plot=False,
        run_ids=[0],
        ax=None,
        quantizer_params=None,
    ):

        epoch_list = []
        performance_dict = dict()
        for dataset_name in dataset_names_list:
            performance_dict[dataset_name] = {"acc": [], "loss": []}
        for c in self.exp.list_checkpoints(run_ids=run_ids):
            self.exp.load_checkpoint(filename=c)
            epoch_list.append(self.exp.epoch_id)
            for dataset_name in dataset_names_list:
                performance_dict[dataset_name]["acc"].append(
                    self.exp.tester.test(
                        dataset_name, return_accuracy=True, quantizer_params=quantizer_params
                    )
                )
                performance_dict[dataset_name]["loss"].append(
                    self.exp.tester.test(dataset_name, quantizer_params=quantizer_params)
                )
        if show_plot or save_plot or not ax is None:
            nninfo.plot.Plot.performance_plot(
                epoch_list,
                performance_dict,
                show_plot=show_plot,
                save_plot=save_plot,
                save_dir=self.exp.experiment_dir + "plots/",
                ax=ax,
            )
        return epoch_list, performance_dict

    def perform_weight_measurement(
        self, choose_chapters="this", custom_chapters_list=None
    ):

        chapters_list = self._choose_chapters(choose_chapters, custom_chapters_list)
        measurement_id = None
        for state in chapters_list:
            if isinstance(state, str):
                self.exp.load_checkpoint(filename=state)
            elif isinstance(state, tuple):
                self.exp.load_checkpoint(run_id=state[0], chapter_id=state[1])
            wm = WeightMeasurement(
                self.exp.run_id,
                self.exp.chapter_id,
                self.exp.epoch_id,
                self.exp.network,
            )
            wm.measure()
            # save the measurement
            measurement_id = self._measurement_manager.save(
                wm, measurement_id=measurement_id
            )
        return measurement_id

    def count_unique_activations(
        self, dataset, run_id, chapter_id, neurons, quantizer_params=None
    ):

        if isinstance(dataset, str):
            dataset = self.exp.task[dataset]

        self.exp.load_checkpoint(run_id=run_id, chapter_id=chapter_id)
        activations = next(self.capture_activations(
            dataset, quantizer_params=quantizer_params))

        relevant_activations = np.array([activations[neuron] for neuron in neurons])
        n_uniques = np.unique(relevant_activations, axis=1).shape[1]

        return n_uniques

    def perform_mi_estimation(
        self,
        dataset,
        rv1_id_list,
        rv2_id_list,
        estimation_method="naive",
        binning_method="fixed_size",
        estimation_params={},
        choose_chapters="this",
        custom_chapters_list=None,
        repeat_dataset=1,
        n_samples=None,
        measurement_id=None,
        quantizer_params=None
    ):
        """
        Perform a single Partial Information Decomposition measurement
        on the current network state.

        Args:
            dataset (str or dataset): The dataset you want to perform the PID with.
            rv1_id_list (list): List of neuron_ids that form source 1.
            rv2_id_list (list): List of neuron_ids that form source 2.

        Keyword Args:
            estimation_method (str): Method used for mutual information estimation.
                At the moment 'naive'.
            binning_method (str): Method for discretization. At the moment 'fixed_size'
                and 'goldfeld:fixed_size'
            binning_params (dict): Additional parameters such as 'n_bins'. Defaults
                is 'n_bins' = 30. Limits for binning are automatically chosen for
                each activation function.
            choose_chapters (str): Can have the following values:
                * 'all' (performs meas on all the chapters found in the checkpoints directory)
                * 'all_chapters_same_run' (performs m all chapters of the current run)
                * 'all_runs_same_chapter' (performs meas on the same chapter of all runs)
                * 'this' (performs only on the current state of the network)
                * 'custom_list' (performs only on states given in custom_chapters_list
            custom_chapters_list (list): list of (run_id, chapter_id) tuples or list of filenames

        Returns:
            int: measurement_id of the performed measurements for quick access (all have the
                same id, because all have the same parameters)
        """

        if isinstance(dataset, str):
            dataset = self.exp.task[dataset]

        if not n_samples is None:
            assert isinstance(
                dataset, nninfo.task.LazyDataset
            ), '"n_samples" keyword only allowed for LazyDataset'
            dataset = dataset.extend(n_samples)

        chapters_list = self._choose_chapters(choose_chapters, custom_chapters_list)

        if not "limits" in estimation_params:
            estimation_params["limits"] = self.exp.network.get_binning_limits()
        if not "binning_method" in estimation_params:
            estimation_params["binning_method"] = binning_method

        for state in chapters_list:
            # load the right chapter
            if isinstance(state, str):
                self.exp.load_checkpoint(filename=state)
            elif isinstance(state, tuple):
                self.exp.load_checkpoint(run_id=state[0], chapter_id=state[1])
            run_id = self.exp.run_id
            chapter_id = self.exp.chapter_id
            epoch_id = self.exp.epoch_id

            activations_iter = self.capture_activations(
                dataset,
                repeat_dataset=repeat_dataset,
                before_noise="goldfeld" in binning_method,
                quantizer_params=quantizer_params
            )

            # set all necessary parameters
            if estimation_method == "naive" or estimation_method == "nsb_ndd" or estimation_method == "nsb":
                mi = BinningMIMeasurement(
                    run_id,
                    chapter_id,
                    epoch_id,
                    dataset,
                    estimation_params,
                    rv1_id_list,
                    rv2_id_list,
                    estimation_method,
                )
                mi.measure(activations_iter)
            elif estimation_method == "goldfeld":
                mi = nninfo.analysis.analysis.GoldfeldMIMeasurement(
                    run_id,
                    chapter_id,
                    epoch_id,
                    dataset,
                    estimation_params,
                    rv1_id_list,
                    rv2_id_list,
                )
                mi.measure(self)

            # actual measurement step
            #mi.measure(self)
            # save the measurement
            measurement_id = self._measurement_manager.save(
                mi, measurement_id=measurement_id
            )
        return measurement_id


class Measurement:
    """
    Class every Measurement inherits from. Makes sure there is some standard between different
    types of measurements.
    """

    def __init__(self, measurement_type, run_id, chapter_id, epoch_id, dataset=None):
        """
        All parameters that have to be set for every measurement are set here.

        Args:
            measurement_type (str): Identifier of the measurement type (e.g. 'fisher', 'pid' etc.)
            run_id (int): Identifier of the run.
            chapter_id (int): Identifier of the chapter.
            epoch_id (int): Identifier of the epoch.

        Kwargs:
             dataset_name (str):  If measurement_type is 'fisher' or 'pid', the dataset_name has to given.
        """

        self._params = dict()
        self._params["measurement_type"] = measurement_type
        self._params["run_id"] = run_id
        self._params["chapter_id"] = chapter_id
        self._params["epoch_id"] = epoch_id

        if measurement_type == "fisher" or measurement_type == "pid":
            if dataset is None:
                raise AttributeError
            else:
                self._params["dataset_name"] = dataset.name
                self._params["n_samples"] = len(dataset)


class WeightMeasurement(Measurement):
    """
    Recommended way of getting all weights and biases of the network
    at certain stages of the training process. The underlying data
    structure should be very comparable to FisherMeasurement.
    """

    def __init__(self, run_id, chapter_id, epoch_id, net):
        super(WeightMeasurement, self).__init__("weight", run_id, chapter_id, epoch_id)
        self.measurement_data = {}
        self._net = net

    def measure(self):
        parameters = torch.nn.utils.parameters_to_vector(
            (p.data for p in self._net.parameters() if p.requires_grad)
        )
        self.measurement_data = parameters.detach_().tolist()

    def get_measurement_dict(self, measurement_id):
        output_dict = copy.deepcopy(self._params)
        output_dict["meas_id"] = measurement_id
        output_dict["weights"] = self.measurement_data
        return output_dict


class BinningMIMeasurement(Measurement):
    def __init__(
        self,
        run_id,
        chapter_id,
        epoch_id,
        dataset,
        binning_params,
        rv1_id_list,
        rv2_id_list,
        estimation_method,
    ):
        super().__init__("mi", run_id, chapter_id, epoch_id, dataset)
        self._params["RV1"] = sorted(rv1_id_list)
        self._params["RV2"] = sorted(rv2_id_list)
        self._params["estimation_method"] = estimation_method
        self._params["binning_params"] = binning_params
        self._params['n_labels'] = binning_params['n_labels']
        self._params['dataset_name'] = dataset.name

        if binning_params["binning_method"] == "fixed_size":
            # one of the two (groupings and rv1_id_list, rv2_id_list) is actually redundant ..
            self._joint = nninfo.analysis.binning.FixedSizeBinning(
                [self._params["RV1"], self._params["RV2"]], **binning_params
            )
        elif binning_params["binning_method"] == "none":
            self._joint = nninfo.analysis.binning.NoneBinnig(
                [self._params["RV1"], self._params["RV2"]], **binning_params
            )
        elif binning_params["binning_method"] == "goldfeld:fixed_size":
            self._joint = nninfo.analysis.binning.GoldfeldBinning(
                [self._params["RV1"]], self._params["RV2"], **binning_params
            )

        self._measurement_results = None
        self._measurement_results_uncertainty = None

    def measure(self, activations_iter):

        n_samples = 0
        for activations in activations_iter:
            n_samples += len(activations[list(activations.keys())[0]])
            self._joint.apply(activations)

        joint_pdf = self._joint.get_pdf()

        if not isinstance(joint_pdf, SxPID.PDF):
            joint_pdf = SxPID.PDF.from_dict(joint_pdf)

        marginal1_pdf = self._marginalize(joint_pdf, 0)
        marginal2_pdf = self._marginalize(joint_pdf, 1)

        #Compute total numbers of bins
        n_bins1 = self._params['n_labels'] if self._params["RV1"][0][0]=='Y' else self._params["binning_params"]["n_bins"] ** len(self._params["RV1"])
        n_bins2 = self._params['n_labels'] if self._params["RV2"][0][0]=='Y' else self._params["binning_params"]["n_bins"] ** len(self._params["RV2"])
        n_bins_joined = n_bins1 * n_bins2

        if self._params["estimation_method"] == "naive":
            self._perform_naive_estimation(joint_pdf, marginal1_pdf, marginal2_pdf)
        elif self._params["estimation_method"] == "nsb_ndd":
            self._perform_nsb_estimation_ndd(joint_pdf, marginal1_pdf, marginal2_pdf, n_bins1, n_bins2, n_bins_joined, n_samples)
        elif self._params["estimation_method"] == "nsb":
            self._perform_nsb_estimation(joint_pdf, marginal1_pdf, marginal2_pdf, n_bins1, n_bins2, n_bins_joined, n_samples)
        else:
            raise NotImplementedError(
                'No mutual information estimation method apart from "naive" implemented.'
            )

    def _marginalize(self, pdf, var_index):
        # var_idx: Index of the remaining variable, marginalize over others
        uniques = np.unique(pdf.coords[:, var_index])
        marg_probs = np.array(
            [np.sum(pdf.probs[pdf.coords[:, var_index] == unq]) for unq in uniques]
        )

        return SxPID.PDF(uniques[:, np.newaxis], marg_probs)

    def _perform_naive_estimation(self, joint, marg1, marg2):
        mi = 0
        for i in range(len(joint.coords)):
            p1 = marg1.probs[np.argmax(marg1.coords[:, 0] == joint.coords[i, 0])]
            p2 = marg2.probs[np.argmax(marg2.coords[:, 0] == joint.coords[i, 1])]
            mi += joint.probs[i] * np.log2(joint.probs[i] / (p1 * p2))
        self._measurement_results = mi

    def _perform_nsb_estimation(self, joint, marg1, marg2, n_bins1, n_bins2, n_bins_joined, n_samples):
        
        h1  = self._compute_entropy_nsb(np.round(marg1.probs * n_samples).astype(int), n_bins1, n_samples)
        h2  = self._compute_entropy_nsb(np.round(marg2.probs * n_samples).astype(int), n_bins2, n_samples)
        h12 = self._compute_entropy_nsb(np.round(joint.probs * n_samples).astype(int), n_bins_joined, n_samples)

        if h1 is not None and h2 is not None and h12 is not None:
            self._measurement_results = 1/np.log(2) * (h1 + h2 - h12) #convert to bits
        elif h1 is not None and h2 is None and h12 is None:
            self._measurement_results = 1/np.log(2) * h1 #convert to bits
        elif h2 is not None and h1 is None and h12 is None:
            self._measurement_results = 1/np.log(2) * h2 #convert to bits
        else:
            self._measurement_results = np.nan
    
    def _compute_entropy_nsb(self, frequencies, n_bins, n_samples):

        if(len(frequencies) == n_samples):
            return None

        mk = make_mk(frequencies, n_bins)
        return float(_H_nsb(mk, n_bins, n_samples))

    def get_measurement_dict(self, measurement_id):
        output_dict = copy.deepcopy(self._params)
        output_dict["meas_id"] = measurement_id
        output_dict["mutual_info"] = self._measurement_results
        output_dict["mutual_info_uncertainty"] = self._measurement_results_uncertainty
        output_dict["binning_params"] = self._joint.get_params()
        return output_dict


class GoldfeldMIMeasurement(Measurement):
    def __init__(
        self,
        run_id,
        chapter_id,
        epoch_id,
        dataset,
        estimation_params,
        source_id_list,
        target_id_list,
    ):
        """
        Args:
            source_id_list: List of Neuron IDs of noisy neurons
            target_id_list: List of non-noisy neuron IDs (X or Y)
        """
        super().__init__("mi", run_id, chapter_id, epoch_id, dataset)

        self._dataset = dataset

        self._source_id_list = source_id_list
        self._target_id_list = target_id_list
        assert self._target_id_list[0][0] == 'X' or self._target_id_list[0][0] == 'Y'

        self._limits = estimation_params['limits'][self._source_id_list[0]]
        self._stddev = estimation_params["std_dev"] * (self._limits[1] - self._limits[0])

        self._normal = norm(loc=0, scale= (self._limits[1]-self._limits[0]) * estimation_params["std_dev"])
        

        self._n_x = estimation_params["n_x"]
        self._n_mc = estimation_params["n_mc"]

        self._params.update({'dataset_name':dataset.name, 'RV1':source_id_list, 'RV2':target_id_list, 'std_dev':estimation_params["std_dev"], 'n_mc':self._n_mc, 'n_x':self._n_x})

    def measure(self, analysis):

        activation_iter = analysis.capture_activations(
                self._dataset,
                repeat_dataset=1,
                before_noise=True,
            )

        source_activations = np.empty((0, len(self._source_id_list)))
        target_activations = np.empty((0, len(self._target_id_list)))

        for activations in activation_iter:

            source_activations = np.concatenate(
                (
                    source_activations,
                    np.array(
                        [activations[neuron_id] for neuron_id in self._source_id_list]
                    ).T,
                ),
                axis=0,
            )
            target_activations = np.concatenate(
                (
                    target_activations,
                    np.array(
                        [activations[neuron_id] for neuron_id in self._target_id_list]
                    ).T,
                ),
                axis=0,
            )

        # Compute unconditional entropy of source
        entropy = self.compute_entropy(source_activations, self._n_mc, progressbar=True)

        if(self._target_id_list[0][0] == 'Y'):
            # Compute conditional entropies by conditioning on each unique target activation
            unq_targets, unq_targets_inv = np.unique(
                target_activations, return_inverse=True, axis=0
            )

            sum_cond_entropy = 0
            for unq_target_index in range(len(unq_targets)):
                cond_source_activations = source_activations[
                    np.nonzero(unq_targets_inv == unq_target_index)
                ]

                cond_entropy = self.compute_entropy(cond_source_activations, self._n_mc)

                sum_cond_entropy += cond_entropy

            avg_cond_entropy = sum_cond_entropy / len(unq_targets)
        else:
            #Target is X. generate enough samples of the Source by resubmitting each x n_x times
            sum_cond_entropy = 0

            
            for target in target_activations:
                activation_iter = analysis.capture_activations_of_single_samle(
                    single_sample=target,
                    repeat_dataset=self._n_x,
                    before_noise=True,
                )

                cond_source_activations = np.empty((0, len(self._source_id_list)))

                for activations in activation_iter:

                    cond_source_activations = np.concatenate(
                        (
                            cond_source_activations,
                            np.array(
                                [activations[neuron_id] for neuron_id in self._source_id_list]
                            ).T,
                        ),
                        axis=0,
                    )

                sum_cond_entropy += self.compute_entropy(cond_source_activations, self._n_mc)

            #pool = mp.Pool(processes=10)
            #conditioned_entropies = list(map(compute_conditioned_entropy, target_activations))

            avg_cond_entropy = sum_cond_entropy / len(target_activations)
                    
        MI = entropy - avg_cond_entropy
        print(MI)

        self._measurement_results = MI

    def compute_entropy(self, mu, n_mc=10, progressbar=False):
        """
        Computes the differential entropy of a mixture of Gaussians with centres mu
        Args:
            mu: Centres of Gaussians, (n_samples, n_dim)
            n_mc: Number of Monte Carlo samples per bin center
        """

        # Generate Monte Carlo Samples
        samples = (mu[np.newaxis] + self._normal.rvs(size=(n_mc,) + mu.shape)).reshape(
            -1, mu.shape[1]
        )  # (n_mc * n_samples, n_dim)

        # Evaluate Mixture of Gaussians at Samples
        # Multiply over n_dims, average over centres mu
        entropy_sum = 0
        for sample in (samples):

            p = np.mean(np.exp(-np.sum(np.square(sample-mu), axis=-1)/(2 * self._stddev**2)))

            entropy_sum -= np.log2(p) - mu.shape[1]/2 * np.log2(2*np.pi*self._stddev**2)

        # Estimate Entropy as Monte Carlo Integral
        entropy = entropy_sum / len(samples)

        return entropy

    def get_measurement_dict(self, measurement_id):
        output_dict = copy.deepcopy(self._params)
        output_dict["meas_id"] = measurement_id
        output_dict["mutual_info"] = self._measurement_results
        return output_dict


class ContinuousMIMeasurement(Measurement):
    def __init__(
        self,
        run_id,
        chapter_id,
        epoch_id,
        dataset,
        rv1_id_list,
        rv2_id_list,
        estimator_settings=None,
        permutation_test=False
    ):
        super().__init__("cmi", run_id, chapter_id, epoch_id, dataset)
        self._params["RV1"] = sorted(rv1_id_list)
        self._params["RV2"] = sorted(rv2_id_list)
        self._params["estimator_settings"] = estimator_settings
        self._params["dataset_name"] = dataset.name
        self._params['permutation_test'] = permutation_test

        self._measurement_results = None

    def measure(self, activations_iter):

        rv1_activations = np.ndarray((len(self._params["RV1"]), 0))
        rv2_activations = np.ndarray((len(self._params["RV2"]), 0))
        for activations in activations_iter:
            rv1_activations = np.append(
                rv1_activations,
                [activations[neuron] for neuron in self._params["RV1"]],
                axis=1,
            )
            rv2_activations = np.append(
                rv2_activations,
                [activations[neuron] for neuron in self._params["RV2"]],
                axis=1,
            )

        estimator = JidtKraskovCMI(settings=self._params["estimator_settings"])

        if self._params['permutation_test']:
            np.random.shuffle(rv2_activations.T) #shuffle along first axis

        self._measurement_results = estimator.estimate(rv1_activations.T, rv2_activations.T)

        print(self._measurement_results)

    def get_measurement_dict(self, measurement_id):
        output_dict = copy.deepcopy(self._params)
        output_dict["meas_id"] = measurement_id
        output_dict["continuous_mutual_info"] = self._measurement_results
        return output_dict