from typing import Generator, Dict, Tuple, Any, Callable, List
import logging
from copy import deepcopy, copy

import torch.nn

from .sensitivity import SingleTaskSensitivityAnalysis
from .helpers import perturb_input, perturb_input_blur, perturb_input_contrast, perturb_input_shift,\
    perturb_input_hue, perturb_model, perturb_conv_layer, DTHETA, DX, EPSILON, N_COMPARISONS

PERTURBATIONS = {
    "perturb-model": perturb_model,
    "perturb-blur": perturb_input_blur,
    "perturb-contrast": perturb_input_contrast,
    "perturb-shift": perturb_input_shift,
    "perturb-hue": perturb_input_hue,
    "perturb-conv-layer": perturb_conv_layer,
    "perturb-input": perturb_input,
}


class SingleTaskPerturbGradientSensitivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_task_perturb_gradient_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        self.perturb_method_name = kwargs.get("perturb_method")
        self.perturb_method = PERTURBATIONS[self.perturb_method_name]
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor,
                  model_output: torch.tensor, target: torch.tensor) -> torch.Tensor:
        stabilities0 = self.compute_gradient_perturbation_sensitivity(model, loss, model_input,
                                                                      model_output, target)
        return stabilities0

    @staticmethod
    def get_param_grad(parameters) -> torch.tensor:
        param_grads = torch.zeros((0,))
        for i, param in enumerate(parameters):
            param_grads = torch.cat((param_grads, param.grad.view(-1)), dim=0)
        return param_grads

    @staticmethod
    def get_grad(data) -> torch.tensor:
        return data.grad

    def compute_gradient_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                                  model_input: torch.tensor, model_output: torch.tensor,
                                                  target: torch.tensor) -> torch.Tensor:
        """
       Compute grad for regular data point, and for perturbed data point and compare
       :param model:
       :param loss:
       :param model_input:
       :param model_output:
       :param target:
       :return:
       """
        model_input.requires_grad = True
        # Want partial gradient w.r.t. both input and parameters

        with torch.enable_grad():
            # Input 0
            optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0)
            for param in model.parameters():
                param.requires_grad = True
            optimizer.zero_grad()
            pred_target0 = model(model_input)
            task_loss0 = loss(pred_target0, target, reduction="mean")
            task_loss0.backward()
            grad0_x_before = deepcopy(self.get_param_grad(model.parameters()))

            # perturbed input
            optimizer.zero_grad()
            model_input.requires_grad = False
            # Use perturb function
            perturb_input0 = self.perturb_method(model_input)
            perturb_input0.requires_grad = True
            pred_target0_p = model(perturb_input0)
            task_loss0_p = loss(pred_target0_p, target, reduction="mean")
            task_loss0_p.backward()
            grad0_x_after = deepcopy(self.get_param_grad(model.parameters()))

            stabilities0 = torch.sum(torch.dot(grad0_x_after, grad0_x_before) / torch.norm(grad0_x_before, p=2) \
                           / torch.norm(grad0_x_after, p=2))
        return stabilities0


class SingleTaskGradientSensitivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_task_gradient_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor, model_output: torch.tensor, target: torch.tensor) -> torch.Tensor:

        stabilities0 = self.compute_gradient_perturbation_sensitivity(model, loss, model_input,
                                                                                    model_output, target)
        return stabilities0

    @staticmethod
    def get_grad(model) -> torch.tensor:
        param_grads = torch.zeros((0,))
        for name, m in model.named_modules():
            # Ability to choose which gradients to look at based on name
            # if "layer1" in name and ("bn" in name or "conv" in name):  #
            param_grads = torch.cat((param_grads, m.weight.grad.view(-1)), dim=0)
        return param_grads

    def compute_gradient_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                                model_input: torch.tensor, model_output: torch.tensor,
                                                target: torch.tensor) -> torch.Tensor:
        """
        Compute grad for regular data point, and for perturbed data point and compare
        :param model:
        :param loss:
        :param model_input:
        :param model_output:
        :param target:
        :return:
        """
        optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
        # Perturbation of both input and model
        # Task0 and 1 baseline
        for name, m in model.named_modules():
            m.requires_grad = True

        with torch.enable_grad():
            optimizer.zero_grad()
            pred_target0 = model(model_input)
            task_loss0 = loss(pred_target0, target, reduction="mean")
            task_loss0.backward()
            stabilities0 = copy(self.get_grad(model))
            optimizer.step()

        return torch.norm(stabilities0, p=2)


class SingleTaskInputGradientSensitivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_taskinput_gradient_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor,
                  model_output: torch.tensor, target: torch.tensor) -> torch.Tensor:
        stabilities0 = self.compute_gradient_perturbation_sensitivity(model, loss,
                                                                                                   model_input,
                                                                                                   model_output,
                                                                                                   target)
        return stabilities0

    @staticmethod
    def get_grad(data) -> torch.tensor:
        return data.grad

    def compute_gradient_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                                  model_input: torch.tensor, model_output: torch.tensor,
                                                  target: torch.tensor) -> torch.Tensor:
        """
        Compute grad for regular data point, and for perturbed data point and compare
        :param model:
        :param loss:
        :param model_inputs:
        :param model_outputs:
        :param target:
        :return:
        """
        optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0)
        # Perturbation of both input and model
        # Task0 and 1 baseline
        for param in model.parameters():
            param.requires_grad = True
        model_input.requires_grad = True

        with torch.enable_grad():
            optimizer.zero_grad()
            pred_target0 = model(model_input)
            task_loss0 = loss(pred_target0, target, reduction="mean")
            task_loss0.backward()
            grad0 = deepcopy(self.get_grad(model_input))
            stabilities0 = torch.mean(torch.norm(grad0, p=2, dim=(1, 2, 3))**2, dim=0)
        return stabilities0


class SingleTaskInputAndParameterSensitivityAnalysis(SingleTaskSensitivityAnalysis):
    name = "single_task_input_and_parameter_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @staticmethod
    def get_param_grad(parameters) -> torch.tensor:
        param_grads = torch.zeros((0,))
        for i, param in enumerate(parameters):
            param_grads = torch.cat((param_grads, param.grad.view(-1)), dim=0)
        return param_grads

    def criterion(self, model: torch.nn.Module, loss: Callable,
                  model_input: torch.tensor, model_output: torch.tensor, target: torch.tensor) -> float:
        stabilities0 = self.compute_input_perturbation_sensitivity(model, loss, model_input, model_output, target)
        return stabilities0

    def compute_input_perturbation_sensitivity(self, model: torch.nn.Module, loss: Callable,
                                               model_input: torch.tensor, model_output: torch.tensor,
                                               target: torch.tensor) -> torch.Tensor:

        stabilities0 = torch.zeros((N_COMPARISONS,))
        model_input.requires_grad = True
        # Want partial gradient w.r.t. both input and parameters

        with torch.enable_grad():
            # Input 0
            optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0)
            for param in model.parameters():
                param.requires_grad = True
            optimizer.zero_grad()
            pred_target0 = model(model_input)
            task_loss0 = loss(pred_target0, target, reduction="mean")
            task_loss0.backward()
            grad0_x_before = deepcopy(self.get_param_grad(model.parameters()))
            for i in range(N_COMPARISONS):
                # perturbed input
                optimizer.zero_grad()
                model_input.requires_grad = False
                perturb_input0 = perturb_input(model_input, seed=i)
                perturb_input0.requires_grad = True
                pred_target0_p = model(perturb_input0)
                task_loss0_p = loss(pred_target0_p, target, reduction="mean")
                task_loss0_p.backward()
                grad0_x_after = deepcopy(self.get_param_grad(model.parameters()))

                stabilities0[i] = torch.sum((abs(grad0_x_after - grad0_x_before) \
                                          / torch.norm(model_input - perturb_input0, p=2)) ** 2, dim=0)
        return torch.mean(stabilities0)


SINGLE_TASK_SENSITIVITY_METHODS = {
    "single-input-perturbation-gradient-sensitivity": SingleTaskPerturbGradientSensitivityAnalysis,
    "single-gradient-sensitivity": SingleTaskGradientSensitivityAnalysis,
    "single-input-gradient-sensitivity": SingleTaskInputGradientSensitivityAnalysis,
    "single-input-and-parameter-sensitivity": SingleTaskInputAndParameterSensitivityAnalysis,
}
