from __future__ import annotations

from typing import TYPE_CHECKING

import inferno
import lightning as L
from inferno import bnn
from inferno.bnn import params
from torch import nn, optim

from . import _ivi_hyperparameters
from ._ivi_model import _ImplicitVIModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class ResNetIVI(_ImplicitVIModel):

    def __init__(
        self,
        resnet_type: type[inferno.models.ResNet],
        resnet_architecture: str,
        out_size: int,
        parametrization: params.Parametrization,
        cov: params.FactorizedCovariance,
        num_samples_train: int,
        num_samples_test: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        dataset: L.LightningDataModule,
        temperature_scaling: bool = False,
        freeze_pretrained_weights: bool = False,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )

        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test
        self.dataset = dataset

        # Model
        if pretrained:
            self.model = resnet_type.from_pretrained_weights(
                out_size=out_size,
                architecture=resnet_architecture,
                parametrization=parametrization,
                cov=cov,
                freeze=freeze_pretrained_weights,
            )
        else:
            self.model = resnet_type(
                out_size=out_size,
                architecture=resnet_architecture,
                parametrization=parametrization,
                cov=cov,
            )

        # Temperature scaling
        if temperature_scaling:
            self.temperature_scaler = bnn.TemperatureScaler(
                loss_fn=nn.CrossEntropyLoss(),
            )

        self.save_hyperparameters(
            _ivi_hyperparameters(
                lightning_module=self,
                architecture=resnet_type.__name__,
                out_size=out_size,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                cov=cov,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                temperature_scaling=temperature_scaling,
            ),
            logger=True,
        )


class ResNet18IVIKronecker(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):

        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=params.KroneckerCovariance(),
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )


class ResNet18IVILowRank(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):

        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )


class ResNet18IVICustom(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):

        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )

    # def configure_optimizers(self) -> optim.Optimizer:
    #     optimizer = optim.SGD(
    #         self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
    #         lr=self.lr,
    #         momentum=self.momentum,
    #         nesterov=self.nesterov,
    #         weight_decay=self.weight_decay,
    #     )
    #     lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    #         optimizer=optimizer, T_max=self.max_epochs
    #     )
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {
    #             "scheduler": lr_scheduler,
    #             "interval": "epoch",
    #             "frequency": 1,
    #         },
    #     }

    # def configure_optimizers(self) -> optim.Optimizer:
    #     optimizer = optim.SGD(
    #         self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD"),
    #         lr=self.lr,
    #         momentum=self.momentum,
    #         nesterov=self.nesterov,
    #         weight_decay=self.weight_decay,
    #     )
    #     if self.momentum == 0.0:
    #         base_momentum = 0.0
    #         max_momentum = 0.0
    #     else:
    #         base_momentum = self.momentum - 0.1
    #         max_momentum = self.momentum
    #     lr_scheduler = optim.lr_scheduler.OneCycleLR(
    #         optimizer=optimizer,
    #         max_lr=self.lr,
    #         epochs=self.max_epochs,
    #         steps_per_epoch=len(self.dataset.train_dataloader()),
    #         base_momentum=base_momentum,
    #         max_momentum=max_momentum,
    #     )
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {
    #             "scheduler": lr_scheduler,
    #             "interval": (
    #                 "step"
    #                 if isinstance(lr_scheduler, optim.lr_scheduler.OneCycleLR)
    #                 else "epoch"
    #             ),
    #             "frequency": 1,
    #         },
    #     }


class ResNet34IVILowRank(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet34,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
        )


class ResNet50IVILowRank(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )


class ResNet101IVILowRank(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):

        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )


class WideResNet50IVILowRank(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):

        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.WideResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )


class WideResNet101IVILowRank(ResNetIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):

        if dataset.__class__.__name__ in ["CIFAR10"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["CIFAR100"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=20)
        elif dataset.__class__.__name__ in ["TinyImageNet"]:
            resnet_architecture = "cifar"
            cov = params.LowRankCovariance(rank=10)
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.WideResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            cov=cov,
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )
