from __future__ import annotations

import copy
from typing import TYPE_CHECKING

import inferno
import lightning as L
import torch
from inferno import loss_fns
from inferno.bnn import params
from torch import nn

from . import _wsvi_hyperparameters
from ._wsvi_model import _WeightSpaceVIModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class ResNetWeightSpaceVI(_WeightSpaceVIModel):

    def __init__(
        self,
        resnet_type: type[inferno.models.ResNet],
        resnet_architecture: str,
        out_size: int,
        parametrization: params.Parametrization,
        num_samples_train: int,
        num_samples_test: int,
        kl_weight: float,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        max_epochs: int,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )

        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test

        # Model
        cov = params.DiagonalCovariance()
        if pretrained:
            self.model = resnet_type.from_pretrained_weights(
                out_size=out_size,
                architecture=resnet_architecture,
                parametrization=parametrization,
                cov=cov,
                freeze=freeze_pretrained_weights,
            )
        else:
            self.model = resnet_type(
                out_size=out_size,
                architecture=resnet_architecture,
                parametrization=parametrization,
                cov=cov,
            )

        # Assigning prior parameters as a buffer counts them in hparams.num_parameters_and_buffers
        numel_mean_parameters = sum(
            param.numel()
            for name, param in self.model.named_parameters()
            if param.requires_grad and "params." in name and "cov." not in name
        )
        if pretrained:
            loc_params = (
                torch.concat(
                    [
                        param.ravel()
                        for name, param in self.model.named_parameters()
                        if param.requires_grad
                        and "params." in name
                        and "cov." not in name
                    ]
                )
                .detach()
                .clone()
            )
            loc_params.requires_grad = False

            self.prior_loc = nn.Buffer(loc_params)
        else:
            self.prior_loc = nn.Buffer(
                torch.zeros((numel_mean_parameters,), requires_grad=False)
            )
        self.prior_scale = nn.Buffer(
            torch.ones((numel_mean_parameters,), requires_grad=False)
        )

        # Loss function
        self.loss_fn = loss_fns.VariationalFreeEnergy(
            nll=nn.CrossEntropyLoss(),
            model=self.model,
            prior_loc=self.prior_loc,
            prior_scale=self.prior_scale,
            kl_weight=kl_weight,
        )

        self.save_hyperparameters(
            _wsvi_hyperparameters(
                lightning_module=self,
                architecture=resnet_type.__name__,
                out_size=out_size,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                cov=cov,
                kl_weight=self.loss_fn.kl_weight,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
            ),
            logger=True,
        )


class ResNet18WeightSpaceVIDiagonal(ResNetWeightSpaceVI):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            num_samples_train=8,
            num_samples_test=32,
            kl_weight=None,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
            max_epochs=max_epochs,
        )


class ResNet34WeightSpaceVIDiagonal(ResNetWeightSpaceVI):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet34,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            num_samples_train=8,
            num_samples_test=32,
            kl_weight=None,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
            max_epochs=max_epochs,
        )


class ResNet50WeightSpaceVIDiagonal(ResNetWeightSpaceVI):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            num_samples_train=8,
            num_samples_test=32,
            kl_weight=None,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
            max_epochs=max_epochs,
        )


class ResNet101WeightSpaceVIDiagonal(ResNetWeightSpaceVI):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.ResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            num_samples_train=8,
            num_samples_test=32,
            kl_weight=None,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
            max_epochs=max_epochs,
        )


class WideResNet50WeightSpaceVIDiagonal(ResNetWeightSpaceVI):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.WideResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            num_samples_train=8,
            num_samples_test=32,
            kl_weight=None,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
        )


class WideResNet101WeightSpaceVIDiagonal(ResNetWeightSpaceVI):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        return cls(
            resnet_type=inferno.models.WideResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            num_samples_train=8,
            num_samples_test=32,
            kl_weight=None,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            pretrained=pretrained,
            freeze_pretrained_weights=freeze_pretrained_weights,
            max_epochs=max_epochs,
        )
