import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Distribution


#Reference : https://github.com/DakshIdnani/pytorch-nice

   

    
class AdditiveCouplingLayer(nn.Module):

  def __init__(self, data_dim, hidden_dim, num_layers, idx):
    super().__init__()

    def init_normal(m):
      if type(m) == nn.Linear:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)


    self.idx = idx

    self.act = nn.LeakyReLU()

  
    modules = [nn.Linear(int(data_dim/2),  hidden_dim), self.act]
    for i in range(num_layers - 2):
      modules.append(nn.Linear(hidden_dim, hidden_dim))
      modules.append(self.act)
    modules.append(nn.Linear(hidden_dim, int(data_dim/2)))

    self.m = nn.Sequential(*modules)



  def forward(self, x, logdet):
    x1, x2 = x[:,:int(x.shape[1]/2)], x[:,int(x.shape[1]/2):]
    if self.idx % 2 == 0:
      y1, y2 = x1, x2 + self.m(x1)
    else:
      y1, y2 = x1 + self.m(x2), x2

    return torch.cat((y1,y2), axis=1), logdet

  def inverse(self, x, logdet):
    y1, y2 = x[:,:int(x.shape[1]/2)], x[:,int(x.shape[1]/2):]
    if self.idx % 2 == 0:
      x1, x2 = y1, y2 - (self.m(y1))
    else:
      x1, x2 = y1 - (self.m(y2)), y2
    return torch.cat((x1,x2), axis=1), logdet


class AffineCopulingLayer(nn.Module):

  def __init__(self, data_dim, hidden_dim, num_layers, idx):
    super().__init__()

    def init_normal(m):
      if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias)

    self.idx = idx 

    self.act = nn.LeakyReLU()

    modules = [nn.Linear(int(data_dim/2), hidden_dim), self.act]
    for i in range(num_layers - 2):
      modules.append(nn.Linear(hidden_dim, hidden_dim))
      modules.append(nn.LeakyReLU())
    modules.append(nn.Linear(hidden_dim, int(data_dim/2)))
    modules.append(nn.Tanh())

    self.s = nn.Sequential(*modules)

    modules = [nn.Linear(int(data_dim/2), hidden_dim), self.act]
    for i in range(num_layers - 2):
      modules.append(nn.Linear(hidden_dim, hidden_dim))
      modules.append(nn.LeakyReLU())
    modules.append(nn.Linear(hidden_dim, int(data_dim/2)))

    self.t = nn.Sequential(*modules)

  def forward(self, x, logdet):
    x1, x2 = x[:,:int(x.shape[1]/2)], x[:,int(x.shape[1]/2):]
    if self.idx % 2 == 0:
      scale = torch.exp(self.s(x1))
      y1, y2 = x1, x2 * scale + self.t(x1)
    else:
      scale = torch.exp(self.s(x2))
      y1, y2 = x1 * scale + self.t(x2), x2

    logdet += torch.log(scale).sum(axis=1)
    return torch.cat((y1,y2), axis=1), logdet

  def inverse(self, x, logdet):
    y1, y2 = x[:,:int(x.shape[1]/2)], x[:,int(x.shape[1]/2):]
    if self.idx % 2 == 0:
      scale = torch.exp(-self.s(y1))
      x1, x2 = y1, (y2 - (self.t(y1))) * scale
    else:
      scale = torch.exp(-self.s(y2))
      x1, x2 = (y1-self.t(y2)) * scale, y2

    logdet -= torch.log(scale).sum(axis=1)
    return torch.cat((x1,x2), axis=1), logdet

class ScalingLayer(nn.Module):
  def __init__(self, data_dim):
    super().__init__()
    self.log_scale_vector = nn.Parameter(torch.randn(1, data_dim, requires_grad=True))

  def forward(self, x, logdet):
    log_det_jacobian = torch.sum(self.log_scale_vector)
    return torch.exp(self.log_scale_vector) * x, logdet + log_det_jacobian

  def inverse(self, x, logdet):
      log_det_jacobian = torch.sum(self.log_scale_vector)
      return torch.exp(- self.log_scale_vector) * x, logdet - log_det_jacobian
  

#Reference : https://github.com/y0ast/Glow-PyTorch/blob/master/model.py
class InvertibleLinear(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        w_shape = [num_channels, num_channels]
        w_init = torch.qr(torch.randn(*w_shape))[0]
        self.weight = nn.Parameter(torch.Tensor(w_init))
        self.w_shape = w_shape

    def get_weight(self, input, reverse):
        b, c = input.shape

        dlogdet = torch.slogdet(self.weight)[1]
        if reverse:
            weight = torch.inverse(self.weight)
        else:
            weight = self.weight

        return weight.to(input.device), dlogdet.to(input.device)

    def forward(self, input, logdet=None, reverse=False):
        weight, dlogdet = self.get_weight(input, reverse)
        if not reverse:
            z = F.linear(input, weight)
            if logdet is not None:
                logdet = logdet + dlogdet
            return z, logdet
        else:
            z = F.linear(input, weight)
            if logdet is not None:
                logdet = logdet - dlogdet
            return z, logdet

class GaussianDistribution(Distribution):
  def __init__(self):
    super().__init__()

  def log_prob(self, z):
    return -(0.5*torch.sum(torch.pow(z,2),dim=1) + z.size(1)*0.5*torch.log(torch.tensor(2*np.pi)))

  def sample(self, latent_size, sample_size):
    mu = np.zeros(latent_size)
    cov = np.eye(latent_size)
    samples = np.random.multivariate_normal(mu, cov, sample_size)
    return torch.tensor(samples)

