from __future__ import annotations

import math
from typing import TYPE_CHECKING

import torch

from inferno.bnn.params.parameter import BNNParameter

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor

    from .covariances.factorized import FactorizedCovariance


class CustomGaussianParameter(BNNParameter):
    """Parameter of a BNNModule with Gaussian distribution.

    :param mean: Mean of the Gaussian distribution.
    :param cov: Covariance of the Gaussian distribution.
    """

    def __init__(
        self,
        mean: Float[Tensor, "parameter"] | dict[str, Float[Tensor, "parameter"]],
        cov: FactorizedCovariance,
        scale_forward_weight: float,
        scale_forward_bias: float
    ):
        if not isinstance(mean, dict):
            mean = {"mean": mean}

        super().__init__(hyperparameters=mean)

        self.cov = cov
        self.cov.initialize_parameters(mean)
        self.scales_forward = {'weight': scale_forward_weight, 'bias': scale_forward_bias}

    def sample(
        self,
        sample_shape: torch.Size = torch.Size([]),
        generator: torch.Generator | None = None,
    ) -> (
        Float[Tensor, "*sample parameter"]
        | dict[str, Float[Tensor, "*sample parameter"]]
    ):
        # Sample from standard normal distribution
        standard_normal_sample = torch.randn(
            sample_shape + (self.cov.rank,),
            dtype=next(self.cov.parameters()).dtype,
            device=next(self.cov.parameters()).device,
            generator=generator,
        )

        # Transform the standard normal sample to the correct mean and covariance
        mean_params = {
            name: self.scales_forward[name] * tens
            for name, tens in self.named_parameters()
            if "cov." not in name and "temperature" not in name
        }

        # if not self.training:
        #     # Normalize the mean parameters

        #     if hasattr(self, "optimizer_steps"):
        #         normalization = math.log(
        #             self.optimizer_steps * self.learning_rate / self.train_set_size
        #         )
        #         if normalization >= 1.0:
        #             mean_params["weight"] = mean_params["weight"] / normalization
        #             if "bias" in mean_params:
        #                 mean_params["bias"] = mean_params["bias"] / normalization

        # mean_params["weight"] = mean_params["weight"] / max(
        #     normalization, torch.linalg.norm(mean_params["weight"], dim=-1, keepdim=True).item()
        # )
        # mean_params["weight"] = mean_params["weight"] / torch.linalg.norm(
        #     mean_params["weight"], dim=-1, keepdim=True
        # )

        mean_params_stacked = torch.hstack(
            [tens.view(-1) for tens in mean_params.values()]
        )

        # Scale with inverse temperature if not training and the parameters are in the output layer
        if hasattr(self, "temperature") and not self.training:
            mean_params_stacked = mean_params_stacked / self.temperature

        return self.cov.factor_matmul(
            standard_normal_sample,
            additive_constant=mean_params_stacked,
        )
