from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Callable

import inferno
import lightning as L
import numpy as np
import torch
from inferno import bnn, loss_fns
from inferno.bnn import params
from torch import nn

from . import _wsvi_hyperparameters
from ._wsvi_model import _WeightSpaceVIModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class MLPWeightSpaceVI(_WeightSpaceVIModel):
    """Multi-layer Perceptron.

    A fully-connected feedforward deep neural network with the same activation function for
    each hidden layer.

    :param in_size:             Size of the input.
    :param hidden_sizes:        List of hidden layer sizes.
    :param out_size:            Size of the output (e.g. number of classes).
    :param parametrization:     The parametrization to use. Defines the initialization
        and learning rate scaling for the parameters of the module.
    :param cov:                 Covariance structure of the weights.
    :param num_samples_train:   Number of samples to draw during training.
    :param num_samples_test:    Number of samples to draw during testing.
    :param kl_weight:           Weight of the KL regularization term in the ELBO.
    :param lr:                  Learning rate of the optimizer.
    :param momentum:            Momentum of the optimizer.
    :param weight_decay:        Weight decay of the optimizer.
    :param activation_layer:    Activation function following a linear layer.
    :param bias:                Whether to use bias in the linear layer.``
    """

    def __init__(
        self,
        in_size: int,
        hidden_sizes: list[int],
        out_size: int,
        parametrization: params.Parametrization,
        num_samples_train: int,
        num_samples_test: int,
        kl_weight: float,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        cov: params.FactorizedCovariance,
        activation_layer: Callable[..., nn.Module] | None = nn.ReLU,
        bias: bool = True,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )
        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test

        # Model
        self.model = inferno.models.MLP(
            in_size=in_size,
            hidden_sizes=hidden_sizes,
            out_size=out_size,
            parametrization=parametrization,
            cov=cov,
            activation_layer=activation_layer,
            bias=bias,
        )
        # Ensure every layer has covariance parameters.
        if not isinstance(self.model[0].params.cov, params.DiagonalCovariance):
            self.model[0] = bnn.Linear(
                in_features=in_size,
                out_features=hidden_sizes[0],
                bias=bias,
                layer_type="input",
                cov=copy.deepcopy(cov),
                parametrization=parametrization,
            )

        # Assigning prior parameters as a buffer counts them in hparams.num_parameters_and_buffers
        numel_mean_parameters = sum(
            param.numel()
            for name, param in self.model.named_parameters()
            if param.requires_grad and "cov." not in name
        )
        self.prior_loc = nn.Buffer(
            torch.zeros((numel_mean_parameters,), requires_grad=False)
        )
        self.prior_scale = nn.Buffer(
            torch.ones((numel_mean_parameters,), requires_grad=False)
        )

        # Loss function
        self.loss_fn = loss_fns.VariationalFreeEnergy(
            nll=nn.CrossEntropyLoss(),
            model=self.model,
            prior_loc=self.prior_loc,
            prior_scale=self.prior_scale,
            kl_weight=kl_weight,
            reduction="mean",
        )

        self.save_hyperparameters(
            _wsvi_hyperparameters(
                lightning_module=self,
                architecture="MLP",
                out_size=out_size,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                cov=cov,
                kl_weight=self.loss_fn.kl_weight,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
            )
            | {
                "num_layers": len(hidden_sizes) + 1,
                "in_size": in_size,
                "hidden_sizes": hidden_sizes,
                "bias": bias,
                "activation_layer": activation_layer.__name__,
            },
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
    ) -> MLPWeightSpaceVI:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                num_samples_train=32,
                num_samples_test=256,
                cov=params.DiagonalCovariance(),
                kl_weight=None,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
            )
        else:
            raise NotImplementedError()

    def forward(
        self,
        inputs: Float[Tensor, "batch *in_feature"],
        sample_shape: torch.Size = torch.Size([]),
        generator: torch.Generator | None = None,
    ) -> Float[Tensor, "*sample batch *out_feature"]:
        inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
        return self.model(inputs, sample_shape=sample_shape, generator=generator)


MLPWeightSpaceVIDiagonal = MLPWeightSpaceVI


class MLPWeightSpaceVIKronecker(MLPWeightSpaceVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
    ) -> MLPWeightSpaceVI:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                num_samples_train=32,
                num_samples_test=32,
                cov=params.Kronecker(),
                kl_weight=None,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
            )
        else:
            raise NotImplementedError()
