from typing import Tuple, Callable
import logging

import torch

from path_learning.utils.result import TaskResult
from path_learning.models.models import reset_bias
from .analysis_method import _PairwiseAnalysisMethod, ResultGeneratorType


class _StabilityAnalysis(_PairwiseAnalysisMethod):
    name = "stability_analysis"

    def __init__(self, *args, **kwargs):
        self.n_batches = kwargs.pop("n_batches")
        self.reset_bias = None
        try:
            self.reset_bias = kwargs.pop("reset_bias")
        except KeyError:
            logging.warning(f"invalid kwargs {self.name}, received : {args} and {kwargs}")
            raise
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable, model_inputs: Tuple[torch.tensor, torch.tensor],
                  model_outputs: Tuple[torch.tensor, torch.tensor], target: torch.tensor) -> float:
        raise NotImplementedError("_StabilityAnalysis is abstract")

    def analyze_model(self, task_results: Tuple[TaskResult, TaskResult],
                      models: Tuple[torch.nn.Module, torch.nn.Module]) -> ResultGeneratorType:

        loss = task_results[1].generate_loss(self.logdir)
        loss_function = loss.loss_functions["test"]["callable"]
        dataloader = self.generate_dataloader(task_results[0], task_results[1])

        if self.reset_bias is not None:
            models[0].apply(reset_bias)
            models[1].apply(reset_bias)

        models[0].train()
        models[1].train()
        with torch.no_grad():
            batch_stabilities = []

            for batch_idx, (feature0, feature1, target) in enumerate(dataloader):
                pred_target0 = models[0](feature0)
                pred_target1 = models[1](feature1)

                result = self.criterion(models[0], loss_function, (feature0, feature1), (pred_target0, pred_target1),
                                        target)
                batch_stabilities.append(result)

                # Stopping condition
                if batch_idx > self.n_batches:
                    break

            yield f"batch_stabilities", batch_stabilities
            if len(batch_stabilities) > 0:
                yield f"average_stability", sum(batch_stabilities) / len(batch_stabilities)