import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GaussianDistribution, AdditiveCouplingLayer, AffineCopulingLayer, ScalingLayer, InvertibleLinear

#Reference : https://github.com/DakshIdnani/pytorch-nice

class NICE(nn.Module):
  def __init__(self, data_dim, hiddem_dim,  num_layers, num_coupling_layers, invlayer):
    super().__init__()

    self.data_dim = data_dim
    self.invlayer = invlayer
    # alternating mask orientations for consecutive coupling layers
    if invlayer == True:
      self.linear_invertible = nn.ModuleList([InvertibleLinear(data_dim) for i in range(num_coupling_layers)])
    self.coupling_layers = nn.ModuleList([AdditiveCouplingLayer(data_dim=data_dim,
                                hidden_dim=hiddem_dim, num_layers=num_layers, idx = i)
                                    for i in range(num_coupling_layers)])

    self.scaling_layer = ScalingLayer(data_dim=data_dim)

    self.prior = GaussianDistribution()

  def forward(self, x):
    z, log_det_jacobian = self.f(x)
    log_likelihood = self.prior.log_prob(z) + log_det_jacobian
    return z, log_likelihood
  
  def volume_extract(self, x):
    z, log_det_jacobian = self.f(x)
    return log_det_jacobian

  def f(self, x):
    z = x
    log_det_jacobian = torch.zeros(x.shape[0]).to(x.device)
    if self.invlayer == True:
      for i, (linear_inverible, coupling_layer) in enumerate(zip(self.linear_invertible,self.coupling_layers)):
        z, log_det_jacobian = linear_inverible(z, log_det_jacobian)
        z, log_det_jacobian = coupling_layer(z, log_det_jacobian)
    else:
      for i, coupling_layer in enumerate(self.coupling_layers):
        z, log_det_jacobian = coupling_layer(z, log_det_jacobian)
    z, log_det_jacobian = self.scaling_layer(z, log_det_jacobian)
    return z, log_det_jacobian

  def inverse(self, z):
    x = z
    self.dummy = torch.zeros(x.shape[0]).to(x.device)
    x, logdet = self.scaling_layer.inverse(x, self.dummy)
    
    if self.invlayer == True:
      for i, (linear_invertible,coupling_layer) in reversed(list(enumerate(zip(self.linear_invertible, self.coupling_layers)))):
        x, logdet = coupling_layer.inverse(x, logdet)
        x, logdet = linear_invertible(x, logdet, True)
    else:
      for i, coupling_layer in reversed(list(enumerate(self.coupling_layers))):
        x, logdet = coupling_layer.inverse(x, logdet)
    return x

  def sampling(self, num_samples):
    z = self.prior.sample(num_samples, self.data_dim)
    return self.f_inverse(z)

class RealNVP(nn.Module):
  def __init__(self, data_dim, hiddem_dim,  num_layers, num_coupling_layers, invlayer):
      super().__init__()

      self.data_dim = data_dim
      self.invlayer = invlayer
      
      if invlayer == True:
        self.linear_invertible = nn.ModuleList([InvertibleLinear(data_dim) for i in range(num_coupling_layers)])   
      self.coupling_layers = nn.ModuleList([AffineCopulingLayer(data_dim=data_dim,
                                  hidden_dim=hiddem_dim, num_layers=num_layers, idx = i)
                                      for i in range(num_coupling_layers)])

      self.prior = GaussianDistribution()  

  def forward(self, x):
    z, log_det_jacobian = self.f(x)
    log_likelihood = self.prior.log_prob(z) + log_det_jacobian
    return z, log_likelihood

  def volume_extract(self, x):
    z, log_det_jacobian = self.f(x)
    return log_det_jacobian

  def f(self, x):
    z = x
    log_det_jacobian = torch.zeros(x.shape[0]).to(x.device)
    if self.invlayer == True:
      for i, (linear_inverible, coupling_layer) in enumerate(zip(self.linear_invertible,self.coupling_layers)):
        z, log_det_jacobian = linear_inverible(z, log_det_jacobian)
        z, log_det_jacobian = coupling_layer(z, log_det_jacobian)
    else:
      for i, coupling_layer in enumerate(self.coupling_layers):
        z, log_det_jacobian = coupling_layer(z, log_det_jacobian)
    return z, log_det_jacobian

  def inverse(self, z):
    x = z
    logdet = torch.zeros(x.shape[0]).to(x.device)
    
    if self.invlayer == True:
      for i, (linear_invertible,coupling_layer) in reversed(list(enumerate(zip(self.linear_invertible, self.coupling_layers)))):
        x, logdet = coupling_layer.inverse(x, logdet)
        x, logdet = linear_invertible(x, logdet, True)
    else:
      for i, coupling_layer in reversed(list(enumerate(self.coupling_layers))):
        x, _ = coupling_layer.inverse(x, logdet)
    return x

  def sampling(self, num_samples):
    z = self.prior.sample(num_samples, self.data_dim)
    return self.f_inverse(z)