import numpy as np
import torch
from torch import tensor

from hvae_backbone.block import OutputBlock, InputPipeline, SimpleGenBlock
from hvae_backbone.block import get_net
from hvae_backbone.block import JointDistribution
from hvae_backbone.utils import split_mu_sigma, SerializableSequential as Sequential
from hvae_backbone.elements.distributions import generate_distribution


# class ContrastiveOutputBlock(OutputBlock):

#     # only for 1D inputs
#     def __init__(self, net, input_id, contrast_dims: int, output_distribution: str = 'normal'):
#         super().__init__(net, input_id, output_distribution)
#         self.contrast_dims = contrast_dims

#     def _sample_uncond(self, y: tensor, t: float or int = None, use_mean=False) -> tensor:
#         y_input = y[:, :-self.contrast_dims]
#         contrast = y[:, -self.contrast_dims:]
#         y_prior = self.prior_net(y_input)
#         pm, pv = split_mu_sigma(y_prior)
#         pm_shape = pm.shape
#         pm_flattened = torch.flatten(pm, start_dim=1)
#         pm = pm_flattened * contrast
#         pm = pm.reshape(pm_shape)
#         if t is not None:
#             pv = pv + torch.ones_like(pv) * np.log(t)
#         prior = generate_distribution(pm, pv, self.output_distribution)
#         z = prior.sample() if not use_mean else prior.mean
#         return z, (prior, None)

#     def serialize(self) -> dict:
#         serialized = super().serialize()
#         serialized["contrast_dims"] = self.contrast_dims
#         return serialized

#     @staticmethod
#     def deserialize(serialized: dict):
#         prior_net = Sequential.deserialize(serialized["prior_net"])
#         return ContrastiveOutputBlock(
#             net=prior_net,
#             input_id=InputPipeline.deserialize(serialized["input"]),
#             contrast_dims=serialized["contrast_dims"],
#             output_distribution=serialized["output_distribution"]
#         )

