import torch.nn

from .analysis_method import _SingleAnalysisMethod, ResultGeneratorType
from path_learning.utils.result import TaskResult


class TaskBatchLossAnalysis(_SingleAnalysisMethod):

    name = "task_batch_loss_analysis"

    def __init__(self, *args, **kwargs):
        super().__init__(*args)

    def analyze_model(self, task_result: TaskResult, model: torch.nn.Module) -> ResultGeneratorType:
        batch_losses = []
        logdir = self.logdir / "tmp"
        logdir.mkdir(parents=True, exist_ok=True)
        loss = task_result.generate_loss(logdir)
        dataloader = self.generate_dataloader(task_result)
        model.eval()
        for batch_idx, (feature, target) in enumerate(dataloader):
            pred_target: torch.Tensor = model(feature)
            loss_function = loss.loss_functions["test"]["callable"]
            batch_loss: float = loss_function(pred_target, target, reduction="mean").detach().cpu().numpy()
            self.logger.info(f"Found task loss {batch_loss}")
            batch_losses.append(float(batch_loss))

        yield "batch_losses", batch_losses
        if len(batch_losses) > 0:
            yield "avg_batch_loss", sum(batch_losses) / len(batch_losses)
