import numpy as np
import torch
from torch import tensor

from .elements.distributions import generate_distribution, JointDistribution
from .elements.nets import get_net
from .utils import SerializableModule, SerializableSequential as Sequential, split_mu_sigma, softclip

"""
Sampling and forward methods based on VDAVAE paper
"""


class _Block(SerializableModule):
    """
    Base class for all blocks
    """

    def __init__(self, input_id = None):
        super(_Block, self).__init__()
        self.input = InputPipeline(input_id)
        self.output = None

    def forward(self, computed: dict, **kwargs) -> (dict, None):
        return dict(), None

    def sample_from_prior(self, computed: dict, t = None, **kwargs) -> (dict, None):
        return self.forward(computed)

    def freeze(self, net_name: str):
        for name, param in self.named_parameters():
            if net_name in name:
                param.requires_grad = False

    def set_output(self, output: str) -> None:
        self.output = output

    def serialize(self) -> dict:
        return dict(
            input=self.input.serialize(),
            output=self.output,
            type=self.__class__
        )


class InputPipeline(SerializableModule):
    """
    Helper class for preprocessing pipeline
    """

    def __init__(self, input_pipeline: str or tuple or list):
        super(InputPipeline, self).__init__()
        self.inputs = self.parse(input_pipeline)

    def forward(self, computed):
        return self._load(computed, self.inputs)

    def parse(self, input_pipeline):
        if isinstance(input_pipeline, str):
            return input_pipeline
        elif isinstance(input_pipeline, tuple):
            return tuple([self.parse(i) for i in input_pipeline])
        elif isinstance(input_pipeline, list):
            return [self.parse(i) for i in input_pipeline]
        elif isinstance(input_pipeline, (SerializableModule, Sequential)):
            self.register_module(str(len(self._modules)), input_pipeline)
            return input_pipeline
        elif hasattr(input_pipeline, "config"):
            net = get_net(input_pipeline)
            self.register_module(str(len(self._modules)), net)
            return net
        else:
            raise ValueError(f"Unknown input pipeline element {input_pipeline}")

    def serialize(self):
        return self._serialize(self.inputs)

    def _serialize(self, item):
        if isinstance(item, str):
            return item
        elif isinstance(item, list):
            return [i.serialize() if isinstance(i, (SerializableModule, Sequential))
                    else self._serialize(i) for i in item]
        elif isinstance(item, tuple):
            return tuple([self._serialize(i) for i in item])

    @staticmethod
    def deserialize(serialized):
        if isinstance(serialized, str):
            return serialized
        elif isinstance(serialized, list):
            return [i["type"].deserialize(i) if isinstance(i, dict) and "type" in i.keys()
                    else InputPipeline.deserialize(i) for i in serialized]
        elif isinstance(serialized, tuple):
            return tuple([InputPipeline.deserialize(i) for i in serialized])

    @staticmethod
    def _load(computed: dict, inputs):
        def _validate_get(_inputs):
            if not isinstance(_inputs, str):
                raise ValueError(f"Input {_inputs} must be a string")
            if _inputs not in computed:
                raise ValueError(f"Input {_inputs} not found in computed")
            return computed[_inputs]

        # single input
        if isinstance(inputs, str):
            return _validate_get(inputs)

        # multiple inputs
        elif isinstance(inputs, tuple):
            return tuple([InputPipeline._load(computed, i) for i in inputs])

        # list of operations
        elif isinstance(inputs, list):
            if len(inputs) < 2:
                raise ValueError(f"Preprocessing pipeline must have at least 2 elements, got {len(inputs)}"
                                 f"Provide the inputs in [inputs, operation1, operation2, ...] format")
            if not isinstance(inputs[0], (str, tuple)):
                raise ValueError(f"First element of the preprocessing pipeline "
                                 f"must be the input id or tuple of input ids, got {inputs[0]}")
            input_tensors = InputPipeline._load(computed, inputs[0])
            for op in inputs[1:]:
                if callable(op):
                    input_tensors = op(*input_tensors) if isinstance(input_tensors, tuple) \
                                                        else op(input_tensors)
                elif isinstance(op, str):
                    if op == "concat":
                        input_tensors = torch.cat(input_tensors, dim=1)
                    elif op == "sub":
                        input_tensors = input_tensors[0] - input_tensors[1]
                    elif op == "add":
                        input_tensors = input_tensors[0] + input_tensors[1]
            return input_tensors


