import torch
import numpy as np

from src.methods.models.layers.stochastic.stochastic import StochasticLayer

MODES = ["static", "trainable", "contextual"]


class InvalidGaussianModeException(Exception):
    pass


class GaussianLayer(StochasticLayer):

    def __init__(self, network: torch.nn.Module, input_dim: int, output_dim: int, mode: str, std: float):

        super().__init__(network, output_dim)

        self._input_dim = input_dim
        self._mode = mode
        self._std = std

        if mode not in MODES:
            raise InvalidGaussianModeException("mode must be one between:", MODES)

        self._initialize_std()

    def initialize(self) -> None:

        super().initialize()
        self._initialize_std()

    def build_distribution(self, x: torch.Tensor,
                           y_mean: torch.Tensor | None = None) -> torch.distributions.Distribution:

        if y_mean is None:
            mean = self._network(x)
        else:
            mean = y_mean

        if self._mode == "static" or self._mode == "trainable":
            log_std = self._log_std
        else:
            log_std = self._std_layer(x)

        std = torch.exp(log_std)
        covariance = torch.diag_embed(std)

        gaussian = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=covariance)

        return gaussian

    def _initialize_std(self) -> None:

        if self._mode == "static" or self._mode == "trainable":
            std = np.ones(shape=(self._output_dim,), dtype=np.float32) * self._std
            train = self._mode == "trainable"
            self._log_std = torch.nn.Parameter(torch.log(torch.tensor(std)), requires_grad=train)
            self._std_layer = None

        # Contextual
        else:
            self._log_std = None
            self._std_layer = torch.nn.Linear(self._input_dim, self._output_dim)
