from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Callable

import inferno
import lightning as L
import numpy as np
import torch
from inferno import bnn
from inferno.bnn import params
from torch import nn, optim

sys.path.insert(0, "../../experiments")
sys.path.insert(0, "analysis/experiments")

from ood_detection.models.ivi import _ivi_hyperparameters
from ood_detection.models.ivi._ivi_model import _ImplicitVIModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class MLPIVI(_ImplicitVIModel):
    """Multi-layer Perceptron.

    A fully-connected feedforward deep neural network with the same activation function for
    each hidden layer.
    """

    def __init__(
        self,
        in_size: int,
        hidden_sizes: list[int],
        out_size: int,
        parametrization: params.Parametrization,
        cov: params.FactorizedCovariance,
        num_samples_train: int,
        num_samples_test: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        dataset: L.LightningDataModule,
        temperature_scaling: bool = False,
        activation_layer: Callable[..., nn.Module] | None = nn.ReLU,
        bias: bool = True,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )
        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test
        self.dataset = dataset

        # Model
        self.model = bnn.Sequential(
            nn.Flatten(-3, -1),
            inferno.models.MLP(
                in_size=in_size,
                hidden_sizes=hidden_sizes,
                out_size=out_size,
                cov=cov,
                activation_layer=activation_layer,
                bias=bias,
            ),
            parametrization=parametrization,
        )

        # Temperature scaling
        if temperature_scaling:
            self.temperature_scaler = bnn.TemperatureScaler(
                loss_fn=nn.CrossEntropyLoss(),
            )

        self.save_hyperparameters(
            _ivi_hyperparameters(
                lightning_module=self,
                architecture="MLP",
                out_size=out_size,
                cov=cov,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                temperature_scaling=temperature_scaling,
            )
            | {
                "num_layers": len(hidden_sizes) + 1,
                "in_size": in_size,
                "hidden_sizes": hidden_sizes,
                "bias": bias,
                "activation_layer": activation_layer.__name__,
            },
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        num_samples_train: int,
        num_samples_test: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
    ) -> MLPIVI:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:

            # # Initialize dataloader for validation dataset
            # dataset.setup("fit")

            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                cov=[None, params.KroneckerCovariance(), params.KroneckerCovariance()],
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                dataset=dataset,
                temperature_scaling=True,
            )
        else:
            raise NotImplementedError()

    def configure_optimizers(self) -> optim.Optimizer:
        optimizer = optim.SGD(
            self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
            lr=self.lr,
            momentum=self.momentum,
            nesterov=self.nesterov,
            weight_decay=self.weight_decay,
        )
        return optimizer
        # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer=optimizer, T_max=self.max_epochs
        # )
        # if self.momentum == 0.0:
        #     base_momentum = 0.0
        #     max_momentum = 0.0
        # else:
        #     base_momentum = self.momentum - 0.05
        #     max_momentum = self.momentum + 0.05
        # lr_scheduler = optim.lr_scheduler.OneCycleLR(
        #     optimizer=optimizer,
        #     max_lr=self.lr,
        #     epochs=self.max_epochs,
        #     steps_per_epoch=len(self.dataset.train_dataloader()),
        #     base_momentum=base_momentum,
        #     max_momentum=max_momentum,
        # )
        # return {
        #     "optimizer": optimizer,
        #     "lr_scheduler": {
        #         "scheduler": lr_scheduler,
        #         "interval": (
        #             "step"
        #             if isinstance(lr_scheduler, optim.lr_scheduler.OneCycleLR)
        #             else "epoch"
        #         ),
        #         "frequency": 1,
        #     },
        # }
