from __future__ import annotations

import math
from typing import TYPE_CHECKING, Callable

import inferno
import lightning as L
import numpy as np
import torch
from inferno import bnn
from inferno.bnn import params
from models.ivi._ivi_model import _ImplicitVIModel
from torch import nn

from ..custom_covariances.low_rank import CustomLowRankCovariance
from ..custom_inferno_mlp import CustomMLP
from . import _ivi_hyperparameters

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class MLPIVI(_ImplicitVIModel):
    """Multi-layer Perceptron.

    A fully-connected feedforward deep neural network with the same activation function for
    each hidden layer.

    Uses hyperparameters from tensor programs V example.

    :param in_size:             Size of the input.
    :param hidden_sizes:        List of hidden layer sizes.
    :param out_size:            Size of the output (e.g. number of classes).
    :param parametrization:     The parametrization to use.
    :param cov:                 Covariance structure of the weights.
    :param num_samples_train:   Number of samples to draw during training.
    :param num_samples_test:    Number of samples to draw during testing.
    :param lr:                  Learning rate of the optimizer.
    :param momentum:            Momentum of the optimizer.
    :param nesterov:            Whether to use Nesterov momentum.
    :param weight_decay:        Weight decay of the optimizer.
    :param activation_layer:    Activation function following a linear layer.
    :param bias:                Whether to use bias in the linear layer.``
    """

    def __init__(
        self,
        in_size: int,
        hidden_sizes: list[int],
        out_size: int,
        parametrization: params.Parametrization,
        cov: params.FactorizedCovariance,
        num_samples_train: int,
        num_samples_test: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        dataset: L.LightningDataModule,
        temperature_scaling: bool = False,
        activation_layer: Callable[..., nn.Module] | None = nn.ReLU,
        bias: bool = True,
        scale_mean_input_init_weight: float = 1.0,
        scale_mean_input_init_bias: float = 1.0,
        scale_mean_input_lr_weight: float = 1.0,
        scale_mean_input_lr_bias: float = 1.0,
        scale_mean_input_forward_weight: float = 1.0,
        scale_mean_input_forward_bias: float = 1.0,
        scale_mean_output_init_weight: float = 1.0,
        scale_mean_output_init_bias: float = 1.0,
        scale_mean_output_lr_weight: float = 1.0,
        scale_mean_output_lr_bias: float = 1.0,
        scale_mean_output_forward_weight: float = 1.0,
        scale_mean_output_forward_bias: float = 1.0,
        scale_cov_input_init_weight: float = 1.0,
        scale_cov_input_init_bias: float = 1.0,
        scale_cov_input_lr_weight: float = 1.0,
        scale_cov_input_lr_bias: float = 1.0,
        scale_cov_input_forward_weight: float = 1.0,
        scale_cov_input_forward_bias: float = 1.0,
        scale_cov_output_init_weight: float = 1.0,
        scale_cov_output_init_bias: float = 1.0,
        scale_cov_output_lr_weight: float = 1.0,
        scale_cov_output_lr_bias: float = 1.0,
        scale_cov_output_forward_weight: float = 1.0,
        scale_cov_output_forward_bias: float = 1.0,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )
        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test
        self.dataset = dataset

        # Model
        self.model = bnn.Sequential(
            nn.Flatten(-3, -1),
            CustomMLP(
                in_size=in_size,
                hidden_sizes=hidden_sizes,
                out_size=out_size,
                cov=cov,
                activation_layer=activation_layer,
                bias=bias,
                scale_input_init_weight=scale_mean_input_init_weight,
                scale_input_init_bias=scale_mean_input_init_bias,
                scale_input_lr_weight=scale_mean_input_lr_weight,
                scale_input_lr_bias=scale_mean_input_lr_bias,
                scale_input_forward_weight=scale_mean_input_forward_weight,
                scale_input_forward_bias=scale_mean_input_forward_bias,
                scale_output_init_weight=scale_mean_output_init_weight,
                scale_output_init_bias=scale_mean_output_init_bias,
                scale_output_lr_weight=scale_mean_output_lr_weight,
                scale_output_lr_bias=scale_mean_output_lr_bias,
                scale_output_forward_weight=scale_mean_output_forward_weight,
                scale_output_forward_bias=scale_mean_output_forward_bias,
            ),
            parametrization=parametrization,
        )
        self.model.reset_parameters()

        # Temperature scaling
        if temperature_scaling:
            self.temperature_scaler = bnn.TemperatureScaler(
                loss_fn=nn.CrossEntropyLoss(),
            )

        self.save_hyperparameters(
            _ivi_hyperparameters(
                lightning_module=self,
                architecture="MLP",
                out_size=out_size,
                cov=cov,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                temperature_scaling=temperature_scaling,
                scale_mean_input_init_weight=scale_mean_input_init_weight,
                scale_mean_input_init_bias=scale_mean_input_init_bias,
                scale_mean_input_lr_weight=scale_mean_input_lr_weight,
                scale_mean_input_lr_bias=scale_mean_input_lr_bias,
                scale_mean_input_forward_weight=scale_mean_input_forward_weight,
                scale_mean_input_forward_bias=scale_mean_input_forward_bias,
                scale_mean_output_init_weight=scale_mean_output_init_weight,
                scale_mean_output_init_bias=scale_mean_output_init_bias,
                scale_mean_output_lr_weight=scale_mean_output_lr_weight,
                scale_mean_output_lr_bias=scale_mean_output_lr_bias,
                scale_mean_output_forward_weight=scale_mean_output_forward_weight,
                scale_mean_output_forward_bias=scale_mean_output_forward_bias,
                scale_cov_input_init_weight=scale_cov_input_init_weight,
                scale_cov_input_init_bias=scale_cov_input_init_bias,
                scale_cov_input_lr_weight=scale_cov_input_lr_weight,
                scale_cov_input_lr_bias=scale_cov_input_lr_bias,
                scale_cov_input_forward_weight=scale_cov_input_forward_weight,
                scale_cov_input_forward_bias=scale_cov_input_forward_bias,
                scale_cov_output_init_weight=scale_cov_output_init_weight,
                scale_cov_output_init_bias=scale_cov_output_init_bias,
                scale_cov_output_lr_weight=scale_cov_output_lr_weight,
                scale_cov_output_lr_bias=scale_cov_output_lr_bias,
                scale_cov_output_forward_weight=scale_cov_output_forward_weight,
                scale_cov_output_forward_bias=scale_cov_output_forward_bias,
            )
            | {
                "num_layers": len(hidden_sizes) + 1,
                "in_size": in_size,
                "hidden_sizes": hidden_sizes,
                "bias": bias,
                "activation_layer": activation_layer.__name__,
            },
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
    ) -> MLPIVI:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:

            # # Initialize dataloader for validation dataset
            # dataset.setup("fit")

            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                cov=[None, params.KroneckerCovariance(), params.KroneckerCovariance()],
                num_samples_train=32,
                num_samples_test=256,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                dataset=dataset,
                temperature_scaling=True,
            )
        else:
            raise NotImplementedError()

    def _log_test_step(
        self,
        pred_targets: Float[Tensor, "batch *out_feature"],
        targets: Float[Tensor, "batch *out_feature"],
        dataloader_idx: int = 0,
    ):
        pred_log_probs = nn.functional.log_softmax(pred_targets, dim=-1).mean(dim=0)
        pred_probs = torch.exp(pred_log_probs)
        pred_targets_entropy = -torch.sum(pred_probs * pred_log_probs, dim=-1)

        if dataloader_idx == 0:
            self.accuracy(pred_probs, targets)
            if self.top_5_accuracy is not None:
                self.top_5_accuracy(pred_probs, targets)
            self.calibration_error(pred_probs, targets)

            logging_dict = {
                "Test Accuracy": self.accuracy,
                "Test ECE": self.calibration_error,
                "Test NLL": self.nll(pred_log_probs, targets),
                "Test Norm. Entropy": pred_targets_entropy.mean()
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy"] = self.top_5_accuracy
            self.log_dict(logging_dict)

        elif dataloader_idx == 1:
            # Compute OOD Metrics

            target_class = targets[0]
            input_is_ood = targets[1]

            self.accuracy_ood(
                pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
            )
            if self.top_5_accuracy is not None:
                self.top_5_accuracy_ood(
                    pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
                )
            self.calibration_error_ood(
                pred_probs[input_is_ood == 1], target_class[input_is_ood == 1]
            )

            logging_dict = {
                "Test Accuracy (OOD)": self.accuracy_ood,
                "Test ECE (OOD)": self.calibration_error_ood,
                "Test NLL (OOD)": self.nll(
                    pred_log_probs[input_is_ood == 1], target_class[input_is_ood == 1]
                ),
                "Test Norm. Entropy (OOD)": pred_targets_entropy[
                    input_is_ood == 1
                ].mean()  # Only consider entropy for OOD Data
                / math.log(self.num_classes),
            }
            if self.top_5_accuracy is not None:
                logging_dict["Test Top-5 Accuracy (OOD)"] = self.top_5_accuracy_ood

            if torch.backends.mps.is_available():
                # MPS backend causes bugs in binary AUROC
                # See: https://github.com/Lightning-AI/torchmetrics/issues/1727#issuecomment-1999173176
                self.binary_auroc.to(device="cpu")
                pred_targets_entropy = pred_targets_entropy.to(device="cpu")
                input_is_ood = input_is_ood.to(device="cpu")

            self.binary_auroc(
                pred_targets_entropy
                / math.log(self.num_classes),  # Normalize entropy to [0, 1]
                input_is_ood,
            )
            logging_dict["Test AUROC"] = self.binary_auroc

            self.log_dict(logging_dict)

        elif dataloader_idx == 2:
            pred_mean_log_probs = nn.functional.log_softmax(
                pred_targets.mean(dim=0), dim=-1
            )

            self.log_dict(
                {
                    "Final Validation NLL": self.nll(pred_log_probs, targets),
                    "Final Validation NLL of mean": self.nll(
                        pred_mean_log_probs, targets
                    ),
                    # "Train Loss": self.loss_fn(pred_targets.reshape(-1, self.hparams.out_size), targets.expand(self.num_samples_test, len(targets)).reshape(-1)), # same as "Train NLL"
                    # "Train Loss of mean": self.loss_fn(pred_targets.mean(dim=0).reshape(-1, self.hparams.out_size), targets), # same as "Train NLL of mean"
                }
            )

    def configure_optimizers(self) -> torch.optim.Optimizer:
        optimizer = torch.optim.SGD(
            self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
            lr=self.lr,
            momentum=self.momentum,
            nesterov=self.nesterov,
            weight_decay=self.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.max_epochs
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }


MLPIVIKronecker = MLPIVI


class MLPIVILowRank(MLPIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
        bias: bool = True,
        scale_mean_input_init_weight: float = 1.0,
        scale_mean_input_init_bias: float = 1.0,
        scale_mean_input_lr_weight: float = 1.0,
        scale_mean_input_lr_bias: float = 1.0,
        scale_mean_input_forward_weight: float = 1.0,
        scale_mean_input_forward_bias: float = 1.0,
        scale_mean_output_init_weight: float = 1.0,
        scale_mean_output_init_bias: float = 1.0,
        scale_mean_output_lr_weight: float = 1.0,
        scale_mean_output_lr_bias: float = 1.0,
        scale_mean_output_forward_weight: float = 1.0,
        scale_mean_output_forward_bias: float = 1.0,
        scale_cov_input_init_weight: float = 1.0,
        scale_cov_input_init_bias: float = 1.0,
        scale_cov_input_lr_weight: float = 1.0,
        scale_cov_input_lr_bias: float = 1.0,
        scale_cov_input_forward_weight: float = 1.0,
        scale_cov_input_forward_bias: float = 1.0,
        scale_cov_output_init_weight: float = 1.0,
        scale_cov_output_init_bias: float = 1.0,
        scale_cov_output_lr_weight: float = 1.0,
        scale_cov_output_lr_bias: float = 1.0,
        scale_cov_output_forward_weight: float = 1.0,
        scale_cov_output_forward_bias: float = 1.0,
    ) -> MLPIVI:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            # # Initialize dataloader for validation dataset
            # dataset.setup("fit")

            cov = [
                # CustomLowRankCovariance(
                #    rank=10,
                #    scale_init_weight = scale_cov_input_init_weight,
                #    scale_init_bias = scale_cov_input_init_bias,
                #    scale_lr_weight = scale_cov_input_lr_weight,
                #    scale_lr_bias = scale_cov_input_lr_bias,
                #    scale_forward_weight = scale_cov_input_forward_weight,
                #    scale_forward_bias = scale_cov_input_forward_bias,
                # ),
                None,
                params.LowRankCovariance(10),
                CustomLowRankCovariance(
                    rank=10,
                    scale_init_weight=scale_cov_output_init_weight,
                    scale_init_bias=scale_cov_output_init_bias,
                    scale_lr_weight=scale_cov_output_lr_weight,
                    scale_lr_bias=scale_cov_output_lr_bias,
                    scale_forward_weight=scale_cov_output_forward_weight,
                    scale_forward_bias=scale_cov_output_forward_bias,
                ),
            ]

            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                # cov=[None, params.LowRankCovariance(10), params.LowRankCovariance(10)],
                cov=cov,
                num_samples_train=32,
                num_samples_test=256,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                dataset=dataset,
                temperature_scaling=True,
                bias=bias,
                scale_mean_input_init_weight=scale_mean_input_init_weight,
                scale_mean_input_init_bias=scale_mean_input_init_bias,
                scale_mean_input_lr_weight=scale_mean_input_lr_weight,
                scale_mean_input_lr_bias=scale_mean_input_lr_bias,
                scale_mean_input_forward_weight=scale_mean_input_forward_weight,
                scale_mean_input_forward_bias=scale_mean_input_forward_bias,
                scale_mean_output_init_weight=scale_mean_output_init_weight,
                scale_mean_output_init_bias=scale_mean_output_init_bias,
                scale_mean_output_lr_weight=scale_mean_output_lr_weight,
                scale_mean_output_lr_bias=scale_mean_output_lr_bias,
                scale_mean_output_forward_weight=scale_mean_output_forward_weight,
                scale_mean_output_forward_bias=scale_mean_output_forward_bias,
                scale_cov_input_init_weight=scale_cov_input_init_weight,
                scale_cov_input_init_bias=scale_cov_input_init_bias,
                scale_cov_input_lr_weight=scale_cov_input_lr_weight,
                scale_cov_input_lr_bias=scale_cov_input_lr_bias,
                scale_cov_input_forward_weight=scale_cov_input_forward_weight,
                scale_cov_input_forward_bias=scale_cov_input_forward_bias,
                scale_cov_output_init_weight=scale_cov_output_init_weight,
                scale_cov_output_init_bias=scale_cov_output_init_bias,
                scale_cov_output_lr_weight=scale_cov_output_lr_weight,
                scale_cov_output_lr_bias=scale_cov_output_lr_bias,
                scale_cov_output_forward_weight=scale_cov_output_forward_weight,
                scale_cov_output_forward_bias=scale_cov_output_forward_bias,
            )
        else:
            raise NotImplementedError()
