from __future__ import annotations

import abc
import math
from typing import TYPE_CHECKING

import lightning as L
import torch
import torchmetrics
import torchmetrics.classification
from torch import nn, optim

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _IVIModel(L.LightningModule, abc.ABC):
    """Base class for all models.

    :param num_classes:         Number of classes in the dataset.
    :param lr:                  Learning rate of the optimizer.
    :param momentum:            Momentum of the optimizer.
    :param nesterov:            Whether to use Nesterov momentum.
    :param weight_decay:        Weight decay of the optimizer.
    :param max_epochs:          Maximum number of epochs to train the model.
    """

    def __init__(
        self,
        num_classes: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
    ) -> None:
        super().__init__()

        # Loss function
        self.loss_fn = nn.BCEWithLogitsLoss()

        # Validation / test metrics
        self.num_classes = num_classes
        self.accuracy = torchmetrics.classification.Accuracy(
            task="binary" if num_classes == 2 else "multiclass",
            num_classes=num_classes,
        )
        self.nll = nn.BCELoss()
        self.calibration_error = torchmetrics.classification.CalibrationError(
            task="binary" if num_classes == 2 else "multiclass",
            num_classes=num_classes,
        )

        # Optimizer
        self.lr = lr
        self.momentum = momentum
        self.nesterov = nesterov
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.num_optim_steps = 0

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        lr: float,
        momentum: float,
        weight_decay: float,
        seed: int,
        root_dir: str,
    ):
        raise NotImplementedError

    def forward(
        self,
        inputs: Float[Tensor, "batch *in_feature"],
        sample_shape: torch.Size = torch.Size([]),
        generator: torch.Generator | None = None,
    ) -> Float[Tensor, "*sample batch *out_feature"]:
        return self.model(inputs, sample_shape=sample_shape, generator=generator)

    def on_validation_start(self):
        if hasattr(self, "temperature_scaler"):
            with torch.inference_mode(False):
                self.temperature_scaler.optimize(
                    model=self.model,
                    dataloader=self.dataset.val_dataloader(),
                )
        return super().on_validation_start()

    def training_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
    ) -> Float[Tensor, ""]:

        inputs, targets = batch
        pred_targets = self(inputs, sample_shape=(self.num_samples_train,))
        loss = self.loss_fn(
            pred_targets,
            targets.expand(self.num_samples_train, -1),
        )

        self.num_optim_steps += 1

        self._log_training_step(loss)

        return loss

    def _log_training_step(self, loss: Float[Tensor, ""]):
        self.log(
            "Train Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

    def validation_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_latents = self(inputs, sample_shape=(self.num_samples_test,))
        pred_targets = pred_latents

        self._log_validation_step(pred_targets, targets, batch_idx, dataloader_idx)

    def _log_validation_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        batch_idx: int = 0,
        dataloader_idx: int = 0,
    ):

        pred_log_probs = nn.functional.logsigmoid(pred_targets).mean(dim=0)
        pred_probs = torch.exp(pred_log_probs)

        if dataloader_idx == 0:
            # Validation set
            # NOTE: Make sure validation batch size is equal to the entire dataset
            f_var = pred_targets.var(dim=0)

            # Compute metrics on batch
            self.accuracy(pred_probs, targets)
            self.calibration_error(pred_probs, targets)

            model_params_vector = torch.nn.utils.parameters_to_vector(
                [
                    param
                    for name, param in self.model.named_parameters()
                    if ("cov" not in name) and ("temperature" not in name)
                ]
            )

            # Log metrics
            logging_dict = {
                "Validation Accuracy": self.accuracy,
                "Validation ECE": self.calibration_error,
                "Validation NLL": self.nll(pred_probs, targets),
                "Mean Parameter Norm": torch.linalg.norm(model_params_vector, ord=2)
                / self.model[-2][-2].params.temperature,
                "Variance val x0": f_var[0],
                "Variance val x1": f_var[1],
                "Variance val x2": f_var[2],
                "Variance val x3": f_var[3],
                "Variance val x4": f_var[4],
                "Variance val x5": f_var[5],
                "Variance val x6": f_var[6],
                "Variance val x7": f_var[7],
                "Variance val x8": f_var[8],
                "Variance val x9": f_var[9],
            }
            self.log_dict(logging_dict)
        elif dataloader_idx == 1:
            # Training set
            # NOTE: Make sure validation batch size is equal to the entire dataset
            f_var = pred_targets.var(dim=0)
            logging_dict = {
                "Variance train x0": f_var[0],
                "Variance train x1": f_var[1],
                "Variance train x2": f_var[2],
                "Variance train x3": f_var[3],
                "Variance train x4": f_var[4],
                "Variance train x5": f_var[5],
                "Variance train x6": f_var[6],
                "Variance train x7": f_var[7],
                "Variance train x8": f_var[8],
                "Variance train x9": f_var[9],
            }
            self.log_dict(logging_dict)

    def test_step(
        self,
        batch: tuple[
            Float[Tensor, "batch *in_feature"], Float[Tensor, "batch *out_feature"]
        ],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        inputs, targets = batch
        pred_latents = self(inputs, sample_shape=(self.num_samples_test,))
        pred_targets = pred_latents

        self._log_test_step(pred_targets, targets, dataloader_idx)

    def _log_test_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        dataloader_idx: int = 0,
    ):

        pred_log_probs = nn.functional.logsigmoid(pred_targets).mean(dim=0)
        pred_probs = torch.exp(pred_log_probs)

        if dataloader_idx == 0:

            self.accuracy(pred_probs, targets)
            self.calibration_error(pred_probs, targets)

            # Log metrics
            # Note: Lightning accumulates torchmetrics over the entire test set by calling metric.compute()
            # at the end of the test epoch.
            logging_dict = {
                "Test Accuracy": self.accuracy,
                "Test ECE": self.calibration_error,
                "Test NLL": self.nll(pred_probs, targets),
            }
            self.log_dict(logging_dict)

    def configure_optimizers(self) -> optim.Optimizer:
        optimizer = optim.SGD(
            self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
            lr=self.lr,
            momentum=self.momentum,
            nesterov=self.nesterov,
            weight_decay=self.weight_decay,
        )
        return optimizer
