from dataclasses import dataclass

import torch

from ...classes import MLP, Hyperparameters, ModelInterface


class DeterministicEncoder(torch.nn.Module):
    def __init__(self, input_size: int, output_size: int, hidden_sizes: list):
        super().__init__()
        self.mlp = MLP(input_size, output_size, hidden_sizes)

    def __call__(self, X: torch.Tensor):
        h = self.mlp(X).mean(dim=(1, 2))
        return h, None


class LatentEncoderInterface(torch.nn.Module):
    min_g: float
    prior_mu: float
    prior_sigma: float

    def compute_gaussian_params(self, X: torch.tensor):
        raise NotImplementedError()

    def compute_f_and_g(self, X: torch.tensor):
        f_n, log_g_n = self.compute_gaussian_params(X)

        if self.is_g_fixed:
            return f_n, torch.full_like(f_n, self.prior_sigma)

        eps = 1e-6

        if self.prior_sigma == 1:
            coeff = 1 / (1 + self.min_g + eps)
            g_n = coeff * (torch.sigmoid(log_g_n) + self.min_g)
            return f_n, g_n

        # avoid nan
        log_scale_max = 5
        log_g_n = torch.clamp(log_g_n, max=log_scale_max)
        g_n = torch.exp(log_g_n)
        # g_n should be smaller than g_0
        coeff = self.prior_sigma / (self.prior_sigma + self.min_g + eps)
        g_n = coeff * (1 / (1 / g_n + 1 / (self.prior_sigma)) + self.min_g)
        return f_n, g_n

    def compute_mean_and_variance(self, f_n: torch.tensor, g_n: torch.tensor):
        num_datapoints = f_n.shape[1]
        g2_n = g_n**2
        eps = 1e-12

        variance = 1 / (torch.sum(1 / (g2_n + eps), dim=1) -
                        (num_datapoints - 1) / self.prior_sigma**2)
        mu = variance * (
            torch.sum(f_n / (g2_n + eps), dim=1) -
            (num_datapoints - 1) * self.prior_mu / self.prior_sigma**2)
        return mu, variance

    def h_distribution(self, X: torch.Tensor):
        f_n, g_n = self.compute_f_and_g(X)
        mu, variance = self.compute_mean_and_variance(f_n, g_n)
        return torch.distributions.Normal(loc=mu, scale=torch.sqrt(variance))

    def forward(self, X: torch.tensor):
        h_distribution = self.h_distribution(X)
        h = h_distribution.rsample()
        return h, h_distribution

    def fix_g(self, flag):
        self.is_g_fixed = flag


class LatentEncoderSharedHiddenLayers(LatentEncoderInterface):
    def __init__(self, input_size: int, output_size: int,
                 shared_layers_output_size: int,
                 shared_layers_hidden_sizes: list,
                 gaussian_params_hidden_sizes: list, min_g: float,
                 prior_mu: float, prior_sigma: float):
        super().__init__()
        self.min_g = min_g
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.is_g_fixed = False
        self.mlp = MLP(input_size,
                       shared_layers_output_size,
                       shared_layers_hidden_sizes,
                       output_activation=True)
        self.gaussian_params_f = MLP(shared_layers_output_size, output_size,
                                     gaussian_params_hidden_sizes)
        self.gaussian_params_g = MLP(shared_layers_output_size, output_size,
                                     gaussian_params_hidden_sizes)

    def compute_gaussian_params(self, X: torch.tensor):
        batchsize, num_tuples = X.shape[:2]
        X = X.view((batchsize, num_tuples, -1))

        hidden = self.mlp(X)
        f_n = self.gaussian_params_f(hidden)
        log_g_n = self.gaussian_params_g(hidden)
        return f_n, log_g_n


class LatentEncoder(LatentEncoderInterface):
    def __init__(self, input_size: int, output_size: int, hidden_sizes: list,
                 min_g: float, prior_mu: float, prior_sigma: float):
        super().__init__()
        self.min_g = min_g
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.is_g_fixed = False
        self.mlp_f = MLP(input_size, output_size, hidden_sizes)
        self.mlp_g = MLP(input_size, output_size, hidden_sizes)

    def compute_gaussian_params(self, X: torch.tensor):
        batchsize, num_tuples = X.shape[:2]
        X = X.view((batchsize, num_tuples, -1))

        f_n = self.mlp_f(X)
        log_g_n = self.mlp_g(X)

        return f_n, log_g_n
