from __future__ import annotations

import math
from typing import TYPE_CHECKING

import torch
from torch import optim
from torch import nn
from .._ood_model import _OODModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _WeightSpaceVIModel(_OODModel):
    """Base class for models trained via variational inference in weight space."""

    def forward(
        self,
        inputs: Float[Tensor, "batch *in_feature"],
        sample_shape: torch.Size = torch.Size([]),
        generator: torch.Generator | None = None,
    ) -> Float[Tensor, "*sample batch *out_feature"]:
        return self.model(inputs, sample_shape=sample_shape, generator=generator)

    def training_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
    ) -> Float[Tensor, ""]:

        inputs, targets = batch
        pred_targets = self(inputs, sample_shape=(self.num_samples_train,)).reshape(
            -1, self.hparams.out_size
        )
        loss = self.loss_fn(
            pred_targets,
            targets.expand(self.num_samples_train, len(targets)).reshape(-1),
        )

        self._log_training_step(loss)

        return loss

    def validation_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_targets = self(inputs, sample_shape=(self.num_samples_train,))

        self._log_validation_step(pred_targets, targets)

    def _log_validation_step(
        self,
        pred_targets: Float[Tensor, "sample batch out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
    ):
        # Average logits
        pred_targets_avg = pred_targets.mean(dim=0)

        # Compute metrics on batch
        self.accuracy(pred_targets_avg, targets)
        if self.top_5_accuracy is not None:
            self.top_5_accuracy(pred_targets_avg, targets)
        self.calibration_error(pred_targets_avg, targets)
        self.logit_variance(pred_targets)

        # Log metrics
        # Note: Lightning accumulates torchmetrics over the entire test set by calling metric.compute()
        # at the end of the test epoch.
        logging_dict = {
            "Validation Accuracy": self.accuracy,
            "Validation ECE": self.calibration_error,
            "Validation NLL": self.nll(
                nn.functional.log_softmax(pred_targets_avg, dim=-1), targets
            ),
            "Validation Logit Variance": self.logit_variance,
        }
        if self.top_5_accuracy is not None:
            logging_dict["Validation Top-5 Accuracy"] = self.top_5_accuracy
        self.log_dict(logging_dict)

    def test_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_targets = self(inputs, sample_shape=(self.num_samples_test,))

        self._log_test_step(pred_targets, targets, dataloader_idx)

    def _log_test_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        dataloader_idx: int = 0,
    ):
        # TODO: Refactor this to move all logging logic into OODModel which should just handle if the model produces samples (unifies all sampling-based models)

        # Average logits
        pred_targets_avg = pred_targets.mean(dim=0)

        log_p = nn.functional.log_softmax(pred_targets_avg, dim=-1)
        p = nn.functional.softmax(pred_targets_avg, dim=-1)
        pred_targets_entropy = -torch.sum(p * log_p, dim=-1)

        if dataloader_idx == 0:
            # Compute ID Metrics

            self.accuracy(pred_targets_avg, targets)
            if self.top_5_accuracy is not None:
                self.top_5_accuracy(pred_targets_avg, targets)
            self.calibration_error(pred_targets_avg, targets)
            self.logit_variance(pred_targets)

            # Log metrics
            # Note: Lightning accumulates torchmetrics over the entire test set by calling metric.compute()
            # at the end of the test epoch.
            logging_dict = {
                "Test Accuracy": self.accuracy,
                "Test ECE": self.calibration_error,
                "Test NLL": self.nll(
                    nn.functional.log_softmax(pred_targets_avg, dim=-1), targets
                ),
                "Test Logit Variance": self.logit_variance,
                "Test Norm. Entropy": pred_targets_entropy.mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy"] = self.top_5_accuracy
            self.log_dict(logging_dict)

        elif dataloader_idx == 1:
            # Compute OOD Metrics

            target_class = targets[0]
            input_is_ood = targets[1]

            self.accuracy_ood(
                pred_targets_avg[input_is_ood == 1], target_class[input_is_ood == 1]
            )  # Subset data to select OOD data only
            if self.top_5_accuracy is not None:
                self.top_5_accuracy_ood(
                    pred_targets_avg[input_is_ood == 1], target_class[input_is_ood == 1]
                )
            self.calibration_error_ood(
                pred_targets_avg[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            self.logit_variance_ood(pred_targets[:, input_is_ood == 1, ...])

            # Note: Lightning accumulates the metrics over the entire test set by calling metric.compute()
            # at the end of the test epoch.
            logging_dict = {
                "Test Accuracy (OOD)": self.accuracy_ood,
                "Test ECE (OOD)": self.calibration_error_ood,
                "Test NLL (OOD)": self.nll(
                    nn.functional.log_softmax(
                        pred_targets_avg[input_is_ood == 1], dim=-1
                    ),
                    target_class[input_is_ood == 1],
                ),
                "Test Logit Variance (OOD)": self.logit_variance_ood,
                "Test Norm. Entropy (OOD)": pred_targets_entropy[
                    input_is_ood == 1
                ].mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy (OOD)"] = self.top_5_accuracy_ood

            if torch.backends.mps.is_available():
                # MPS backend causes bugs in binary AUROC
                # See: https://github.com/Lightning-AI/torchmetrics/issues/1727#issuecomment-1999173176
                self.binary_auroc.to(device="cpu")
                pred_targets_entropy = pred_targets_entropy.to(device="cpu")
                input_is_ood = input_is_ood.to(device="cpu")

            self.binary_auroc(
                pred_targets_entropy
                / math.log(self.num_classes),  # Normalize entropy to [0, 1]
                input_is_ood,
            )
            logging_dict["Test AUROC"] = self.binary_auroc

            # Log metrics
            self.log_dict(logging_dict)

    def configure_optimizers(self) -> optim.Optimizer:
        optimizer = optim.SGD(
            self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        return optimizer
        # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer=optimizer, T_max=self.max_epochs
        # )
        # return {
        #     "optimizer": optimizer,
        #     "lr_scheduler": {
        #         "scheduler": lr_scheduler,
        #         "interval": "epoch",
        #         "frequency": 1,
        #     },
        # }
