from __future__ import annotations

import pathlib
from typing import TYPE_CHECKING, Callable

import inferno
import lightning as L
import numpy as np
import torch
import torchvision
from inferno import bnn, models
from torch import nn, optim

from . import _ensemble_hyperparameters
from ._ensemble_model import _EnsembleModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class MLPEnsemble(_EnsembleModel):
    """Multi-layer Perceptron Ensemble.

    An ensemble of fully-connected feedforward deep neural networks with the same activation function for
    each hidden layer.

    :param in_size:             Size of the input.
    :param hidden_sizes:        List of hidden layer sizes.
    :param out_size:            Size of the output (e.g. number of classes).
    :param parametrization:     The parametrization to use.
    :param checkpoint_paths:    List of paths to the checkpoints of the models.
    :param activation_layer:    Activation function which will be stacked on top of the normalization
        layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used.
    :param bias:                Whether to use bias in the linear layer.``
    """

    def __init__(
        self,
        in_size: int,
        hidden_sizes: list[int],
        out_size: int,
        parametrization: bnn.params.Parametrization,
        checkpoint_paths: str,
        activation_layer: Callable[..., nn.Module] | None = nn.ReLU,
        bias: bool = True,
    ) -> None:
        state_dicts = []
        hyper_parameters = None
        for checkpoint_path in checkpoint_paths:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            state_dicts.append(
                {
                    k.partition("model.")[2]: v
                    for k, v in checkpoint["state_dict"].items()
                }
            )
            hyper_parameters = checkpoint[
                "hyper_parameters"
            ]  # NOTE: Assumes all checkpoints have the same hyperparameters

        super().__init__(
            num_classes=out_size,
            lr=hyper_parameters["lr"],
            momentum=hyper_parameters["momentum"],
            nesterov=hyper_parameters["nesterov"],
            weight_decay=hyper_parameters["weight_decay"],
            max_epochs=hyper_parameters["max_epochs"],
        )

        # Ensemble members
        members = []
        for state_dict in state_dicts:
            if False:  # isinstance(parametrization, bnn.params.Standard):
                # TODO: Used when using vanilla PyTorch models
                model = torchvision.ops.MLP(
                    in_channels=in_size,
                    hidden_channels=hidden_sizes + [out_size],
                    norm_layer=None,
                    activation_layer=activation_layer,
                    inplace=None,
                    bias=bias,
                )
            else:
                model = inferno.models.MLP(
                    in_size=in_size,
                    hidden_sizes=hidden_sizes,
                    out_size=out_size,
                    parametrization=parametrization,
                    cov=None,
                    activation_layer=activation_layer,
                    bias=bias,
                )
            model.load_state_dict(state_dict)
            members.append(model)

        self.model = models.Ensemble(members)

        self.save_hyperparameters(
            _ensemble_hyperparameters(
                lightning_module=self,
                architecture="MLP",
                out_size=out_size,
                num_members=len(members),
                optimizer=hyper_parameters["optimizer"],
                lr=hyper_parameters["lr"],
                momentum=hyper_parameters["momentum"],
                nesterov=hyper_parameters["nesterov"],
                weight_decay=hyper_parameters["weight_decay"],
                max_epochs=hyper_parameters["max_epochs"],
            )
            | {
                "num_layers": len(hidden_sizes) + 1,
                "in_size": in_size,
                "hidden_sizes": hidden_sizes,
                "activation_layer": activation_layer.__class__.__name__,
                "bias": bias,
            },
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: bnn.params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
    ) -> MLPEnsemble:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:

            checkpoints = cls.get_checkpoints(
                dataset_name=dataset.__class__.__name__,
                model_name="MLP",
                parametrization_name=parametrization.__class__.__name__,
                root_dir=root_dir,
            )

            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                checkpoint_paths=checkpoints,
            )
        else:
            raise NotImplementedError()

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
        return self.model(inputs)
