import nf
import torch
import torch.nn as nn
import torch.distributions as td

class IWAE(nn.Module):
    def __init__(self, dim, hidden_dim, num_layers, n_bins, att_layers, n_heads, **kwargs):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim

        # Attention encoder
        layers = []
        layers.append(nf.net.SelfAttention(dim, hidden_dim, hidden_dim, n_heads))
        for _ in range(att_layers - 2):
            layers.append(nf.net.SelfAttention(hidden_dim, hidden_dim, hidden_dim, n_heads))
            layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nf.net.SelfAttention(hidden_dim, hidden_dim, hidden_dim * 2, n_heads))

        self.att = nn.Sequential(*layers)

        # Decoder: normalizing flow that models p(x | z)
        self.base_dist = td.Uniform(torch.zeros(dim), torch.ones(dim))

        self.transforms = []
        for i in range(num_layers):
            self.transforms.append(nf.Coupling(
                flow=nf.Spline(dim, latent_dim=hidden_dim * 2, n_bins=n_bins, lower=0, upper=1),
                net=nf.net.MLP(dim, [hidden_dim], hidden_dim),
                mask=nf.mask.ordered(right=i%2),
            ))

        self.flow = nf.Flow(self.base_dist, self.transforms)

    def encoder(self, x):
        return self.att(x).max(-2)[0]

    def log_px(self, x, z):
        x = x.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)
        z = z.unsqueeze(-2).repeat_interleave(x.shape[-2], dim=-2)
        log_px = self.flow.log_prob(x, latent=z)
        return log_px

    def forward(self, x, m, num_samples=5, **kwargs):
        # Get q(z) params
        mu, logvar = self.encoder(x).chunk(2, dim=-1)

        # Set the batching params
        max_batch = 256
        if num_samples < max_batch:
            repeats = 1
        else:
            repeats = num_samples // max_batch
            num_samples = max_batch

        # Go through data in batches to lower memory footprint
        loss = []
        for _ in range(repeats):
            # Sample latent z
            z_dist = td.Normal(mu, torch.exp(logvar))
            z = z_dist.rsample((num_samples,))

            # Calculate p(x | z), likelihood
            log_px = self.log_px(x, z)
            log_px = log_px * m.view(1, *m.shape, 1)
            log_px = log_px.sum([-1, -2])

            # Calculate p(z) and q(z), prior and approx. posterior
            log_qz = z_dist.log_prob(z).sum(-1)
            log_pz = td.Normal(0, 1).log_prob(z).sum(-1)

            _loss = log_px + log_pz - log_qz
            loss.append(_loss if self.training else _loss.detach())

        loss = torch.cat(loss, 0)

        # log_mean_exp
        max_val = torch.max(loss, 0)[0]
        loss = max_val + (loss - max_val).exp().mean(0).log()

        loss = -loss.sum() / m.sum()
        return loss

    def sample(self, num_samples):
        z = td.Normal(0, 1).sample((1, 1, self.hidden_dim)).repeat(1, num_samples, 1)
        x = self.flow.sample((1, num_samples), latent=z)
        return x.squeeze(0)
