from __future__ import annotations

from typing import TYPE_CHECKING, Callable

import inferno
import lightning as L
import numpy as np
from inferno.bnn import params
from torch import nn
from .swag import SWAG

from . import _swag_hyperparameters
from ._swag_model import _SWAGModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class MLPSWAG(_SWAGModel):
    """Multi-layer Perceptron.

    A fully-connected feedforward deep neural network with the same activation function for
    each hidden layer.
    """

    def __init__(
        self,
        in_size: int,
        hidden_sizes: list[int],
        out_size: int,
        parametrization: params.Parametrization,
        cycle_start: int,
        cycle_length: int,
        max_num_models: int,
        num_estimators: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        activation_layer: Callable[..., nn.Module] | None = nn.ReLU,
        bias: bool = True,
    ) -> None:
        super().__init__(
            num_classes=out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )

        # Model
        base_model = inferno.models.MLP(
            in_size=in_size,
            hidden_sizes=hidden_sizes,
            out_size=out_size,
            norm_layer=None,
            activation_layer=activation_layer,
            bias=bias,
            parametrization=parametrization,
            cov=None,
        )

        self.model = SWAG(
            model=base_model,
            cycle_start=cycle_start,
            cycle_length=cycle_length,
            diag_covariance=False,
            max_num_models=max_num_models,
            num_estimators=num_estimators,
        )

        self.save_hyperparameters(
            _swag_hyperparameters(
                lightning_module=self,
                architecture="MLP",
                out_size=out_size,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
            )
            | {
                "num_layers": len(hidden_sizes) + 1,
                "in_size": in_size,
                "hidden_sizes": hidden_sizes,
                "bias": bias,
                "activation_layer": activation_layer.__name__,
            },
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [128, 128],
    ) -> MLPSWAG:
        if dataset.__class__.__name__ in [
            "MNIST",
            "FashionMNIST",
            "CIFAR10",
            "CIFAR100",
            "TinyImageNet",
        ]:
            return cls(
                in_size=int(np.prod(dataset.input_shape)),
                hidden_sizes=hidden_sizes,
                out_size=dataset.num_classes,
                parametrization=parametrization,
                cycle_start=max_epochs // 2,
                cycle_length=1,
                max_num_models=20,
                num_estimators=16,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
            )
        else:
            raise NotImplementedError()

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
        return self.model(inputs)
