from __future__ import annotations

from typing import TYPE_CHECKING

import inferno
import lightning as L
import torch
from inferno import bnn
from inferno.bnn import params
from torch import nn

from . import _ts_hyperparameters
from ._ts_model import _TemperatureScaledModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class LeNet5TemperatureScaling(_TemperatureScaledModel):

    def __init__(
        self,
        out_size: int,
        parametrization: params.Parametrization,
        checkpoint_path: str,
        val_dataloader: torch.utils.data.DataLoader,
    ) -> None:
        checkpoint = torch.load(checkpoint_path, weights_only=True)
        hyper_parameters = checkpoint["hyper_parameters"]

        super().__init__(
            num_classes=out_size,
            lr=hyper_parameters["lr"],
            momentum=hyper_parameters["momentum"],
            nesterov=hyper_parameters["nesterov"],
            weight_decay=hyper_parameters["weight_decay"],
            max_epochs=hyper_parameters["max_epochs"],
        )
        self.val_dataloader = val_dataloader

        # Model
        self.model = inferno.models.LeNet5(
            out_size=out_size,
            parametrization=parametrization,
            cov=None,
        )

        self.model.load_state_dict(
            {k.partition("model.")[2]: v for k, v in checkpoint["state_dict"].items()}
        )

        # Temperature scaling
        temperature_scaler = bnn.TemperatureScaler(loss_fn=nn.CrossEntropyLoss())
        temperature_scaler.optimize(
            model=self.model.to(
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            ),
            dataloader=val_dataloader,
        )
        temperature = None
        for name, param in self.model.named_parameters():
            if "temperature" in name:
                temperature = param.item()
                break

        self.save_hyperparameters(
            _ts_hyperparameters(
                lightning_module=self,
                architecture="LeNet5",
                out_size=out_size,
                optimizer=hyper_parameters["optimizer"],
                lr=hyper_parameters["lr"],
                momentum=hyper_parameters["momentum"],
                nesterov=hyper_parameters["nesterov"],
                weight_decay=hyper_parameters["weight_decay"],
                max_epochs=hyper_parameters["max_epochs"],
                temperature=temperature,
            ),
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in ["MNIST", "FashionMNIST"]:

            # All checkpoints from vanilla neural network
            checkpoints = cls.get_checkpoints(
                dataset_name=dataset.__class__.__name__,
                model_name="LeNet5",
                parametrization_name=parametrization.__class__.__name__,
                root_dir=root_dir,
                seed=seed,
            )

            # Initialize dataloader for validation dataset
            dataset.setup("fit")

            return cls(
                out_size=dataset.num_classes,
                parametrization=parametrization,
                checkpoint_path=checkpoints[-1],
                val_dataloader=dataset.val_dataloader(),
            )
