import torch
import copy
from nninfo.analysis import Measurement
from torch.utils.data import DataLoader

__all__ = ["FisherMeasurement"]


class FisherMeasurement(Measurement):
    """
    Under construction.
    """

    def __init__(
        self,
        run_id,
        chapter_id,
        epoch_id,
        dataset,
        net,
        task,
        empirical=False,
        save_offdiag=False,
        output_size=1,
    ):
        """
        Args:
            run_id (int):
            chapter_id (int):
            epoch_id (int):
            dataset (str or dataset):
            net (nn.Module):
            task (nninfo.task.TaskManager):
            empirical (bool):
            save_offdiag (bool): Condition if the offdiagonal elements of the FI matrix should be calculated, too.
            output_size (int): Size of the output layer. If size is 1 the probability of the second state needs to be calculated additionally.
        """
        if isinstance(dataset, str):
            dataset = task[dataset]
        super(FisherMeasurement, self).__init__(
            "fisher", run_id, chapter_id, epoch_id, dataset=dataset
        )
        self._task = task
        self._dataset_name = dataset.name
        self._net = net
        self._empirical = empirical
        self._save_offdiag = save_offdiag
        self._output_size = output_size
        self._measurement_data = {}


    def measure(self):
        """
        Returns:
        """
        n_samples = len(self._task[self._dataset_name])
        dataloader = DataLoader(self._task[self._dataset_name], batch_size=n_samples)

        trainable_parameters = torch.nn.utils.parameters_to_vector(
            filter(lambda p: p.requires_grad, self._net.parameters())
        )
        fisher_diag = torch.zeros_like(trainable_parameters)

        if self._save_offdiag:
            fisher_off_diag = torch.zeros_like(
                torch.ger(trainable_parameters, trainable_parameters)
            )
        else:
            fisher_off_diag = None

        for inputs, labels in dataloader:
            #If output is just one neuron the probability of the second output needs to be defined, too.
            if self._output_size==1:
                out1 = self._net(inputs)
                out2 = torch.ones(out1.size())-out1
                outputs = torch.cat((out1, out2), dim=1)
            else:
                outputs = self._net(inputs)
                
            if self._empirical:
                probs = torch.unsqueeze(labels,1).detach()
            else:
                probs = (
                    torch.distributions.Categorical(probs=outputs)
                    .sample()
                    .unsqueeze(1)
                    .detach()
                )
            samples = torch.log(outputs.gather(1, probs))
            
            for i in range(inputs.size(0)):
                self._net.zero_grad()
                torch.autograd.backward(samples[i], retain_graph=True)

                trainable_parameters = torch.nn.utils.parameters_to_vector(
                    (p.grad.data for p in self._net.parameters() if p.requires_grad)
                )

                fisher_diag += trainable_parameters.pow(2)
                fisher_diag.detach_()

                if self._save_offdiag:
                    fisher_off_diag += torch.ger(
                        trainable_parameters, trainable_parameters
                    )

        fisher_diag /= float(n_samples)
        fisher_diag = fisher_diag.tolist()

        if self._save_offdiag:
            fisher_off_diag /= float(n_samples)
            fisher_off_diag = fisher_off_diag.tolist()

        self._measurement_data = (
            fisher_diag,
            fisher_off_diag,
        )

    def get_measurement_dict(self, measurement_id):
        output_dict = copy.deepcopy(self._params)
        output_dict["meas_id"] = measurement_id
        output_dict["empirical"] = self._empirical
        output_dict["save_offdiag"] = self._save_offdiag
        output_dict["fisher_diag"] = self._measurement_data[0]
        if self._save_offdiag:
            output_dict["fisher_off_diag"] = self._measurement_data[1]
        return output_dict
