from __future__ import annotations

import math
from typing import TYPE_CHECKING, Any, Callable, Dict

import inferno
import lightning as L
import numpy as np
import torch
from inferno import bnn
from inferno.bnn import params
from torch import nn

from ._ivi_model import _IVIModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


def _ivi_hyperparameters(
    lightning_module,
    architecture: str,
    out_size: int,
    cov: params.FactorizedCovariance,
    num_samples_train: int,
    num_samples_test: int,
    temperature_scaling: bool,
    lr: float,
    momentum: float,
    nesterov: bool,
    weight_decay: float,
    max_epochs: int,
    optimizer: str = "SGD",
) -> Dict[str, Any]:
    """Common hyperparameters for models trained via implicit VI."""

    VARIATIONAL_FAMILY = {
        "DiagonalCovariance": "Mean-field",
        "KroneckerCovariance": "Kronecker",
        "LowRankCovariance": "Low-rank",
    }
    if isinstance(cov, list):
        cov = cov[-1]

    return {
        "model": lightning_module.__class__.__name__,
        "inference_method": f"Implicit VI ({VARIATIONAL_FAMILY[cov.__class__.__name__]})",
        "architecture": architecture,
        "out_size": out_size,
        "num_trainable_parameters": sum(
            p.numel() for p in lightning_module.model.parameters() if p.requires_grad
        ),
        "num_parameters_and_buffers": sum(
            p.numel() for p in lightning_module.model.parameters()
        )
        + sum(b.numel() for b in lightning_module.model.buffers()),
        "cov": cov.__class__.__name__,
        "parametrization": lightning_module.model.parametrization.__class__.__name__,
        "num_samples_train": num_samples_train,
        "num_samples_test": num_samples_test,
        "temperature_scaling": temperature_scaling,
        "optimizer": optimizer,
        "lr": lr,
        "momentum": momentum,
        "nesterov": nesterov,
        "weight_decay": weight_decay,
        "max_epochs": max_epochs,
    }


class MLPIVI(_IVIModel):

    def __init__(
        self,
        in_size: int,
        hidden_sizes: list[int],
        out_size: int,
        parametrization: params.Parametrization,
        cov: params.FactorizedCovariance,
        num_samples_train: int,
        num_samples_test: int,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        dataset: L.LightningDataModule,
        temperature_scaling: bool = False,
        activation_layer: Callable[..., nn.Module] | None = nn.ReLU,
        bias: bool = True,
    ) -> None:
        super().__init__(
            num_classes=2 if out_size == 1 else out_size,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
        )
        self.num_samples_train = num_samples_train
        self.num_samples_test = num_samples_test
        self.dataset = dataset

        # Model
        self.model = bnn.Sequential(
            inferno.models.MLP(
                in_size=in_size,
                hidden_sizes=hidden_sizes,
                out_size=out_size,
                cov=cov,
                activation_layer=activation_layer,
                bias=bias,
            ),
            nn.Flatten(-2, -1),
            parametrization=parametrization,
        )

        # Temperature scaling
        if temperature_scaling:
            self.temperature_scaler = bnn.TemperatureScaler(
                loss_fn=nn.BCEWithLogitsLoss(), lr=0.01
            )

        self.save_hyperparameters(
            _ivi_hyperparameters(
                lightning_module=self,
                architecture="MLP",
                out_size=out_size,
                cov=cov,
                num_samples_train=num_samples_train,
                num_samples_test=num_samples_test,
                lr=lr,
                momentum=momentum,
                nesterov=nesterov,
                weight_decay=weight_decay,
                max_epochs=max_epochs,
                temperature_scaling=temperature_scaling,
            )
            | {
                "num_layers": len(hidden_sizes) + 1,
                "in_size": in_size,
                "hidden_sizes": hidden_sizes,
                "bias": bias,
                "activation_layer": activation_layer.__name__,
            },
            logger=True,
        )

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [12, 12],
    ) -> MLPIVI:

        return cls(
            in_size=int(np.prod(dataset.input_shape)),
            hidden_sizes=hidden_sizes,
            out_size=1 if dataset.num_classes == 2 else dataset.num_classes,
            parametrization=parametrization,
            cov=[params.LowRankCovariance(20), None, params.LowRankCovariance(20)],
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=False,
        )


class MLPIVITemperatureScaling(MLPIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [12, 12],
    ) -> MLPIVITemperatureScaling:

        return cls(
            in_size=int(np.prod(dataset.input_shape)),
            hidden_sizes=hidden_sizes,
            out_size=1 if dataset.num_classes == 2 else dataset.num_classes,
            parametrization=parametrization,
            cov=[params.LowRankCovariance(20), None, params.LowRankCovariance(20)],
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=True,
        )


class MLPIVITheoreticalScaling(MLPIVI):

    @classmethod
    def from_dataset(
        cls,
        dataset: L.LightningDataModule,
        parametrization: params.Parametrization,
        lr: float,
        momentum: float,
        nesterov: bool,
        weight_decay: float,
        max_epochs: int,
        pretrained: bool,
        freeze_pretrained_weights: bool,
        seed: int,
        root_dir: str,
        hidden_sizes: list[int] = [12, 12],
    ) -> MLPIVITheoreticalScaling:

        return cls(
            in_size=int(np.prod(dataset.input_shape)),
            hidden_sizes=hidden_sizes,
            out_size=1 if dataset.num_classes == 2 else dataset.num_classes,
            parametrization=parametrization,
            cov=[params.LowRankCovariance(20), None, params.LowRankCovariance(20)],
            num_samples_train=1,
            num_samples_test=32,
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
            dataset=dataset,
            temperature_scaling=False,
        )

    def on_validation_start(self):
        normalization = 0.0
        if self.num_optim_steps > 0:
            normalization = math.log(
                self.num_optim_steps * self.lr / self.dataset.train_set_size
            )

        if normalization >= 1.0:
            self.model[-2][-2].params.temperature = nn.Parameter(
                torch.ones(
                    1,
                    dtype=self.model[-2][-2].params.temperature.dtype,
                    device=self.model[-2][-2].params.temperature.device,
                )
                * normalization,
                requires_grad=False,
            )
        return super().on_validation_start()
