from typing import Generator, Dict, Tuple, Any, Callable, List
import logging
from copy import deepcopy, copy
import collections
from functools import partial

import torch.nn
from torch.distributions import Categorical

from .analysis_method import _SingleAnalysisMethod, ResultGeneratorType, _StepAnalysisMethod
from .sensitivity import SingleTaskSensitivityAnalysis
from .helpers import perturb_input, perturb_input_blur, perturb_input_contrast, perturb_input_shift,\
    perturb_input_hue, perturb_model, perturb_conv_layer, DTHETA, DX, EPSILON, N_COMPARISONS
from path_learning.utils.result import TaskResult
from path_learning.models.models import reset_bias

PERTURBATIONS = {
    "perturb-model": perturb_model,
    "perturb-blur": perturb_input_blur,
    "perturb-contrast": perturb_input_contrast,
    "perturb-shift": perturb_input_shift,
    "perturb-hue": perturb_input_hue,
    "perturb-conv-layer": perturb_conv_layer,
    "perturb-input": perturb_input,
}


class SingleTaskGradientPlasticitySensitivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_task_gradient_plasticity_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor, model_output: torch.tensor, target: torch.tensor) -> torch.Tensor:

        stabilities0 = self.compute_gradient_perturbation_sensitivity(model, loss, model_input, model_output, target)
        return stabilities0

    @staticmethod
    def get_grad(parameters) -> torch.tensor:
        param_grads = torch.zeros((0,))
        for i, param in enumerate(parameters):
            if len(param.grad.size()) > 1:
                param_grads = torch.cat((param_grads, param.grad.view(-1)))
        return param_grads

    def compute_gradient_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                                  model_input: torch.tensor, model_output: torch.tensor,
                                                  target: torch.tensor) -> torch.Tensor:
        """
        Compute grad for regular data point, and for perturbed data point and compare
        :param model:
        :param loss:
        :param model_input:
        :param model_output:
        :param target:
        :return:
        """
        optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0.0)
        # Perturbation of both input and model
        # Task0 and 1 baseline
        # for param in model.parameters():
        for name, m in model.named_modules():
            m.requires_grad = True
        grads = []
        with torch.enable_grad():
            pred_target0 = model(model_input)
            task_loss0 = loss(pred_target0, target, reduction="none")
            for i in range(task_loss0.size(0)):
                optimizer.zero_grad()
                if i < task_loss0.size(0) - 1:
                    task_loss0[i].backward(retain_graph=True)
                else:
                    task_loss0[i].backward()
                grads.append(copy(self.get_grad(model.parameters()).detach()))
            stabilities0 = 0
            count = 0
            for i in range(task_loss0.size(0)):
                for j in range(i+1, task_loss0.size(0)):
                    count += 1
                    stabilities0 += torch.dot(grads[i], grads[j])  # / torch.norm(grads[i]) / torch.norm(grads[j])
        optimizer.step()

        return stabilities0 / count  # torch.norm(stabilities0, p=2)


class SingleTaskFisherDiagSensitivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_task_fisher_diag_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor, model_output: torch.tensor, target: torch.tensor) -> torch.Tensor:

        stabilities0 = self.compute_gradient_perturbation_sensitivity(model, loss, model_input,
                                                                                    model_output, target)
        return stabilities0

    @staticmethod
    def get_grad(model) -> torch.tensor:
        param_grads = torch.zeros((0,))
        for name, m in model.named_modules():
            # print(f"name: {name}")
            if "layer" in name and ("bn" in name or "conv" in name):  #
                param_grads = torch.cat((param_grads, m.weight.grad.view(-1)), dim=0)
        return param_grads

    def compute_gradient_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                                  model_input: torch.tensor, model_output: torch.tensor,
                                                  target: torch.tensor) -> torch.Tensor:
        """
        Compute grad for regular data point, and for perturbed data point and compare
        :param model:
        :param loss:
        :param model_input:
        :param model_output:
        :param target:
        :return:
        """
        optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
        # Perturbation of both input and model
        # Task0 and 1 baseline
        fim = {}
        with torch.enable_grad():
            for name, param in model.named_parameters():
                if param.requires_grad:
                    fim[name] = 0
            pred_target0 = model(model_input)
            outdx = Categorical(logits=pred_target0).sample().unsqueeze(1).detach()
            samples = pred_target0.gather(1, outdx)
            idx, batch_size = 0, model_input.size(0)
            while idx < batch_size:
                model.zero_grad()
                torch.autograd.backward(samples[idx], retain_graph=True)
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        fim[name] += torch.sum(param.grad * param.grad)
                        fim[name].detach_()
                idx += 1
        print(f"fim: {fim}")
        fim = sum([fim[name] for name in fim])
        print(f"fim sum: {fim}")
        return fim


class SingleTaskActivationActivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_task_activation_analysis"

    def __init__(self, *args, **kwargs):
        self.activations = None
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor, model_output: torch.tensor, target: torch.tensor) -> torch.Tensor:
        stabilities0 = self.compute_gradient_perturbation_sensitivity(model, loss, model_input,
                                                                                    model_output, target)
        return stabilities0

    def get_activity(self) -> torch.tensor:
        activity = torch.zeros((len(self.activations.keys()),))
        for i, name in enumerate(self.activations.keys()):
            condition = (self.activations[name][-1] <= 0.0).float()
            print(f"shape: {condition.size()}, {torch.mean(condition.float(), dim=0)}")
            activity[i] = torch.mean(torch.mean(condition.float(), dim=0))
        print(f"activity: {activity}, Arithmetic mean: {torch.mean(activity)} ")
        return torch.mean(activity)

    def save_activation(self, name, mod, inp, out):
        self.activations[name].append(out.cpu())

    def compute_gradient_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                                model_input: Tuple[torch.tensor, torch.tensor],
                                                model_output: Tuple[torch.tensor, torch.tensor],
                                                target: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute grad for regular data point, and for perturbed data point and compare
        :param model:
        :param loss:
        :param model_inputs:
        :param model_outputs:
        :param target:
        :return:
        """
        self.activations = collections.defaultdict(list)
        handles = list()
        for name, m in model.named_modules():
            if "bn1" in name:
                handle = m.register_forward_hook(partial(self.save_activation, name))
                handles.append(handle)
            if name in ["layer1.0", "layer2.0", "layer3.0", "layer4.0",
                          "layer1.1", "layer2.1", "layer3.1", "layer4.1"]:
                handle = m.register_forward_hook(partial(self.save_activation, name))
                handles.append(handle)

        with torch.enable_grad():
            pred_target0 = model(model_input)
            stabilities0 = deepcopy(self.get_activity())
            self.activations = collections.defaultdict(list)

        for handle in handles:
            handle.remove()

        return stabilities0


class SingleTaskModelAnalysis(_SingleAnalysisMethod):
    name = "task_model_sensitivity_analysis"
    # TODO: add more values of interest here such as weight norm differentiated by layer

    def __init__(self, *args, **kwargs):
        try:
            self.n_batches = kwargs.pop("n_batches")
            self.reset_bias = kwargs.pop("reset_bias")
        except KeyError:
            logging.warning(f"invalid kwargs {self.name}, received : {args} and {kwargs}")
            raise
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable, model_inputs: Tuple[torch.tensor, torch.tensor],
                  model_outputs: Tuple[torch.tensor, torch.tensor], target: torch.tensor) -> float:
        raise NotImplementedError("MultiTaskSensitivityAnalysis is abstract")

    def analyze_model(self, task_result: TaskResult,
                      model: torch.nn.Module) -> ResultGeneratorType:

        global BIAS_VALUE
        if self.reset_bias:
            BIAS_VALUE = self.reset_bias
            print("Applying bias reset")
            model.apply(reset_bias)

        model.train()
        for name, m in model.named_modules():
            print(f"name: {name}")
            if "conv" in name:
                print("PICKED")
                weight = m.weight
                signed_distance_sum = 0
                if len(weight.size()) > 1:
                    weight = weight.view(weight.size(0), -1)
                    count = 0
                    for i in range(weight.size(0)):
                        for j in range(i+1, weight.size(0)):
                            count += 1
                            signed_distance_sum += torch.mean(torch.sign(weight[i]*weight[j]))
                    yield f"batch_stabilities0_{name}", float(signed_distance_sum) / count


class SingleStepSelectedTaskLossAnalysis(_StepAnalysisMethod):
    name = "single_step_task_loss_selected_analysis"
    """
    Purpose:
    Analyze how well (in loss) a model does on the next task
    Split performance along percentiles
    """
    def __init__(self, *args, **kwargs):
        try:
            self.n_batches = kwargs.pop("n_batches", 3)
            self.reset_bias = kwargs.pop("reset_bias", False)
            self.ranks = kwargs.pop("ranks", [5., 20., 40., 60., 80., 100.])
        except KeyError:
            logging.warning(f"invalid kwargs {self.name}, received : {args} and {kwargs}")
            raise
        super().__init__(*args, **kwargs)

    def infer_labels(self, dataloader, model, loss, certainty_threshold=0, batch_size_inference=100) \
            -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # if certainty threshold is set to 0, all labels are inferred
        # all max labels which don't meet the certainty threshold, will get a label of -1

        with torch.no_grad():
            # Important for batch normalization
            model.eval()
            # we need to apply softmax to get normalized class scores (not probabilities)
            softmax = torch.nn.Softmax(dim=1)
            num_samples = len(dataloader.dataset)
            max_pred_probabilities = None
            max_losses = None
            max_predictions = None
            targets = None
            for batch_idx, (data, target) in enumerate(dataloader):
                start_idx = batch_idx * dataloader.batch_size
                pred_target = model(data)
                task_loss = loss(pred_target, target, reduction="none")
                pred_target_softmaxed = softmax(pred_target)

                # returns the indice of the class with the highest probability and its associated certainty for
                # all images in the batch
                batch_max_pred_probabilities, batch_max_predictions = torch.max(pred_target_softmaxed, dim=1)

                if max_pred_probabilities is None:
                    max_pred_probabilities = batch_max_pred_probabilities.new_zeros(size=(min(num_samples, self.n_batches * dataloader.batch_size),))
                max_pred_probabilities[start_idx:start_idx + batch_max_pred_probabilities.size(0)] \
                    = batch_max_pred_probabilities

                if targets is None:
                    targets = target.new_zeros(size=(min(num_samples, self.n_batches * dataloader.batch_size),))
                targets[start_idx:start_idx + target.size(0)] \
                    = target

                if max_losses is None:
                    max_losses = batch_max_pred_probabilities.new_zeros(size=(min(num_samples, self.n_batches * dataloader.batch_size),))
                max_losses[start_idx:start_idx + task_loss.size(0)] = task_loss

                if max_predictions is None:
                    max_predictions = batch_max_predictions.new_zeros(size=(min(num_samples, self.n_batches * dataloader.batch_size),))
                max_predictions[start_idx:start_idx + batch_max_predictions.size(0)] \
                    = batch_max_predictions
                if batch_idx >= (self.n_batches - 1):
                    break
        return max_losses, max_pred_probabilities, targets, max_predictions

    def analyze_model(self, task_result: TaskResult, model: torch.nn.Module) -> ResultGeneratorType:

        loss = task_result.generate_loss(self.logdir)
        loss_function = loss.loss_functions["test"]["callable"]
        dataloader = self.generate_dataloader(task_result)
        # Use model "i" for task "i -1"
        losses, pred_scores, targets, predictions = self.infer_labels(dataloader, model, loss_function)
        # Return overall performance
        yield "Mean losses", float(torch.mean(losses))
        yield "Number of all data points", int(losses.size(0))
        yield "Mean accuracy", float(torch.mean((predictions == targets).float()))

        ranks = torch.FloatTensor(self.ranks)
        q = pred_scores.new_tensor([ranks / 100])
        quantiles = torch.quantile(pred_scores, q)
        for index in range(quantiles.size(0)):
            threshold: float = quantiles[index].item()
            yield f"Number of data points of percentile {ranks[index]} and threshold {threshold}", \
                  int(losses[pred_scores > threshold].size(0))
            yield f"Mean loss of percentile {ranks[index]} and threshold {threshold}", \
                  float(torch.mean(losses[pred_scores > threshold]))
            yield f"Mean accuracy of percentile {ranks[index]} and threshold {threshold}", \
                  float(torch.mean((predictions[pred_scores > threshold] == targets[pred_scores > threshold]).float()))


SINGLE_TASK_EXPLORATORY_METHODS = {
    "single-step-task-loss-analysis": SingleStepSelectedTaskLossAnalysis,
    "single-model-weight-correlation-sensitivity": SingleTaskModelAnalysis,
    "single-gradient-plasticity": SingleTaskGradientPlasticitySensitivityAnalysis,
    "single-fisher-sensitivity": SingleTaskFisherDiagSensitivityAnalysis,
    "single-activation-analysis": SingleTaskActivationActivityAnalysis,
}