from __future__ import annotations

import math
from typing import TYPE_CHECKING

import torch
from torch import nn, optim

from .._ood_model import _OODModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _ImplicitVIModel(_OODModel):
    """Base class for models trained via implicit variational inference."""

    def forward(
        self,
        inputs: Float[Tensor, "batch *in_feature"],
        sample_shape: torch.Size = torch.Size([]),
        generator: torch.Generator | None = None,
    ) -> Float[Tensor, "*sample batch *out_feature"]:
        return self.model(inputs, sample_shape=sample_shape, generator=generator)

    def on_test_start(self):
        if hasattr(self, "temperature_scaler"):
            with torch.inference_mode(False):
                self.temperature_scaler.optimize(
                    model=self.model,
                    dataloader=self.dataset.val_dataloader(),
                )
        return super().on_test_start()

    def training_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
    ) -> Float[Tensor, ""]:

        inputs, targets = batch
        pred_targets = self(inputs, sample_shape=(self.num_samples_train,)).reshape(
            -1, self.hparams.out_size
        )
        loss = self.loss_fn(
            pred_targets,
            targets.expand(self.num_samples_train, len(targets)).reshape(-1),
        )

        self._log_training_step(loss)

        return loss

    def validation_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_latents = self(
            inputs, sample_shape=(self.num_samples_test,)
        )  # TODO: change back to num_samples_train
        # pred_targets = pred_latents.mean(dim=0) / torch.sqrt(
        #     1.0 + math.pi / 8 * pred_latents.std(dim=0) ** 2
        # )
        pred_targets = pred_latents

        self._log_validation_step(pred_targets, targets)

    def test_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_latents = self(inputs, sample_shape=(self.num_samples_test,))
        # pred_targets = pred_latents.mean(dim=0) / torch.sqrt(
        #     1.0 + math.pi / 8 * pred_latents.std(dim=0) ** 2
        # )
        pred_targets = pred_latents

        self._log_test_step(pred_targets, targets, dataloader_idx)

    def _log_validation_step(
        self,
        pred_targets: Float[Tensor, "sample batch out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
    ):
        pred_log_probs = nn.functional.log_softmax(pred_targets, dim=-1).mean(dim=0)
        pred_probs = torch.exp(pred_log_probs)

        self.accuracy(pred_probs, targets)
        if self.top_5_accuracy is not None:
            self.top_5_accuracy(pred_probs, targets)
        self.calibration_error(pred_probs, targets)
        self.logit_variance(pred_targets)

        logging_dict = {
            "Validation Accuracy": self.accuracy,
            "Validation ECE": self.calibration_error,
            "Validation NLL": self.nll(pred_log_probs, targets),
            "Validation Logit Variance": self.logit_variance,
        }
        if self.top_5_accuracy is not None:
            logging_dict["Validation Top-5 Accuracy"] = self.top_5_accuracy
        self.log_dict(logging_dict)

    def _log_test_step(
        self,
        pred_targets: Float[Tensor, "sample batch out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        dataloader_idx: int = 0,
    ):
        pred_log_probs = nn.functional.log_softmax(pred_targets, dim=-1).mean(dim=0)
        pred_probs = torch.exp(pred_log_probs)
        pred_targets_entropy = -torch.sum(pred_probs * pred_log_probs, dim=-1)

        if dataloader_idx == 0:
            self.accuracy(pred_probs, targets)
            if self.top_5_accuracy is not None:
                self.top_5_accuracy(pred_probs, targets)
            self.calibration_error(pred_probs, targets)
            self.logit_variance(pred_targets)

            logging_dict = {
                "Test Accuracy": self.accuracy,
                "Test ECE": self.calibration_error,
                "Test NLL": self.nll(pred_log_probs, targets),
                "Test Logit Variance": self.logit_variance,
                "Test Norm. Entropy": pred_targets_entropy.mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy"] = self.top_5_accuracy
            self.log_dict(logging_dict)

        elif dataloader_idx == 1:
            # Compute OOD Metrics

            target_class = targets[0]
            input_is_ood = targets[1]

            self.accuracy_ood(
                pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            if self.top_5_accuracy is not None:
                self.top_5_accuracy_ood(
                    pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
                )
            self.calibration_error_ood(
                pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            self.logit_variance_ood(pred_targets[:, input_is_ood == 1, ...])

            logging_dict = {
                "Test Accuracy (OOD)": self.accuracy_ood,
                "Test ECE (OOD)": self.calibration_error_ood,
                "Test NLL (OOD)": self.nll(
                    pred_log_probs[input_is_ood == 1], target_class[input_is_ood == 1]
                ),
                "Test Logit Variance (OOD)": self.logit_variance_ood,
                "Test Norm. Entropy (OOD)": pred_targets_entropy[
                    input_is_ood == 1
                ].mean()  # Only consider entropy for OOD Data
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy (OOD)"] = self.top_5_accuracy_ood

            if torch.backends.mps.is_available():
                # MPS backend causes bugs in binary AUROC
                # See: https://github.com/Lightning-AI/torchmetrics/issues/1727#issuecomment-1999173176
                self.binary_auroc.to(device="cpu")
                pred_targets_entropy = pred_targets_entropy.to(device="cpu")
                input_is_ood = input_is_ood.to(device="cpu")

            self.binary_auroc(
                pred_targets_entropy
                / math.log(self.num_classes),  # Normalize entropy to [0, 1]
                input_is_ood,
            )
            logging_dict["Test AUROC"] = self.binary_auroc

            self.log_dict(logging_dict)

    def configure_optimizers(self) -> optim.Optimizer:
        optimizer = optim.SGD(
            self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
            lr=self.lr,
            momentum=self.momentum,
            nesterov=self.nesterov,
            weight_decay=self.weight_decay,
        )
        return optimizer
        # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer=optimizer, T_max=self.max_epochs
        # )
        # return {
        #     "optimizer": optimizer,
        #     "lr_scheduler": {
        #         "scheduler": lr_scheduler,
        #         "interval": "epoch",
        #         "frequency": 1,
        #     },
        # }