class ContrastiveOutputBlock(OutputBlock):
    def __init__(self, net, input_id, contrast_dims: int = 1, output_distribution: str = 'normal', stddev=None):
        super().__init__(net, input_id, output_distribution, stddev)
        self.contrast_dims = contrast_dims

    def forward(self, computed: dict, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        x_input = x[:, :-self.contrast_dims]
        contrast = x[:, -self.contrast_dims:]
        
        pm = x_input
        pm_shape = pm.shape
        pm_flattened = torch.flatten(pm, start_dim=1)
        pm = pm_flattened * contrast
        pm = pm.reshape(pm_shape)       

        pm = self.prior_net(pm) # decoder
        pv = self.stddev * torch.ones_like(pm, device=pm.device)

        x_prior = torch.cat([pm, pv], dim=1)
        prior = generate_distribution(x_prior, (self.output_distribution, 'none', 'std'))
        z = prior.sample() if not use_mean else prior.mean
        distribution =  (prior, None)
        computed[self.output] = z
        return computed, distribution

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        x_input = x[:, :-self.contrast_dims]
        contrast = x[:, -self.contrast_dims:]
        pm = x_input

        pm_shape = pm.shape
        pm_flattened = torch.flatten(pm, start_dim=1)
        pm = pm_flattened * contrast
        pm = pm.reshape(pm_shape)

        pm = self.prior_net(x_input)
        pv = self.stddev * torch.ones_like(pm, device=pm.device)
        x_prior = torch.cat([pm, pv], dim=1)        
        prior = generate_distribution(x_prior, (self.output_distribution, 'none', 'std'), t)
        z = prior.sample() if not use_mean else prior.mean
        distribution = (prior, None)
        computed[self.output] = z
        return z, distribution
    
    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["contrast_dims"] = self.contrast_dims
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized.pop("prior_net")
        prior_net = prior_net["type"].deserialize(prior_net)
        return ContrastiveOutputBlock(
            net=prior_net,
            input_id=InputPipeline.deserialize(serialized["input"]),
            contrast_dims=serialized["contrast_dims"],
            output_distribution=serialized["output_distribution"],
            stddev=serialized.pop("stddev"),
        )

class ContrastiveGenBlock(SimpleGenBlock):
    '''
        Enables having multiple distributions on different latens dimensions.

        The regular output distribution can be: 'normal', 'laplace'
        The contrast distribution can be: 'lognormal', 'softlaplace', 'loglaplace'
    '''

    def __init__(self,
                 prior_net,
                 posterior_net,
                 input_id, condition,
                 output_distribution: str = 'normal',
                 contrast_distribution: str = 'lognormal',
                 contrast_dims: int = 1,
                 kl_loss = 'default'):
        super(ContrastiveGenBlock, self).__init__(prior_net, input_id, output_distribution)
        self.prior_net = get_net(prior_net)
        self.posterior_net = get_net(posterior_net)
        self.condition = InputPipeline(condition)
        self.contrast_distribution = contrast_distribution
        self.contrast_dims = contrast_dims
        self.kl_loss = kl_loss

    def generate_concatenated(self, z, z_distribution, contrast_distribution, temperature=None):
        length = z.shape[1]
        mean_values = z[:, :length//2]
        sigma_values = z[:, length//2:]
        z_dims = torch.cat((mean_values[:, :-self.contrast_dims], 
                            sigma_values[:, :-self.contrast_dims]), 
                            dim=1)
        contrast_dims = torch.cat((mean_values[:, -self.contrast_dims:],
                                   sigma_values[:, -self.contrast_dims:]), 
                                   dim=1)
        p = generate_distribution(z_dims, z_distribution, temperature) # z dims 
        q = generate_distribution(contrast_dims, contrast_distribution, temperature) # s (contrast) dims
        return JointDistribution([p, q])

    def _sample(self, x: tensor, cond: tensor, variate_mask=None, use_mean=False) -> (tensor, tuple):
        x_prior = self.prior_net(x)
        prior = self.generate_concatenated(x_prior, self.output_distribution, self.contrast_distribution)
        x_posterior = self.posterior_net(cond)
        posterior = self.generate_concatenated(x_posterior, self.output_distribution, self.contrast_distribution)
        z = posterior.rsample() if not use_mean else posterior.mean

        if variate_mask is not None:
            z_prior = prior.rsample() if not use_mean else prior.mean
            z = self.prune(z, z_prior, variate_mask)

        return z, (prior, posterior, self.kl_loss)

    def _sample_uncond(self, x: tensor, t: float or int = None, use_mean=False) -> tensor:
        x_prior = self.prior_net(x)
        prior = self.generate_concatenated(x_prior, self.output_distribution, self.contrast_distribution, t)
        z = prior.sample() if not use_mean else prior.mean
        return z, (prior, None)

    def forward(self, computed: dict, variate_mask=None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        cond = self.condition(computed)
        z, distributions = self._sample(x, cond, variate_mask, use_mean=use_mean)
        computed['z_posterior_mean'] = distributions[1].mean
        computed['z_posterior_std'] = distributions[1].stddev
        computed['z_posterior_sample'] = distributions[1].sample()
        computed['z_posterior_loc'] = distributions[1].loc
        computed[self.output] = z
        return computed, distributions

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        z, dist = self._sample_uncond(x, t, use_mean=use_mean)
        computed[self.output] = z
        return computed, dist

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["contrast_distribution"] = self.contrast_distribution
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized["prior_net"]["type"].deserialize(serialized["prior_net"])
        posterior_net = serialized["posterior_net"]["type"].deserialize(serialized["posterior_net"])
        return ContrastiveGenBlock(
            prior_net=prior_net,
            posterior_net=posterior_net,
            input_id=InputPipeline.deserialize(serialized["input"]),
            condition=InputPipeline.deserialize(serialized["condition"]),
            output_distribution=serialized["output_distribution"],
            contrast_distribution=serialized["contrast_distribution"],
        )
    
    def extra_repr(self) -> str:
        return super().extra_repr()