from dataclasses import dataclass

import torch
import math

from ...classes import MLP, Hyperparameters, ModelInterface
from .encoder import LatentEncoder, LatentEncoderInterface
from .decoder import Decoder


@dataclass
class ModelHyperparameters(Hyperparameters):
    encoder_hidden_sizes: str = "128,256,512"  # used for the standard encoder
    encoder_shared_layers_output_size: int = 512  # used for the encoder that shares the hidden layers
    encoder_shared_layers_hidden_sizes: str = "128,256"  # used for the encoder that shares the hidden layers
    encoder_gaussian_params_hidden_sizes: str = "512"  # used for the encoder that shares the hidden layers
    encoder_share_hidden_layers: bool = False
    encoder_with_normal: bool = False
    deterministic_encoder: bool = False

    decoder_hidden_sizes: str = "512,512,512,512,512,512,512"
    decoder_disable_skip_connection: bool = False
    decoder_softplus_beta: float = 100

    f0: float = 0
    g0: float = 1
    min_g: float = 0.005
    h_dim: int = 512


@dataclass
class ModelOutputDescription:
    pred_distance: torch.Tensor
    decoder_input_points: torch.Tensor
    decoder_input_h: torch.Tensor
    h_prior_dist: torch.Tensor
    h_posterior_dist: torch.Tensor
    h_prior_mean: torch.Tensor = None
    h_prior_var: torch.Tensor = None
    h_posterior_mean: torch.Tensor = None
    h_posterior_var: torch.Tensor = None


def _repeat(array, num_repeats):
    new_array = array[:, None, :].expand(
        (array.shape[0], num_repeats, array.shape[1]))
    return new_array


class Model(ModelInterface):
    def __init__(self,
                 encoder: LatentEncoderInterface,
                 decoder: Decoder,
                 X_dim=3):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.X_dim = X_dim

    def fix_g(self, flag: bool):
        self.encoder.fix_g(flag)

    def h_distribution(self, context_xy: torch.Tensor,
                       context_depth: torch.Tensor):
        mu, variance = self.encoder(context_xy, context_depth)
        return torch.distributions.Normal(loc=mu, scale=torch.sqrt(variance))

    def predict(self, context_X: torch.Tensor, target_X: torch.Tensor):
        """
        Parameters
        ----------
        context_X : (batchsize, num_contex_tuples, X_dim)
        target_X : (batchsize, num_contex_tuples, X_dim)

        Returns
        -------
        output : ModelOutputDescription
        """
        batchsize, num_context_tuples = context_X.shape[:2]
        h_prior, h_prior_dist = self.encoder(context_X)

        # flatten
        X = torch.cat((context_X, target_X), dim=1)
        _X = X.view((batchsize, -1, self.X_dim))

        num_points = _X.shape[1]
        _h_prior = _repeat(h_prior, num_points)

        pred_distance = self.decoder(X=_X, h=_h_prior)

        return ModelOutputDescription(pred_distance=pred_distance,
                                      decoder_input_points=_X,
                                      decoder_input_h=h_prior,
                                      h_prior_dist=h_prior_dist,
                                      h_posterior_dist=None)

    def forward(self, context_X: torch.Tensor, target_X: torch.Tensor):
        """
        Parameters
        ----------
        context_X : (batchsize, num_contex_tuples, X_dim)
        target_X : (batchsize, num_contex_tuples, X_dim)

        Returns
        -------
        output : ModelOutputDescription
        """
        batchsize, num_context_tuples = context_X.shape[:2]

        h_prior, h_prior_dist = self.encoder(context_X)

        X = torch.cat((context_X, target_X), dim=1)
        h_posterior, h_posterior_dist = self.encoder(X)

        # flatten
        _X = X.view((batchsize, -1, self.X_dim))
        _X.requires_grad_(True)

        num_points = _X.shape[1]
        _h_posterior = _repeat(h_posterior, num_points)

        pred_distance = self.decoder(X=_X, h=_h_posterior)

        return ModelOutputDescription(pred_distance=pred_distance,
                                      decoder_input_points=_X,
                                      decoder_input_h=h_posterior,
                                      h_prior_dist=h_prior_dist,
                                      h_posterior_dist=h_posterior_dist)
