from typing import Callable, Tuple

import torch

from .helpers import perturb_model, input_normalize, output_normalize, input_output_normalize
from .stability import _StabilityAnalysis

N_COMPARISONS = 10
DTHETA = 0.1


# ----------- LOSS LEVEL ---------------

class LossStabilityAnalysis(_StabilityAnalysis):

    name = "loss_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @staticmethod
    def compute_stabilities(model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        stabilities = torch.zeros((target.size(0), N_COMPARISONS))
        task_loss0 = loss_function(model_outputs[0], target, reduction="none")
        task_loss1 = loss_function(model_outputs[1], target, reduction="none")

        for i in range(N_COMPARISONS):
            model_perturb = perturb_model(model, DTHETA)

            pred_target0_perturb = model_perturb(model_inputs[0])
            pred_target1_perturb = model_perturb(model_inputs[1])

            task_loss0_perturb = loss_function(pred_target0_perturb, target, reduction="none")
            task_loss1_perturb = loss_function(pred_target1_perturb, target, reduction="none")

            stabilities[:, i] = torch.mean(abs((task_loss1_perturb - task_loss1) -
                                            (task_loss0_perturb - task_loss0))) / DTHETA

        return stabilities, task_loss0

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:
        stabilities, _ = self.compute_stabilities(model, loss_function, model_inputs, model_outputs, target)
        return float(torch.mean(stabilities))


class InputLossStabilityAnalysis(LossStabilityAnalysis):

    name = "input_loss_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        loss_stabilities, _ = super().compute_stabilities(model, loss_function, model_inputs, model_outputs, target)
        input_loss_stabilities = input_normalize(loss_stabilities, model_inputs)
        return float(torch.mean(input_loss_stabilities))


class OutputLossStabilityAnalysis(LossStabilityAnalysis):

    name = "output_loss_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        loss_stabilities, task_loss0 = super().compute_stabilities(model, loss_function,
                                                                   model_inputs, model_outputs, target)

        output_loss_stabilities: torch.tensor = output_normalize(loss_stabilities, task_loss0)
        return float(torch.mean(output_loss_stabilities))


class IOLossStabilityAnalysis(LossStabilityAnalysis):

    name = "io_loss_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        loss_stabilities, task_loss0 = super().compute_stabilities(model, loss_function,
                                                                   model_inputs, model_outputs, target)

        io_loss_stabilities = input_output_normalize(loss_stabilities, model_inputs, task_loss0)
        return float(torch.mean(io_loss_stabilities))


# ----------- FUNCTION LEVEL ---------------

class FunctionStabilityAnalysis(_StabilityAnalysis):

    name = "function_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        stabilities = self.compute_stabilities(model, model_inputs, model_outputs, target)
        return float(torch.mean(stabilities))

    @staticmethod
    def compute_stabilities(model: torch.nn.Module, model_inputs: Tuple[torch.tensor, torch.tensor],
                            model_outputs: Tuple[torch.tensor, torch.tensor], target: torch.tensor) -> torch.Tensor:
        stabilities = torch.zeros((target.size(0), N_COMPARISONS))

        for i in range(N_COMPARISONS):
            model_perturb = perturb_model(model)

            pred_target0_perturb = model_perturb(model_inputs[0])
            pred_target1_perturb = model_perturb(model_inputs[1])

            stabilities[:, i] = torch.mean(
                abs((pred_target1_perturb - model_outputs[1]) - (pred_target0_perturb - model_outputs[0]))) / DTHETA

        return stabilities


class InputStabilityAnalysis(FunctionStabilityAnalysis):

    name = "input_function_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        stabilities: torch.tensor = self.compute_stabilities(model, model_inputs, model_outputs, target)
        normed_stabilities = input_normalize(stabilities, model_inputs)
        return float(torch.mean(normed_stabilities))


class OutputStabilityAnalysis(FunctionStabilityAnalysis):

    name = "output_function_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        stabilities = self.compute_stabilities(model, model_inputs, model_outputs, target)
        normed_stabilities = output_normalize(stabilities, model_outputs[0])
        return float(torch.mean(normed_stabilities))


class IOStabilityAnalysis(FunctionStabilityAnalysis):

    name = "io_function_stability_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor], model_outputs: Tuple[torch.tensor, torch.tensor],
                  target: torch.tensor) -> float:

        stabilities = self.compute_stabilities(model, model_inputs, model_outputs, target)
        normed_stabilities = input_output_normalize(stabilities, model_inputs, model_outputs[0])
        return float(torch.mean(normed_stabilities))


STABILITY_METHODS = {
    "loss-stability": LossStabilityAnalysis,
    "input-loss-stability": InputLossStabilityAnalysis,
    "output-loss-stability": OutputLossStabilityAnalysis,
    "io-loss-stability": IOLossStabilityAnalysis,
    "plain-stability": FunctionStabilityAnalysis,
    "input-stability": InputStabilityAnalysis,
    "output-stability": OutputStabilityAnalysis,
    "io-stability": IOStabilityAnalysis,
}