import nf
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td

class Autoregressive(nn.Module):
    """
    Args:
        dim: Input dimension
        hidden_dim: Size of the hidden layer
        num_autoreg_layers: Number of autoregressive coupling layers
        num_coupling_layers: Number of set and regular coupling layers
        n_bins: Number of bins in spline flows
    """
    def __init__(self, dim, hidden_dim, num_autoreg_layers, num_coupling_layers, n_bins, **kwargs):
        super().__init__()

        self.dim = dim
        self.base_dist = torch.distributions.Uniform(torch.zeros(self.dim), torch.ones(self.dim))
        self.transforms = []

        self.transforms.append(nf.Logit())

        for _ in range(num_autoreg_layers):
            self.transforms.append(nf.AutoregressiveSetAffine(dim, hidden_dim))

        self.transforms.append(nf.Sigmoid())

        for i in range(num_coupling_layers):
            self.transforms.append(nf.Coupling(
                flow=nf.Spline(dim, latent_dim=hidden_dim, n_bins=n_bins, lower=0, upper=1),
                net=nf.net.EquivariantNet(dim, [hidden_dim], hidden_dim),
                mask=nf.mask.parity(even_zero=i%2),
                set_data=True
            ))

            self.transforms.append(nf.Coupling(
                flow=nf.Spline(dim, latent_dim=hidden_dim, n_bins=n_bins, lower=0, upper=1),
                net=nf.net.MLP(dim, [hidden_dim], hidden_dim),
                mask=nf.mask.ordered(right=i%2),
            ))

        self.flow = nf.Flow(self.base_dist, self.transforms)

    def forward(self, x, m, **kwargs):
        log_prob = self.flow.log_prob(x)
        log_n_fact = torch.lgamma(m.sum(-1) + 1)
        log_prob = (log_prob.squeeze(-1) * m).sum(-1) - log_n_fact
        loss = -log_prob.sum() / m.sum()
        return loss

    def sample(self, num_samples):
        z = self.base_dist.sample((num_samples,))
        z = z[z[:,0].sort()[1]].unsqueeze(0)
        return self.flow.forward(z)[0].squeeze(0)
