from __future__ import annotations

from typing import TYPE_CHECKING

import inferno
import lightning as L
from inferno import bnn
from inferno.bnn import params
from torch import nn

from . import _ivi_hyperparameters
from ._ivi_model import _ImplicitVIModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class LeNet5IVI(_ImplicitVIModel):

    def __init__(
        self,
        out_size: int,
        parametrization: params.Parametrization,
        cov: params.FactorizedCovariance,
        num_samples_train: int,
        num_samples_test: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        dataset: L.LightningDataModule,
        temperature_scaling: bool = False,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )

        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test
        self.dataset = dataset

        # Model
        self.model = inferno.models.LeNet5(
            out_size=out_size,
            parametrization=parametrization,
            cov=cov,
        )

        # Temperature scaling
        if temperature_scaling:
            self.temperature_scaler = bnn.TemperatureScaler(
                loss_fn=nn.CrossEntropyLoss(),
            )

        self.save_hyperparameters(
            _ivi_hyperparameters(
                lightning_module=self,
                architecture="LeNet5",
                out_size=out_size,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                cov=cov,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                temperature_scaling=temperature_scaling,
            ),
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in ["MNIST", "FashionMNIST"]:
            return cls(
                out_size=dataset.num_classes,
                parametrization=parametrization,
                cov=params.KroneckerCovariance(),
                num_samples_train=1,
                num_samples_test=32,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                dataset=dataset,
                temperature_scaling=True,
            )
        else:
            raise NotImplementedError()


LeNet5IVIKronecker = LeNet5IVI


class LeNet5IVILowRank(LeNet5IVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in ["MNIST", "FashionMNIST"]:
            return cls(
                out_size=dataset.num_classes,
                parametrization=parametrization,
                cov=params.LowRankCovariance(10),
                num_samples_train=1,
                num_samples_test=32,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                dataset=dataset,
                temperature_scaling=True,
            )
        else:
            raise NotImplementedError()


class LeNet5IVICustomVF(LeNet5IVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in ["MNIST", "FashionMNIST"]:
            return cls(
                out_size=dataset.num_classes,
                parametrization=parametrization,
                cov=params.LowRankCovariance(10),
                num_samples_train=1,
                num_samples_test=32,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                dataset=dataset,
                temperature_scaling=True,
            )
        else:
            raise NotImplementedError()