class SimpleBlock(_Block):
    """
    Simple block that takes an input and returns an output
    No sampling is performed
    """

    def __init__(self, net, input_id):
        super(SimpleBlock, self).__init__(input_id)
        self.net = get_net(net)

    def forward(self, computed: dict, **kwargs) -> (dict, None):
        inputs = self.input(computed)
        output = self.net(inputs)
        computed[self.output] = output
        return computed, None

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["net"] = self.net.serialize()
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        net = serialized["net"]["type"].deserialize(serialized["net"])
        return SimpleBlock(net=net, input_id=InputPipeline.deserialize(serialized["input"]))


class InputBlock(SimpleBlock):
    """
    Block that takes an input
    and runs it through a preprocessing net if one is given
    """

    def __init__(self, net=None):
        super(InputBlock, self).__init__(net, "input")

    def forward(self, inputs: dict, **kwargs) -> tuple:
        if isinstance(inputs, dict):
            computed = inputs
        elif isinstance(inputs, torch.Tensor):
            computed = {"inputs": inputs,
                        self.output: self.net(inputs)}
        else:
            raise ValueError(f"Input must be a tensor or a dict got {type(inputs)}")
        distributions = dict()
        return computed, distributions

    @staticmethod
    def deserialize(serialized: dict):
        net = serialized["net"]["type"].deserialize(serialized["net"])
        return InputBlock(net=net)


