import nf
import torch
import torch.nn as nn
import torch.nn.functional as F


class AutoregressiveAffine(nn.Module):
    """
    Autoregressive affine layer over the last dimension of data.

    Args:
        dim: Dimension of data
        hidden_dim: Dimension of RNN hidden state
    """
    def __init__(self, dim, hidden_dim, **kwargs):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim

        self.rnn = nn.RNN(input_size=1, hidden_size=hidden_dim, batch_first=True)
        self.proj = nn.Linear(hidden_dim, 2)

    def forward(self, x, latent=None, **kwargs):
        x_shape = x.shape
        x = x.view(-1, x_shape[-1])

        y = torch.zeros(x.shape[0], 1) # First RNN input
        h = torch.zeros(1, x.shape[0], self.hidden_dim) # First state

        ys = []
        log_jac = []
        for i in range(x.shape[1]):
            # RNN
            _, h = self.rnn.forward(y.unsqueeze(-1), h)
            params = self.proj(h).view(*y.shape[:-1], 2)

            # Flow
            log_scale = params[...,:1]
            shift = params[...,1:]

            y = (x[:,i,None] - shift) * (-log_scale).exp()

            ys.append(y)
            log_jac.append(-log_scale)

        ys = torch.cat(ys, 1).view(*x_shape)
        log_jac = torch.cat(log_jac, 1).view(*x_shape)

        return ys, log_jac

    def inverse(self, y, latent=None, **kwargs):
        y_shape = y.shape
        y = y.view(-1, y_shape[-1])

        # y: (..., dim) -> (..., dim, 1) -> run RNN over dim
        rnn_input = F.pad(y, (1, 0))[..., :-1].unsqueeze(-1) # Add zero to beginning

        # RNN
        h0 = torch.zeros(1, y.shape[0], self.hidden_dim) # Set manually just in case
        h, _ = self.rnn(rnn_input, h0)
        h = self.proj(h)

        # Flow
        log_scale = h[...,0]
        shift = h[...,1]

        x = y * log_scale.exp() + shift
        return x.view(*y_shape), log_scale.view(*y_shape)

class AutoregressiveSetAffine(nn.Module):
    """
    Autoregressive affine layer over sets.

    Args:
        dim: Dimension of set element
        hidden_dim: Dimension of RNN hidden state
    """
    def __init__(self, dim, hidden_dim, **kwargs):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim

        self.rnn = nn.RNN(input_size=dim, hidden_size=hidden_dim, batch_first=True)
        self.proj = nn.Linear(hidden_dim, 2 * dim)

    def forward(self, x, latent=None, **kwargs):
        y = torch.zeros(x.shape[0], 1, self.dim) # First RNN input
        h = torch.zeros(1, x.shape[0], self.hidden_dim) # First state

        ys = []
        log_jac = []
        for i in range(x.shape[1]):
            # RNN
            _, h = self.rnn.forward(y, h)
            params = self.proj(h).view(*y.shape[:-1], 2 * self.dim)

            # Flow
            log_scale = params[...,:self.dim]
            shift = params[...,self.dim:]

            y = (x[:,i,None] - shift) * (-log_scale).exp()

            ys.append(y)
            log_jac.append(-log_scale)

        ys = torch.cat(ys, 1)
        log_jac = torch.cat(log_jac, 1)

        return ys, log_jac

    def inverse(self, y, latent=None, **kwargs):
        # y: (..., dim) -> (..., dim, 1) -> run RNN over dim
        rnn_input = F.pad(y, (0, 0, 1, 0))[..., :-1, :] # Add zero padding to set

        # RNN
        h0 = torch.zeros(1, y.shape[0], self.hidden_dim) # Set manually just in case
        h, _ = self.rnn(rnn_input, h0)
        h = self.proj(h)

        # Flow
        log_scale = h[...,:self.dim]
        shift = h[...,self.dim:]

        x = y * log_scale.exp() + shift
        return x, log_scale
