import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td

class Flow(nn.Module):
    """ Normalizing flow """
    def __init__(self, base_dist, transforms):
        super().__init__()
        self.base_dist = base_dist
        self.transforms = nn.ModuleList(transforms)

    def forward(self, x, latent=None, mask=None, *args, **kwargs):
        if mask is None:
            mask = torch.ones_like(x)

        log_jac = torch.zeros_like(x).to(x)
        for transform in self.transforms:
            x, ld = transform(x * mask, latent=latent, mask=mask)
            log_jac += ld * mask
        return x, log_jac

    def inverse(self, x, latent=None, mask=None, *args, **kwargs):
        """
        Returns:
            x: Transformed input (..., dim)
            log_jac: Diagonal logarithm of Jacobian (..., dim)
            mask:
        """
        if mask is None:
            mask = torch.ones_like(x).to(x)

        log_jac = torch.zeros_like(x).to(x)
        for transform in self.transforms[::-1]:
            x, ld = transform.inverse(x * mask, latent=latent, mask=mask)
            log_jac += ld * mask
        return x, log_jac

    def log_prob(self, x, latent=None, mask=None, *args, **kwargs):
        """ Calculates log-probability of a sample with a series
        of invertible transformations.

        Args:
            x: Input (..., dim)
            latent: latent (..., latent dim). All transforms need to know about latent dim.
        Returns:
            log_prob: Log-probability of the input (..., 1)
        """
        x, log_jac = self.inverse(x, latent=latent, mask=mask, *args, **kwargs)
        log_prob = (self.base_dist.log_prob(x) + log_jac).sum(-1, keepdim=True)
        return log_prob

    def sample(self, num_samples, latent=None, mask=None, *args, **kwargs):
        """ Transforms samples from a base distribution to get
        a sample from the distribution defined by normalizing flow

        Args:
            num_samples: (tuple or int) Shape of samples
            latent: Contex for conditional sampling (dim)
        Returns:
            x: Samples from target distribution (*num_samples, dim)
        """
        if isinstance(num_samples, int):
            num_samples = (num_samples,)

        x = self.base_dist.rsample(num_samples)
        x, log_jac = self.forward(x, latent=latent, mask=mask, *args, **kwargs)
        return x
