from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Literal

import inferno
import laplace
import lightning as L
import torch
from inferno.bnn import params
from torch import nn

from . import _laplace_hyperparameters
from ._laplace_model import _LaplaceModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class ResNetLaplaceLastLayer(_LaplaceModel):

    def __init__(
        self,
        resnet_type: type[inferno.models.ResNet],
        resnet_architecture: str,
        out_size: int,
        parametrization: params.Parametrization,
        checkpoint_path: str,
        subset_of_weights: Literal["last_layer", "subnetwork", "all"],
        hessian_structure: Literal["diag", "kron", "full", "lowrank", "gp"],
        train_dataloader: torch.utils.data.DataLoader,
        val_dataloader: torch.utils.data.DataLoader,
        num_samples_test: int,
        pred_type: Literal["glm", "nn", "gp"] = "glm",
        link_approx: Literal["probit", "mc", "bridge", "bride_norm"] = "probit",
        method_prior_precision_optimization: Literal[
            "gridsearch", "marglik"
        ] = "marglik",
    ) -> None:
        checkpoint = torch.load(checkpoint_path, weights_only=True)
        hyper_parameters = checkpoint["hyper_parameters"]

        super().__init__(
            num_classes=out_size,
            lr=hyper_parameters["lr"],
            momentum=hyper_parameters["momentum"],
            nesterov=hyper_parameters["nesterov"],
            weight_decay=hyper_parameters["weight_decay"],
            max_epochs=hyper_parameters["max_epochs"],
        )
        self.num_samples_test = num_samples_test
        self.pred_type = pred_type
        self.link_approx = link_approx
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

        # Model
        resnet = resnet_type(
            out_size=out_size,
            architecture=resnet_architecture,
            parametrization=parametrization,
            cov=None,
        )
        resnet.load_state_dict(
            {k.partition("model.")[2]: v for k, v in checkpoint["state_dict"].items()}
        )

        # NOTE: laplace only supports nn.Linear last layers
        nn_linear = nn.Linear(
            in_features=resnet.fc.in_features,
            out_features=resnet.fc.out_features,
            bias=resnet.fc.bias is not None,
            device=resnet.fc.weight.device,
            dtype=resnet.fc.weight.dtype,
        )
        nn_linear.load_state_dict(
            {
                "weight": resnet.fc.weight,
                "bias": resnet.fc.bias,
            }
        )
        resnet.fc = nn_linear
        self.model = resnet

        # Laplace approximation
        self.laplace_approximation = laplace.Laplace(
            self.model.to(
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            ),
            likelihood="classification",
            subset_of_weights=subset_of_weights,
            hessian_structure=hessian_structure,
        )

        # Fit the local Laplace approximation at the parameters of the model,
        # i.e. compute Hessian approximation at MAP
        self.laplace_approximation.fit(self.train_dataloader)

        # Fix for bug in Laplace library, which prevents prediction prior to optimizing the prior precision
        self.laplace_approximation.prior_precision = torch.as_tensor(
            self.laplace_approximation.prior_precision
        )

        self.laplace_approximation.optimize_prior_precision(
            method=method_prior_precision_optimization,
            pred_type=self.pred_type,
            link_approx=self.link_approx,
            val_loader=(
                self.val_dataloader
                if method_prior_precision_optimization == "gridsearch"
                else None
            ),
            progress_bar=True,
        )

        # TODO: time the fit and prior_precision steps separately
        # TODO: implement optimize_prior_precision in lightning learning loop?

        self.save_hyperparameters(
            _laplace_hyperparameters(
                lightning_module=self,
                architecture=resnet_type.__name__,
                out_size=out_size,
                subset_of_weights=subset_of_weights,
                hessian_structure=hessian_structure,
                num_samples_test=num_samples_test,
                pred_type=pred_type,
                link_approx=link_approx,
                method_prior_precision_optimization=method_prior_precision_optimization,
                optimizer=hyper_parameters["optimizer"],
                lr=hyper_parameters["lr"],
                momentum=hyper_parameters["momentum"],
                nesterov=hyper_parameters["nesterov"],
                weight_decay=hyper_parameters["weight_decay"],
                max_epochs=hyper_parameters["max_epochs"],
            ),
            logger=True,
        )


class ResNet18LaplaceLastLayerMargLik(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet18",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="marglik",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet18LaplaceLastLayerGridSearch(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet18",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="gridsearch",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet34LaplaceLastLayerMargLik(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet34",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet34,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="marglik",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet34LaplaceLastLayerGridSearch(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet34",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet34,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="gridsearch",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet50LaplaceLastLayerMargLik(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet50",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="marglik",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet50LaplaceLastLayerGridSearch(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet50",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="gridsearch",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet101LaplaceLastLayerMargLik(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet101",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="marglik",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class ResNet101LaplaceLastLayerGridSearch(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet101",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.ResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="gridsearch",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class WideResNet50LaplaceLastLayerMargLik(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="WideResNet50",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.WideResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="marglik",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class WideResNet50LaplaceLastLayerGridSearch(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="WideResNet50",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.WideResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="gridsearch",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class WideResNet101LaplaceLastLayerMargLik(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="WideResNet101",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.WideResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="marglik",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )


class WideResNet101LaplaceLastLayerGridSearch(ResNetLaplaceLastLayer):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        # All checkpoints from vanilla neural network
        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="WideResNet101",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
            seed=seed,
        )

        # Initialize dataloader for computation of Hessian in Laplace approximatino
        dataset.setup("fit")

        return cls(
            resnet_type=inferno.models.WideResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_path=checkpoints[-1],
            num_samples_test=256,
            subset_of_weights="last_layer",
            hessian_structure="kron",
            method_prior_precision_optimization="gridsearch",
            train_dataloader=dataset.train_dataloader(),
            val_dataloader=dataset.val_dataloader(),
        )
