from __future__ import annotations

from typing import TYPE_CHECKING

import inferno
import lightning as L
import torch
from inferno import bnn, models

from . import _ensemble_hyperparameters
from ._ensemble_model import _EnsembleModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class ResNetEnsemble(_EnsembleModel):

    def __init__(
        self,
        resnet_type: type[inferno.models.ResNet],
        resnet_architecture: str,
        out_size: int,
        parametrization: bnn.params.Parametrization,
        checkpoint_paths: str,
    ) -> None:
        state_dicts = []
        hyper_parameters = None
        for checkpoint_path in checkpoint_paths:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            state_dicts.append(
                {
                    k.partition("model.")[2]: v
                    for k, v in checkpoint["state_dict"].items()
                }
            )
            hyper_parameters = checkpoint[
                "hyper_parameters"
            ]  # NOTE: Assumes all checkpoints have the same hyperparameters

        super().__init__(
            num_classes=out_size,
            lr=hyper_parameters["lr"],
            momentum=hyper_parameters["momentum"],
            nesterov=hyper_parameters["nesterov"],
            weight_decay=hyper_parameters["weight_decay"],
            max_epochs=hyper_parameters["max_epochs"],
        )

        # Ensemble members
        members = []
        for state_dict in state_dicts:
            model = resnet_type(
                out_size=out_size,
                architecture=resnet_architecture,
                parametrization=parametrization,
                cov=None,
            )
            model.load_state_dict(state_dict)
            members.append(model)

        self.model = models.Ensemble(members)

        self.save_hyperparameters(
            _ensemble_hyperparameters(
                lightning_module=self,
                architecture=resnet_type.__name__,
                out_size=out_size,
                num_members=len(members),
                optimizer=hyper_parameters["optimizer"],
                lr=hyper_parameters["lr"],
                momentum=hyper_parameters["momentum"],
                nesterov=hyper_parameters["nesterov"],
                weight_decay=hyper_parameters["weight_decay"],
                max_epochs=hyper_parameters["max_epochs"],
            ),
            logger=True,
        )


class ResNet18Ensemble(ResNetEnsemble):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet18",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
        )

        return cls(
            resnet_type=inferno.models.ResNet18,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_paths=checkpoints,
        )


class ResNet34Ensemble(ResNetEnsemble):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet34",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
        )

        return cls(
            resnet_type=inferno.models.ResNet34,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_paths=checkpoints,
        )


class ResNet50Ensemble(ResNetEnsemble):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet50",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
        )

        return cls(
            resnet_type=inferno.models.ResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_paths=checkpoints,
        )


class ResNet101Ensemble(ResNetEnsemble):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="ResNet101",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
        )

        return cls(
            resnet_type=inferno.models.ResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_paths=checkpoints,
        )


class WideResNet50Ensemble(ResNetEnsemble):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="WideResNet50",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
        )

        return cls(
            resnet_type=inferno.models.WideResNet50,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_paths=checkpoints,
        )


class WideResNet101Ensemble(ResNetEnsemble):
    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
    ):
        if dataset.__class__.__name__ in [
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            resnet_architecture = "cifar"
        elif dataset.__class__.__name__ in ["ImageNet"]:
            resnet_architecture = "imagenet"
        else:
            raise NotImplementedError()

        checkpoints = cls.get_checkpoints(
            dataset_name=dataset.__class__.__name__,
            model_name="WideResNet101",
            parametrization_name=parametrization.__class__.__name__,
            root_dir=root_dir,
        )

        return cls(
            resnet_type=inferno.models.WideResNet101,
            resnet_architecture=resnet_architecture,
            out_size=dataset.num_classes,
            parametrization=parametrization,
            checkpoint_paths=checkpoints,
        )
