from __future__ import annotations

import math
import pathlib
import random
from typing import TYPE_CHECKING

import torch
from torch import nn, optim

from .._ood_model import _OODModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _EnsembleModel(_OODModel):
    """Base class for ensembles in the OOD detection experiment."""

    @staticmethod
    def get_checkpoints(
        dataset_name: str,
        model_name: str,
        parametrization_name: str,
        root_dir: str,
        resample: bool = True,
    ) -> list[pathlib.Path]:
        """Returns the list of checkpoints to use for the Ensemble."""
        try:
            checkpoints = list(
                pathlib.Path(
                    pathlib.Path(root_dir)
                    / f"training_logs/{dataset_name}/{model_name}/{parametrization_name}"
                ).rglob("*.ckpt")
            )
        except TypeError as e:
            raise ValueError(f"No checkpoints found.") from e

        if resample:
            # Sample checkpoints with replacement
            checkpoints = random.choices(checkpoints, k=len(checkpoints))

        return checkpoints

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        return self.model(inputs)

    def configure_optimizers(self) -> optim.Optimizer:
        pass

    def _log_validation_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
    ):
        # Average probabilities of ensemble members
        pred_probs = nn.functional.softmax(pred_targets, dim=-1).mean(dim=0)
        pred_log_probs = torch.log(pred_probs)

        self.accuracy(pred_probs, targets)
        if self.top_5_accuracy is not None:
            self.top_5_accuracy(pred_probs, targets)
        self.calibration_error(pred_probs, targets)
        self.logit_variance(pred_targets)

        logging_dict = {
            "Validation Accuracy": self.accuracy,
            "Validation ECE": self.calibration_error,
            "Validation NLL": self.nll(pred_log_probs, targets),
            "Validation Logit Variance": self.logit_variance,
        }
        if self.top_5_accuracy is not None:
            logging_dict["Validation Top-5 Accuracy"] = self.top_5_accuracy
        self.log_dict(logging_dict)

    def _log_test_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        dataloader_idx: int = 0,
    ):
        # Average probabilities of ensemble members
        pred_probs = nn.functional.softmax(pred_targets, dim=-1).mean(dim=0)
        pred_log_probs = torch.log(pred_probs)
        pred_targets_entropy = -torch.sum(pred_probs * pred_log_probs, dim=-1)

        if dataloader_idx == 0:
            self.accuracy(pred_probs, targets)
            if self.top_5_accuracy is not None:
                self.top_5_accuracy(pred_probs, targets)
            self.calibration_error(pred_probs, targets)
            self.logit_variance(pred_targets)

            logging_dict = {
                "Test Accuracy": self.accuracy,
                "Test ECE": self.calibration_error,
                "Test NLL": self.nll(pred_log_probs, targets),
                "Test Logit Variance": self.logit_variance,
                "Test Norm. Entropy": pred_targets_entropy.mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy"] = self.top_5_accuracy
            self.log_dict(logging_dict)

        elif dataloader_idx == 1:
            # Compute OOD Metrics

            target_class = targets[0]
            input_is_ood = targets[1]

            self.accuracy_ood(
                pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            if self.top_5_accuracy is not None:
                self.top_5_accuracy_ood(
                    pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
                )
            self.calibration_error_ood(
                pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            self.logit_variance_ood(pred_targets[:, input_is_ood == 1, ...])

            logging_dict = {
                "Test Accuracy (OOD)": self.accuracy_ood,
                "Test ECE (OOD)": self.calibration_error_ood,
                "Test NLL (OOD)": self.nll(
                    pred_log_probs[input_is_ood == 1], target_class[input_is_ood == 1]
                ),
                "Test Logit Variance (OOD)": self.logit_variance_ood,
                "Test Norm. Entropy (OOD)": pred_targets_entropy[
                    input_is_ood == 1
                ].mean()  # Only consider entropy for OOD Data
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy (OOD)"] = self.top_5_accuracy_ood

            if torch.backends.mps.is_available():
                # MPS backend causes bugs in binary AUROC
                # See: https://github.com/Lightning-AI/torchmetrics/issues/1727#issuecomment-1999173176
                self.binary_auroc.to(device="cpu")
                pred_targets_entropy = pred_targets_entropy.to(device="cpu")
                input_is_ood = input_is_ood.to(device="cpu")

            self.binary_auroc(
                pred_targets_entropy
                / math.log(self.num_classes),  # Normalize entropy to [0, 1]
                input_is_ood,
            )
            logging_dict["Test AUROC"] = self.binary_auroc

            self.log_dict(logging_dict)
