from __future__ import annotations

import abc
import math
from typing import TYPE_CHECKING, Optional, Any

import lightning as L
import torch
import torchmetrics
import torchmetrics.classification
from torch import nn, optim

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class LogitVariance(torchmetrics.Metric):
    """Measures the variance in the logits of a Bayesian neural network.

    :param reduction: Whether to sum or average the variance per test point and output dimension.
    """

    def __init__(self, reduction: str = "mean"):
        super().__init__()
        self.reduction = reduction

        # State variables
        self.add_state("logits", default=[], dist_reduce_fx="cat")

    def update(self, preds: Float[Tensor, "samples batch output_dim"]):
        self.logits.append(preds.detach().clone())

    def compute(self):
        # Stack predictions across batches: (num_samples, batch_size, output_dim)
        stacked_logits = torch.cat(self.logits, dim=-2)

        # Compute variance across the sample dimension (dim=0)
        if stacked_logits.ndim == 2:
            variance = torch.tensor((0.0))
        elif stacked_logits.ndim == 3:
            variance = torch.var(stacked_logits, dim=0, unbiased=True)
        else:
            raise NotImplementedError

        # Apply reduction over batch dimension and output dimension
        if self.reduction == "mean":
            return variance.mean()
        elif self.reduction == "sum":
            return variance.sum()
        else:
            raise NotImplementedError

    def reset(self) -> None:
        """Reset the metric state."""
        super().reset()
        self.logits.clear()


