"""
Masked Autoregressive Flow for Density Estimation
arXiv:1705.07057v4
"""
import torch.nn as nn
from normalizing_flows.maf import MAF


def reshape(orig_shape, t):
    B = orig_shape[0]
    new_C = t.shape[1]
    return t.view(B, -1, new_C).permute(0, 2, 1).view(B, new_C, *orig_shape[2:])

def unshape(t):
    return t.flatten(start_dim=2).permute(0, 2, 1).flatten(end_dim=1)


class ConditionalFlow(nn.Module):
    """
    My wrapper for conditional MAF.
    """
    def __init__(self, input_size, cond_size, n_blocks=10, hidden_size=100,
                 n_hidden=1, batch_norm=False ):
        super().__init__()
        self.input_size = input_size
        self.cond_size = cond_size
        self.flow = MAF(n_blocks=n_blocks, input_size=input_size,
                        hidden_size=hidden_size, n_hidden=n_hidden,
                        cond_label_size=cond_size, batch_norm=batch_norm)

    def sample_trailing_dims(self, cond, return_logprob=False):
        """
        equivalent of sample when cond has dimension (B, C, ...) instead of just (B, C)
        """
        orig_shape = cond.shape
        cond = unshape(cond)
        r = self.sample(cond, return_logprob=return_logprob)
        if not return_logprob:
            return reshape(orig_shape, r)
        else:
            x = reshape(orig_shape, r[0])
            lp = reshape(orig_shape, r[1].unsqueeze(1)).squeeze(1)
            return x, lp


    def log_prob_trailing_dims(self, x, cond):
        orig_shape_x = x.shape
        r = self.log_prob(unshape(x), unshape(cond))
        r = r.unsqueeze(1)  # add back a channel dimension to make logic work
        r = reshape(orig_shape_x, r)
        r = r.squeeze(1)  # remove channel dimension again
        return r

    def sample(self, cond, return_logprob=False):
        # B - batch size
        B, _cond_size = cond.shape
        assert _cond_size == self.cond_size
        # C = self.flow.mademog.n_components
        u = self.flow.base_dist.sample((B,))
        log_pu = self.flow.base_dist.log_prob(u).sum(dim=-1)
        assert u.shape == (B, self.input_size)
        x, sum_log_abs_det_jacobians = self.flow.inverse(u=u, y=cond)

        # it is weird that sum_log_abs_det_jacobians has a dimension of size
        # `input_size`. I think summing over this gives the correct
        # determinant.
        sum_log_abs_det_jacobians = sum_log_abs_det_jacobians.sum(dim=-1)
        log_px = log_pu - sum_log_abs_det_jacobians

        if return_logprob:
            return x, log_px
        else:
            return x

    def log_prob(self, x, cond):
        return self.flow.log_prob(x, cond)

    def entropy_trailing_dims(self, cond, n_samples):
        cond = cond.unsqueeze(-1).expand(*cond.shape, n_samples)
        _, lp = self.sample_trailing_dims(cond, return_logprob=True)
        return -lp.mean(dim=-1)
