from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import ClassVar

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import functional as tmf

MetricFn = Callable[..., float]


class Metrics(ABC):
    metrics: dict[str, MetricFn]

    @staticmethod
    def _confusion_matrix(preds: Tensor, targets: Tensor, **kwargs) -> Figure:
        """Return a Seaborn heat-map of the confusion matrix."""
        cm = tmf.confusion_matrix(preds, targets, **kwargs).numpy(force=True)
        fig, ax = plt.subplots(figsize=(12, 12))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False, square=True, ax=ax)
        ax.set_xlabel("Predicted")
        ax.set_ylabel("True")
        ax.set_title("Confusion Matrix")
        fig.tight_layout()
        return fig

    @abstractmethod
    def plot_cm(self, yhs: Tensor, ys: Tensor) -> Figure:
        pass

    @abstractmethod
    def _prepare(self, yhs: Tensor, ys: Tensor) -> tuple[Tensor, Tensor]:
        pass

    def _score(self, fn: MetricFn, preds: Tensor, targets: Tensor) -> float:  # noqa: PLR6301
        return fn(preds, targets)

    def __call__(self, yhs: Tensor, ys: Tensor) -> dict[str, float]:
        preds, targets = self._prepare(yhs, ys)
        # print(torch.where(targets != preds)[0].tolist())
        return {name: self._score(fn, preds, targets) for name, fn in self.metrics.items()}


class BinaryMetrics(Metrics):
    main_metrics: ClassVar[dict[str, Callable]] = {
        "acc": lambda y_pred, y_true: tmf.accuracy(y_pred.sigmoid().round(), y_true, task="binary").item(),
        "f1": lambda y_pred, y_true: tmf.f1_score(y_pred.sigmoid().round(), y_true, task="binary").item(),
        "roc": lambda y_pred, y_true: tmf.auroc(y_pred, y_true.long(), task="binary").item(),  # type: ignore
        "recall": lambda y_pred, y_true: tmf.recall(
            y_pred.sigmoid().round(), y_true, task="multiclass", num_classes=2, average="macro"
        ).item(),
    }

    def __init__(self, metrics: list[str] | None = None) -> None:
        metrics = metrics or list(BinaryMetrics.main_metrics.keys())
        self.metrics = {k: v for k, v in BinaryMetrics.main_metrics.items() if k in metrics}

    def plot_cm(self, yhs: Tensor, ys: Tensor) -> Figure:
        preds, targets = self._prepare(yhs, ys)
        return self._confusion_matrix(preds, targets, task="binary")

    def _prepare(self, yhs: Tensor, ys: Tensor) -> tuple[Tensor, Tensor]:  # noqa: PLR6301
        return yhs, ys


class MultiClassMetrics(Metrics):
    main_metrics: ClassVar[dict[str, Callable]] = {
        "f1": lambda y_pred, y_true, n_classes: tmf.f1_score(
            y_pred,
            y_true,
            task="multiclass",
            num_classes=n_classes,
            average="macro",
        ).item(),
        "recall": lambda y_pred, y_true, n_classes: tmf.f1_score(
            y_pred,
            y_true,
            task="multiclass",
            num_classes=n_classes,
            average="macro",
        ).item(),
        "acc": lambda y_pred, y_true, n_classes: tmf.accuracy(
            y_pred,
            y_true,
            task="multiclass",
            num_classes=n_classes,
            average="micro",
        ).item(),
    }

    def __init__(self, metrics: list[str] | None = None) -> None:
        metrics = metrics or list(MultiClassMetrics.main_metrics.keys())
        self.metrics = {k: v for k, v in MultiClassMetrics.main_metrics.items() if k in metrics}

    @staticmethod
    def decode_binary_to_multiclass(yhs: Tensor, ys: Tensor, n_classes: int) -> tuple[Tensor, Tensor]:
        # Reshape logits into (n_samples, n_classes)
        ys_2d = ys.view(-1, n_classes)
        yhs_2d = yhs.view(-1, n_classes)

        # Predicted class = highest logit
        true_ids = ys_2d.argmax(dim=1)
        pred_ids = yhs_2d.argmax(dim=1)

        return pred_ids, true_ids

    @staticmethod
    def get_n_classes(ys: Tensor) -> int:
        # number of elems before hitting 2nd sample
        return int((ys == ys[0]).nonzero(as_tuple=True)[0][1].item())

    def plot_cm(self, yhs: Tensor, ys: Tensor) -> Figure:
        n_classes = self.get_n_classes(ys)
        preds, targets = self.decode_binary_to_multiclass(yhs, ys, n_classes=n_classes)
        return self._confusion_matrix(preds, targets, task="multiclass", num_classes=n_classes)

    def _prepare(self, yhs: Tensor, ys: Tensor) -> tuple[Tensor, Tensor]:
        n_classes = self.get_n_classes(ys)
        self.n_classes = n_classes
        return self.decode_binary_to_multiclass(yhs, ys, n_classes=n_classes)

    def _score(self, fn: MetricFn, preds: Tensor, targets: Tensor) -> float:
        return fn(preds, targets, self.n_classes)