class _OODModel(L.LightningModule, abc.ABC):
    """Base class for all out-of-distribution detection models.

    :param num_classes:         Number of classes in the dataset.
    :param lr:                  Learning rate of the optimizer.
    :param momentum:            Momentum of the optimizer.
    :param nesterov:            Whether to use Nesterov momentum.
    :param weight_decay:        Weight decay of the optimizer.
    :param max_epochs:          Maximum number of epochs to train the model.
    """

    def __init__(
        self,
        num_classes: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
    ) -> None:
        super().__init__()

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Validation / test metrics
        self.num_classes = num_classes
        self.accuracy = torchmetrics.classification.Accuracy(
            task="binary" if num_classes == 2 else "multiclass",
            num_classes=num_classes,
        )
        self.accuracy_ood = torchmetrics.classification.Accuracy(
            task="binary" if num_classes == 2 else "multiclass",
            num_classes=num_classes,
        )
        self.top_5_accuracy = None
        self.top_5_accuracy_ood = None
        if num_classes > 5:
            self.top_5_accuracy = torchmetrics.classification.Accuracy(
                task="multiclass", top_k=5, num_classes=num_classes
            )
            self.top_5_accuracy_ood = torchmetrics.classification.Accuracy(
                task="multiclass", top_k=5, num_classes=num_classes
            )

        self.calibration_error = torchmetrics.classification.CalibrationError(
            task="binary" if num_classes == 2 else "multiclass",
            num_classes=num_classes,
        )
        self.calibration_error_ood = torchmetrics.classification.CalibrationError(
            task="binary" if num_classes == 2 else "multiclass",
            num_classes=num_classes,
        )
        self.nll = nn.NLLLoss()
        self.binary_auroc = torchmetrics.classification.BinaryAUROC(thresholds=None)
        self.logit_variance = LogitVariance()
        self.logit_variance_ood = LogitVariance()

        # Optimizer
        self.lr = lr
        self.momentum = momentum
        self.nesterov = nesterov
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        lr: float,
        momentum: float,
        weight_decay: float,
        seed: int,
        root_dir: str,
    ) -> _OODModel:
        raise NotImplementedError

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        raise NotImplementedError

    def training_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
    ) -> Float[Tensor, ""]:

        inputs, targets = batch
        pred_targets = self(inputs)
        loss = self.loss_fn(pred_targets, targets)

        self._log_training_step(loss)

        return loss

    def _log_training_step(self, loss: Float[Tensor, ""]):
        self.log(
            "Train Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

    def validation_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_targets = self(inputs)

        self._log_validation_step(pred_targets, targets)

    def _log_validation_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
    ):

        # Compute metrics on batch
        self.accuracy(pred_targets, targets)
        if self.top_5_accuracy is not None:
            self.top_5_accuracy(pred_targets, targets)
        self.calibration_error(pred_targets, targets)
        self.logit_variance(pred_targets)

        # Log metrics
        # Note: Lightning accumulates torchmetrics over the entire test set by calling metric.compute()
        # at the end of the test epoch.
        logging_dict = {
            "Validation Accuracy": self.accuracy,
            "Validation ECE": self.calibration_error,
            "Validation NLL": self.nll(
                nn.functional.log_softmax(pred_targets, dim=-1), targets
            ),
            "Validation Logit Variance": self.logit_variance,
        }
        if self.top_5_accuracy is not None:
            logging_dict["Validation Top-5 Accuracy"] = self.top_5_accuracy
        self.log_dict(logging_dict)

    def test_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_targets = self(inputs)

        self._log_test_step(pred_targets, targets, dataloader_idx)

    def _log_test_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        dataloader_idx: int = 0,
    ):

        log_p = nn.functional.log_softmax(pred_targets, dim=-1)
        p = nn.functional.softmax(pred_targets, dim=-1)
        pred_targets_entropy = -torch.sum(p * log_p, dim=-1)

        if dataloader_idx == 0:
            # Compute ID Metrics

            self.accuracy(pred_targets, targets)
            if self.top_5_accuracy is not None:
                self.top_5_accuracy(pred_targets, targets)
            self.calibration_error(pred_targets, targets)
            self.logit_variance(pred_targets)

            # Log metrics
            # Note: Lightning accumulates torchmetrics over the entire test set by calling metric.compute()
            # at the end of the test epoch.
            logging_dict = {
                "Test Accuracy": self.accuracy,
                "Test ECE": self.calibration_error,
                "Test NLL": self.nll(
                    nn.functional.log_softmax(pred_targets, dim=-1), targets
                ),
                "Test Logit Variance": self.logit_variance,
                "Test Norm. Entropy": pred_targets_entropy.mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy"] = self.top_5_accuracy
            self.log_dict(logging_dict)

        elif dataloader_idx == 1:
            # Compute OOD Metrics

            target_class = targets[0]
            input_is_ood = targets[1]

            self.accuracy_ood(
                pred_targets[input_is_ood == 1], target_class[input_is_ood == 1]
            )  # Subset data to select OOD data only
            if self.top_5_accuracy is not None:
                self.top_5_accuracy_ood(
                    pred_targets[input_is_ood == 1], target_class[input_is_ood == 1]
                )
            self.calibration_error_ood(
                pred_targets[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            self.logit_variance_ood(pred_targets[input_is_ood == 1])

            # Note: Lightning accumulates the metrics over the entire test set by calling metric.compute()
            # at the end of the test epoch.
            logging_dict = {
                "Test Accuracy (OOD)": self.accuracy_ood,
                "Test ECE (OOD)": self.calibration_error_ood,
                "Test NLL (OOD)": self.nll(
                    nn.functional.log_softmax(pred_targets[input_is_ood == 1], dim=-1),
                    target_class[input_is_ood == 1],
                ),
                "Test Logit Variance (OOD)": self.logit_variance_ood,
                "Test Norm. Entropy (OOD)": pred_targets_entropy[
                    input_is_ood == 1
                ].mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy (OOD)"] = self.top_5_accuracy_ood

            if torch.backends.mps.is_available():
                # MPS backend causes bugs in binary AUROC
                # See: https://github.com/Lightning-AI/torchmetrics/issues/1727#issuecomment-1999173176
                self.binary_auroc.to(device="cpu")
                pred_targets_entropy = pred_targets_entropy.to(device="cpu")
                input_is_ood = input_is_ood.to(device="cpu")

            self.binary_auroc(
                pred_targets_entropy
                / math.log(self.num_classes),  # Normalize entropy to [0, 1]
                input_is_ood,
            )
            logging_dict["Test AUROC"] = self.binary_auroc

            # Log metrics
            self.log_dict(logging_dict)

    def configure_optimizers(self) -> optim.Optimizer:
        raise NotImplementedError
