import copy
from nninfo.analysis import Measurement
from nninfo.analysis import Binning
from nninfo.config import N_WORKERS
from nninfo.config import CLUSTER_MODE
from external.SxPID.sxpid import SxPID
from external.BROJA_2PID import BROJA_2PID
import nninfo.log

__all__ = ["PIDMeasurement"]

log = nninfo.log.get_logger(__name__)


class PIDMeasurement(Measurement):
    """
    A single PID Measurement, that can be saved into the file structure after computation.
    """

    def __init__(
        self,
        run_id,
        chapter_id,
        epoch_id,
        dataset,
        pid_definition,
        binning_params,
        target_id_list,
        source_id_lists,
        pointwise=False,
        parts="all",
    ):
        """
        Args:
            parts: 'all' - informative and misinformative part
                   'inf' - informative part only
                   'mis' - misinformative part only
        """

        super(PIDMeasurement, self).__init__(
            "pid", run_id, chapter_id, epoch_id, dataset
        )

        self._params["pid_definition"] = pid_definition
        self._params["parts"] = parts

        self._params["T"] = sorted(target_id_list)
        self._pointwise = pointwise

        # determines bivariate, trivariate etc. -> self._params['n_sources'] = 2, 3 etc.
        # sets self._params['S1'] = source1_id_list, self._params['S2'] = source2_id_list, ... etc
        self._n_sources = len(source_id_lists)
        assert (
            self._n_sources >= 2
        ), "Number of sources for PID measurement must be larger or equal to two."

        self._source_target_labels = []
        for i, el in enumerate(source_id_lists):
            if el is None:
                if i < 2:
                    raise NotImplementedError(
                        "Less than 2 sources are not allowed at the moment."
                    )
                else:
                    self._params["n_sources"] = i
                break
            else:
                self._params["S" + str(i + 1)] = sorted(el)
                self._source_target_labels.append("S" + str(i + 1))

        # Depending on the PID definition, target is either first or last variable
        if self._params["pid_definition"] == "sxpid":
            self._source_target_labels = self._source_target_labels + ["T"]
            self._source_target_id_lists = source_id_lists + [target_id_list]
        else:
            self._source_target_labels = ["T"] + self._source_target_labels
            self._source_target_id_lists = [target_id_list] + source_id_lists

        # creates a new instance of Binning,
        # using the function network.get_layer_limits() to
        # get the necessary limits for fixed_size binning.
        # (bugfix: use a local version of binning_params)
        binning_params = copy.deepcopy(binning_params)
        binning_method = binning_params.pop("binning_method")
        self._binning = Binning.from_binning_method(
            binning_method,
            self._source_target_id_lists,
            source_id_lists,
            target_id_list,
            **binning_params
        )

        self._measurement_results = None

    def measure(self, activations_iter):
        """
        activations_iter: iterator
        """

        self._binning.reset()
        for activations in activations_iter:
            self._binning.apply(activations)

        pdf_dict = self._binning.get_pdf()

        if self._params["pid_definition"] == "n_rlz":
            self._measurement_results = {"n_rlz": len(pdf_dict)}
        if self._params["pid_definition"] == "sxpid":
            self._perform_accelerated_sxpid(
                pdf_dict, parts=self._params["parts"], pointwise=self._pointwise)
        elif self._params["pid_definition"] == "brojapid":
            self._perform_brojapid(pdf_dict)

    def _perform_sxpid(self, sxpid_pdf):
        raise NotImplementedError(
            "This is the previous version of SxPID,"
            + "and not recommended to use any more."
        )
        lattice_reader = nninfo.file_io.FileManager(
            "../../external/SxPID/sxpid/", read=True
        )
        lattices = lattice_reader.read("lattices.pkl")
        n = self._params["n_sources"]
        self._measurement_results = SxPID.pid(
            n, sxpid_pdf, lattices[n][0], lattices[n][1], False
        )

    def _perform_accelerated_sxpid(self, sxpid_pdf, parts='all', pointwise=False):
        self._measurement_results = SxPID.pid(
            sxpid_pdf, verbose=0 if CLUSTER_MODE else 2, n_threads=N_WORKERS, parts=parts, pointwise=pointwise
        )

    def _perform_brojapid(self, brojapid_pdf):
        log.info("Broja was not tested yet and is therefore not recommended to use.")
        self._measurement_results = BROJA_2PID.pid(
            brojapid_pdf, cone_solver="ECOS", output=0
        )

    @staticmethod
    def _select_from_activations(activations, id_list):
        return {neuron_id: activations[neuron_id] for neuron_id in id_list}

    def get_measurement_dict(self, measurement_id):
        output_dict = copy.deepcopy(self._params)
        output_dict["meas_id"] = measurement_id
        output_dict["binning_params"] = self._binning.get_params()

        # double check this step!! (takes the measurement results and puts them in a new format)
        # should make average_pid the same format for both broja and sxpid
        if self._params["pid_definition"] == "brojapid":
            avg_pid = dict()
            red = self._measurement_results["SI"]
            syn = self._measurement_results["CI"]
            un1 = self._measurement_results["UIY"]
            un2 = self._measurement_results["UIZ"]
            avg_pid[((1, 2,),)] = syn
            avg_pid[((1,),)] = un1
            avg_pid[((2,),)] = un2
            avg_pid[((1,), (2,),)] = red
            output_dict["average_pid"] = avg_pid
        elif self._params["pid_definition"] == "sxpid":

            if(self._pointwise):
                ptw, avg = self._measurement_results
            else:
                avg = self._measurement_results

            if self._pointwise:
                output_dict["pointwise_pid"] = copy.deepcopy(ptw)
            # reshaping output to make it more comparable to broja output
            # (uses more hard disk space though)
            avg_pid, inf_pid, misinf_pid = dict(), dict(), dict()
            for k, v in avg.items():
                inf_pid[k] = v[0]
                misinf_pid[k] = v[1]
                avg_pid[k] = v[2]
            output_dict["average_pid"] = avg_pid
            output_dict["informative_pid"] = inf_pid
            output_dict["misinformative_pid"] = misinf_pid
        elif self._params["pid_definition"] == "n_rlz":
            output_dict["n_rlz"] = self._measurement_results["n_rlz"]
        else:
            raise NotImplementedError
        return output_dict
