import functools
import itertools
import logging
from abc import abstractmethod

import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchmetrics import Accuracy, MeanMetric
from torchmetrics.classification.accuracy import MulticlassAccuracy
from tqdm.autonotebook import tqdm

from .base_task import BaseTask

log = logging.getLogger(__name__)


class ClassificationTask(BaseTask):
    def __init__(self, task_config):
        super().__init__(task_config)

    @property
    @abstractmethod
    def num_classes(self):
        """
        Returns the number of classes in the dataset.
        """
        pass

    @property
    @abstractmethod
    def test_loader(self):
        """
        Returns a test data loader.
        """
        pass

    @torch.no_grad()
    def evaluate(self, classifier: nn.Module, loader=None, device=None):
        """
        Evaluate the model on the specified dataset loader.
    
        Args:
            classifier (nn.Module): The classifier model to evaluate.
            loader (DataLoader, optional): The DataLoader for the dataset to evaluate on. Defaults to self.test_loader.
            device (torch.device, optional): The device to run the evaluation on. Defaults to None.
    
        Returns:
            dict: A dictionary containing accuracy and loss metrics.
        """
        # Use provided loader or default to self.test_loader
        test_loader = loader if loader is not None else self.test_loader
    
        accuracy: MulticlassAccuracy = Accuracy(
            task="multiclass", num_classes=self.num_classes
        )
        classifier.eval()
        loss_metric = MeanMetric()
    
        # Fast dev mode: Evaluate on a single batch
        if self.config.get("fast_dev_run", False):
            log.info("Running under fast_dev_run mode, evaluating on a single batch.")
            test_loader = itertools.islice(test_loader, 1)

        for batch in (
            pbar := tqdm(
                test_loader, desc="Evaluating", leave=False, dynamic_ncols=True
            )
        ):
            inputs, targets = batch
            if device is not None:
                inputs, targets = inputs.to(device), targets.to(device)
            logits: Tensor = classifier(inputs)

            loss = F.cross_entropy(logits, targets)
            loss_metric.update(loss.detach().cpu())
            acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
            pbar.set_postfix(
                {
                    "accuracy": accuracy.compute().item(),
                    "loss": loss_metric.compute().item(),
                }
            )

        acc = accuracy.compute().item()
        loss = loss_metric.compute().item()
        results = {"accuracy": acc, "loss": loss}
        return results
