from typing import Generator, Dict, Tuple, Any, Callable, List
import logging
from copy import deepcopy

import torch.nn

from .analysis_method import _SingleAnalysisMethod, ResultGeneratorType, _StepAnalysisMethod
from path_learning.utils.result import TaskResult
from path_learning.models.models import reset_bias


class SingleTaskSensitivityAnalysis(_SingleAnalysisMethod):
    name = "task_sensitivity_analysis"

    def __init__(self, *args, **kwargs):
        try:
            self.n_batches = kwargs.pop("n_batches")
            self.reset_bias = kwargs.pop("reset_bias")
        except KeyError:
            logging.warning(f"invalid kwargs {self.name}, received : {args} and {kwargs}")
            raise
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable, model_inputs: Tuple[torch.tensor, torch.tensor],
                  model_outputs: Tuple[torch.tensor, torch.tensor], target: torch.tensor) -> float:
        raise NotImplementedError("MultiTaskSensitivityAnalysis is abstract")

    @staticmethod
    def get_mean(data_list: List) -> torch.tensor:
        print(f"len datalist {len(data_list)}")
        for i in range(len(data_list)):
            if i == 0:
                values = data_list[i]
            else:
                values += data_list[i]
        return float(torch.sum(values)) / len(data_list)  # Mean

    def analyze_model(self, task_result: TaskResult,
                      model: torch.nn.Module) -> ResultGeneratorType:

        loss = task_result.generate_loss(self.logdir)
        loss_function = loss.loss_functions["test"]["callable"]
        dataloader = self.generate_dataloader(task_result)

        model = deepcopy(model)
        global BIAS_VALUE
        if self.reset_bias:
            BIAS_VALUE = self.reset_bias
            print("Applying bias reset")
            model.apply(reset_bias)
        model.train()
        with torch.enable_grad():
            batch_stabilities0 = []

            for batch_idx, (feature0, target) in enumerate(dataloader):
                # Both using model0
                pred_target: torch.Tensor = target
                result = self.criterion(model, loss_function, feature0,
                                        pred_target, target)
                batch_stabilities0.append(result)

                # Stopping condition
                if batch_idx > self.n_batches:
                    break

            batch_stabilities0 = [self.get_mean(batch_stabilities0)]
            yield f"batch_stabilities0", batch_stabilities0
            if len(batch_stabilities0) > 0:
                yield f"average_stability0", sum(batch_stabilities0) / len(batch_stabilities0)