class SimpleGenBlock(_Block):
    """
    Takes an input and samples from a prior distribution
    """

    def __init__(self, net, input_id, output_distribution: str = 'normal'):
        super(SimpleGenBlock, self).__init__(input_id)
        self.prior_net = get_net(net)
        self.output_distribution: str = output_distribution

    def _sample_uncond(self, x: tensor, t: float or int = None, use_mean=False) -> tensor:
        x_prior = self.prior_net(x)
        prior = generate_distribution(x_prior, self.output_distribution, t)
        #import pdb; pdb.set_trace()

        z = prior.rsample() if not use_mean else prior.mean
        return z, (prior, None)

    def forward(self, computed: dict, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        z, distribution = self._sample_uncond(x, use_mean=use_mean)
        computed[self.output] = z
        return computed, distribution

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        z, dist = self._sample_uncond(x, t, use_mean=use_mean)
        computed[self.output] = z
        return computed, dist

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["prior_net"] = self.prior_net.serialize()
        serialized["output_distribution"] = self.output_distribution
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized["prior_net"]["type"].deserialize(serialized["prior_net"])
        return SimpleGenBlock(
            net=prior_net,
            input_id=InputPipeline.deserialize(serialized["input"]),
            output_distribution=serialized["output_distribution"]
        )
        
    def extra_repr(self) -> str:
        return super().extra_repr() + f"\noutput_distribution={self.output_distribution}\n"


class OutputBlock(SimpleGenBlock):
    def __init__(self, net, input_id, output_distribution: str = 'normal', stddev = None, optimal_sigma = False, max_sigma=None):
        assert optimal_sigma or stddev is not None, "Either optimal_sigma or stddev must be provided"
        self.optimal_sigma = optimal_sigma
        self.max_sigma = max_sigma
        if max_sigma is not None:
            assert optimal_sigma, "max_sigma can only be used with optimal_sigma"

        self.stddev = stddev if isinstance(stddev, torch.Tensor) or stddev is None else torch.tensor(stddev)
        super(OutputBlock, self).__init__(net, input_id, output_distribution)
        assert isinstance(self.output_distribution, str), \
                "Output distribution must be a string for OutputBlock. " \
                "Standard deviation is used with no nonlinearity."

    def forward(self, computed: dict, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        pm = self.prior_net(x)
        if self.optimal_sigma:
            if self.training or self.stddev is None:
                inputs = computed["inputs"]
                pv = torch.nn.MSELoss(reduction='mean')(inputs.detach(), pm.detach()).sqrt()

                if pv.isnan():
                    print("NaN in optimal sigma")
                    pv = self.stddev
                else:
                    pv = softclip(pv, 1e-6) # TODO: hard-coded value, without softclip?
                    if self.max_sigma is not None:
                        pv = torch.clip(pv, max=self.max_sigma) 

                    self.stddev = 0.9 * self.stddev + 0.1 * pv if self.stddev is not None else pv

                pv = pv * torch.ones_like(pm, device=pm.device)
             
            else:
                pv = self.stddev * torch.ones_like(pm, device=pm.device)
                        
        else:
            pv = self.stddev * torch.ones_like(pm, device=pm.device)

        x_prior = torch.cat([pm, pv], dim=1)
        prior = generate_distribution(x_prior, (self.output_distribution, 'none', 'std'))
        z = prior.sample() if not use_mean else prior.mean
        distribution =  (prior, None)
        computed[self.output] = z
        return computed, distribution
        
    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        pm = self.prior_net(x)
        pv = self.stddev * torch.ones_like(pm, device=pm.device)
        x_prior = torch.cat([pm, pv], dim=1)
        prior = generate_distribution(x_prior, (self.output_distribution, 'none', 'std'), t)
        z = prior.sample() if not use_mean else prior.mean
        distribution =  (prior, None)
        computed[self.output] = z
        return computed, distribution
    

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["optimal_sigma"] = self.optimal_sigma
        serialized["stddev"] = self.stddev
        serialized["max_sigma"] = self.max_sigma
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized.pop("prior_net")
        prior_net = prior_net["type"].deserialize(prior_net)
        return OutputBlock(
            net=prior_net,
            input_id=InputPipeline.deserialize(serialized.pop("input")),
            output_distribution=serialized["output_distribution"],
            stddev=serialized.pop("stddev"),
            optimal_sigma=serialized.pop("optimal_sigma", False),
            max_sigma=serialized.pop("max_sigma", None)
        )
        
    def extra_repr(self) -> str:
        return super().extra_repr() + f"\noptimal_sigma={self.optimal_sigma}, \nstddev={self.stddev}, \nmax_sigma={self.max_sigma}\n"


class GenBlock(SimpleGenBlock):
    """
    Takes an input,
    samples from a prior distribution,
    (takes a condition,
    samples from a posterior distribution),
    and returns the sample
    """

    def __init__(self,
                 prior_net,
                 posterior_net,
                 input_id, condition,
                 output_distribution: str = 'normal',
                 posterior_distribution: str = None,
                 fuse_prior: str = None,
                 kl_loss = 'default'):
        super(GenBlock, self).__init__(prior_net, input_id, output_distribution)
        self.prior_net = get_net(prior_net)
        self.posterior_net = get_net(posterior_net)
        self.condition = InputPipeline(condition)
        self.fuse_prior = fuse_prior
        self.kl_loss = kl_loss
        self.posterior_distribution = posterior_distribution if posterior_distribution is not None else output_distribution

    def _sample(self, x: tensor, cond: tensor, variate_mask=None, use_mean=False) -> (tensor, tuple):
        x_prior = self.prior_net(x)
        prior = generate_distribution(x_prior, self.output_distribution)
        #import pdb; pdb.set_trace()
        if self.fuse_prior is not None:
            cond = self.fuse(prior, cond, self.fuse_prior)

        x_posterior = self.posterior_net(cond)
        posterior = generate_distribution(x_posterior, self.posterior_distribution)
        z = posterior.rsample() if not use_mean else posterior.mean

        if variate_mask is not None:
            z_prior = prior.rsample() if not use_mean else prior.mean
            z = self.prune(z, z_prior, variate_mask)

        return z, (prior, posterior, self.kl_loss)

    def _sample_uncond(self, x: tensor, t: float or int = None, use_mean=False) -> tensor:
        x_prior = self.prior_net(x)
        prior = generate_distribution(x_prior, self.output_distribution, t)
        z = prior.sample() if not use_mean else prior.mean
        return z, (prior, None)

    def forward(self, computed: dict, variate_mask=None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        cond = self.condition(computed)
        z, distributions = self._sample(x, cond, variate_mask, use_mean=use_mean)
        computed[self.output] = z
        return computed, distributions

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        z, dist = self._sample_uncond(x, t, use_mean=use_mean)
        computed[self.output] = z
        return computed, dist

    @staticmethod
    def prune(z, z_prior, variate_mask=None):
        variate_mask = torch.Tensor(variate_mask)
        # Only used in inference mode to prune turned-off variates
        # Use posterior sample from meaningful variates, and prior sample from "turned-off" variates
        # The NLL should be similar to using z_post without masking if the mask is good (not very destructive)
        # variate_mask automatically broadcasts to [batch_size, H, W, n_variates]
        z = variate_mask * z + (1. - variate_mask) * z_prior
        return z

    @staticmethod
    def fuse(prior, cond, method):
        if method == "concat":
            return torch.cat([cond, prior.rsample()], dim=1)
        elif method == "concat_logits":
            mean = prior.mean
            std = prior.stddev
            return torch.cat([cond, mean, std], dim=1)
        elif method == "add":
            return cond + prior
        elif method == "substract":
            return cond - prior
        # elif callable(method):
        #    return method(cond, prior)
        else:
            raise ValueError(f"Unknown method {method} for fusing prior and condition")

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["prior_net"] = self.prior_net.serialize()
        serialized["posterior_net"] = self.posterior_net.serialize()
        serialized["condition"] = self.condition.serialize()
        serialized["output_distribution"] = self.output_distribution
        serialized["fuse_prior"] = self.fuse_prior
        serialized["kl_loss"] = self.kl_loss
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized["prior_net"]["type"].deserialize(serialized["prior_net"])
        posterior_net = serialized["posterior_net"]["type"].deserialize(serialized["posterior_net"])
        return GenBlock(
            prior_net=prior_net,
            posterior_net=posterior_net,
            input_id=InputPipeline.deserialize(serialized["input"]),
            condition=InputPipeline.deserialize(serialized["condition"]),
            output_distribution=serialized["output_distribution"],
            fuse_prior=serialized["fuse_prior"],
            kl_loss=serialized.pop("kl_loss", 'default')
        )
    
    def extra_repr(self) -> str:
        return super().extra_repr() + f"\nfuse_prior={self.fuse_prior} \n"
    
    
class OnlyFirst(_Block):
    """
    Simple block that takes an input and returns an output
    No sampling is performed
    """

    def __init__(self, block):
        super(OnlyFirst, self).__init__(block.input)
        self.block = block

    def forward(self, computed: dict, **kwargs) -> (dict, None):
        return self.block.forward(computed, **kwargs)
    
    def set_output(self, output: str) -> None:
        self.output = output
        if output.endswith("_first"):
            output = output[:-6]
        self.block.set_output(output)

    def serialize(self) -> dict:
        serialized = super().serialize()
        block_serialized = self.block.serialize()
        serialized["block"] = block_serialized
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        block = serialized["block"]["type"].deserialize(serialized["block"])
        return OnlyFirst(block)


"""
------------------------
CUSTOM BLOCKS
------------------------
"""


class ResidualGenBlock(GenBlock):
    """
    Architecture from VDVAE paper
    """

    def __init__(self, net,
                 prior_net,
                 posterior_net,
                 z_projection,
                 input_id, condition,
                 concat_posterior: bool,
                 prior_layer=None,
                 posterior_layer=None,
                 output_distribution: str = 'normal'):
        super(ResidualGenBlock, self).__init__(
            prior_net, posterior_net, input_id, condition, concat_posterior, output_distribution)
        self.net = get_net(net)
        self.z_projection = get_net(z_projection)
        self.prior_layer = get_net(prior_layer)
        self.posterior_layer = get_net(posterior_layer)

    def _sample(self, y: tensor, cond: tensor, variate_mask=None, use_mean=False) -> (tensor, tensor, tuple):

        y_prior = self.prior_net(y)
        kl_residual, y_prior = split_mu_sigma(y_prior, chunks=2)
        y_prior = self.prior_layer(y_prior)
        prior = generate_distribution(y_prior, self.output_distribution)

        y_posterior = self.posterior_net(torch.cat([y, cond], dim=1))  # y, cond fordított sorrendben mint máshol
        y_posterior = self.posterior_layer(y_posterior)
        posterior = generate_distribution(y_posterior, self.output_distribution)
        z = posterior.rsample() if not use_mean else posterior.mean

        if variate_mask is not None:
            z_prior = prior.rsample() if not use_mean else prior.mean
            z = self.prune(z, z_prior, variate_mask)

        y = y + kl_residual
        return z, y, (prior, posterior)

    def _sample_uncond(self, y: tensor, t: float or int = None, use_mean=False) -> (tensor, tensor):
        y_prior = self.prior_net(y)
        kl_residual, y_prior = split_mu_sigma(y_prior, chunks=2)
        y_prior = self.prior_layer(y_prior)
        prior = generate_distribution(y_prior, self.output_distribution, t)
        z = prior.sample() if not use_mean else prior.mean
        y = y + kl_residual
        return z, y, (prior, None)

    def forward(self, computed: dict, variate_mask=None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        cond = self.condition(computed)
        z, y, distributions = self._sample(x, cond, variate_mask, use_mean)
        y = y + self.z_projection(z)
        y = self.net(y)
        computed[self.output] = y
        return computed, distributions

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        z, y, dist = self._sample_uncond(x, t, use_mean=use_mean)
        y = y + self.z_projection(z)
        y = self.net(y)
        computed[self.output] = y
        return computed, dist

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["net"] = self.net.serialize()
        serialized["z_projection"] = self.z_projection.serialize()
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        net = serialized["net"]["type"].deserialize(serialized["net"])
        prior_net = serialized["prior_net"]["type"].deserialize(serialized["prior_net"])
        posterior_net = serialized["posterior_net"]["type"].deserialize(serialized["posterior_net"])
        z_projection = serialized["z_projection"]["type"].deserialize(serialized["z_projection"])
        return ResidualGenBlock(
            net=net,
            prior_net=prior_net,
            posterior_net=posterior_net,
            z_projection=z_projection,
            input_id=InputPipeline.deserialize(serialized["input_id"]),
            condition=InputPipeline.deserialize(serialized["condition"]),
            concat_posterior=serialized["concat_posterior"],
            output_distribution=serialized["output_distribution"]
        )

class ContrastiveOutputBlock(OutputBlock):
    def __init__(self, net, input_id, contrast_dims: int = 1, output_distribution: str = 'normal', stddev=None):
        super().__init__(net, input_id, output_distribution, stddev)
        self.contrast_dims = contrast_dims

    def forward(self, computed: dict, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        x_input = x[:, :-self.contrast_dims]
        contrast = x[:, -self.contrast_dims:]
        
        pm = x_input
        pm_shape = pm.shape
        pm_flattened = torch.flatten(pm, start_dim=1)
        pm = pm_flattened * contrast
        pm = pm.reshape(pm_shape)       

        pm = self.prior_net(pm) # decoder
        pv = self.stddev * torch.ones_like(pm, device=pm.device)

        x_prior = torch.cat([pm, pv], dim=1)
        prior = generate_distribution(x_prior, (self.output_distribution, 'none', 'std'))
        z = prior.sample() if not use_mean else prior.mean
        distribution =  (prior, None)
        computed[self.output] = z
        return computed, distribution

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        x_input = x[:, :-self.contrast_dims]
        contrast = x[:, -self.contrast_dims:]
        pm = x_input

        pm_shape = pm.shape
        pm_flattened = torch.flatten(pm, start_dim=1)
        pm = pm_flattened * contrast
        pm = pm.reshape(pm_shape)

        pm = self.prior_net(x_input)
        pv = self.stddev * torch.ones_like(pm, device=pm.device)
        x_prior = torch.cat([pm, pv], dim=1)        
        prior = generate_distribution(x_prior, (self.output_distribution, 'none', 'std'), t)
        z = prior.sample() if not use_mean else prior.mean
        distribution = (prior, None)
        computed[self.output] = z
        return z, distribution
    
    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["contrast_dims"] = self.contrast_dims
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized.pop("prior_net")
        prior_net = prior_net["type"].deserialize(prior_net)
        return ContrastiveOutputBlock(
            net=prior_net,
            input_id=InputPipeline.deserialize(serialized["input"]),
            contrast_dims=serialized["contrast_dims"],
            output_distribution=serialized["output_distribution"],
            stddev=serialized.pop("stddev"),
        )

class ContrastiveGenBlock(SimpleGenBlock):
    '''
        Enables having multiple distributions on different latens dimensions.

        The regular output distribution can be: 'normal', 'laplace'
        The contrast distribution can be: 'lognormal', 'softlaplace', 'loglaplace'
    '''

    def __init__(self,
                 prior_net,
                 posterior_net,
                 input_id, condition,
                 output_distribution: str = 'normal',
                 contrast_distribution: str = 'lognormal',
                 contrast_dims: int = 1,
                 kl_loss = 'default'):
        super(ContrastiveGenBlock, self).__init__(prior_net, input_id, output_distribution)
        self.prior_net = get_net(prior_net)
        self.posterior_net = get_net(posterior_net)
        self.condition = InputPipeline(condition)
        self.contrast_distribution = contrast_distribution
        self.contrast_dims = contrast_dims
        self.kl_loss = kl_loss

    def generate_concatenated(self, z, z_distribution, contrast_distribution, temperature=None):
        length = z.shape[1]
        mean_values = z[:, :length//2]
        sigma_values = z[:, length//2:]
        z_dims = torch.cat((mean_values[:, :-self.contrast_dims], 
                            sigma_values[:, :-self.contrast_dims]), 
                            dim=1)
        contrast_dims = torch.cat((mean_values[:, -self.contrast_dims:],
                                   sigma_values[:, -self.contrast_dims:]), 
                                   dim=1)
        p = generate_distribution(z_dims, z_distribution, temperature) # z dims 
        q = generate_distribution(contrast_dims, contrast_distribution, temperature) # s (contrast) dims
        return JointDistribution([p, q])

    def _sample(self, x: tensor, cond: tensor, variate_mask=None, use_mean=False) -> (tensor, tuple):
        x_prior = self.prior_net(x)
        prior = self.generate_concatenated(x_prior, self.output_distribution, self.contrast_distribution)
        x_posterior = self.posterior_net(cond)
        posterior = self.generate_concatenated(x_posterior, self.output_distribution, self.contrast_distribution)
        z = posterior.rsample() if not use_mean else posterior.mean

        if variate_mask is not None:
            z_prior = prior.rsample() if not use_mean else prior.mean
            z = self.prune(z, z_prior, variate_mask)

        return z, (prior, posterior, self.kl_loss)

    def _sample_uncond(self, x: tensor, t: float or int = None, use_mean=False) -> tensor:
        x_prior = self.prior_net(x)
        prior = self.generate_concatenated(x_prior, self.output_distribution, self.contrast_distribution, t)
        z = prior.sample() if not use_mean else prior.mean
        return z, (prior, None)

    def forward(self, computed: dict, variate_mask=None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        cond = self.condition(computed)
        z, distributions = self._sample(x, cond, variate_mask, use_mean=use_mean)
        computed['z_posterior_mean'] = distributions[1].mean
        computed['z_posterior_std'] = distributions[1].stddev
        computed['z_posterior_sample'] = distributions[1].sample()
        computed['z_posterior_loc'] = distributions[1].loc
        computed[self.output] = z
        return computed, distributions

    def sample_from_prior(self, computed: dict, t: float or int = None, use_mean=False, **kwargs) -> (dict, tuple):
        x = self.input(computed)
        z, dist = self._sample_uncond(x, t, use_mean=use_mean)
        computed[self.output] = z
        return computed, dist

    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized["contrast_distribution"] = self.contrast_distribution
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        prior_net = serialized["prior_net"]["type"].deserialize(serialized["prior_net"])
        posterior_net = serialized["posterior_net"]["type"].deserialize(serialized["posterior_net"])
        return ContrastiveGenBlock(
            prior_net=prior_net,
            posterior_net=posterior_net,
            input_id=InputPipeline.deserialize(serialized["input"]),
            condition=InputPipeline.deserialize(serialized["condition"]),
            output_distribution=serialized["output_distribution"],
            contrast_distribution=serialized["contrast_distribution"],
        )
    
    def extra_repr(self) -> str:
        return super().extra_repr()