import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.distributions as distribution
import math
import numpy as np
import time
from utils import utils_math
from scipy import stats
import scipy.special as special
from scipy.stats import binom
from scipy.stats import norm
from nde import MDN
import optimizer
from copy import deepcopy

        
                
class NGC(nn.Module):
    """ 
        Neural Gaussian copula 
    """
    def __init__(self, dim_data, n_hidden, dim_cond, K=1):
        super(NGC, self).__init__()
        self.bs = 100
        self.lr = 5e-4
        self.wd = 1e-5
        self.bn = BNLayer(dim_data)
        self.mdn = MDN(n_in=dim_cond, n_hidden=n_hidden, n_out=dim_data, K=1)
        self.magrinal_layer = MarginalLayer(dim_x=dim_data, dim_cond=dim_cond, L=1)      # <-- only L=1 works?
    
    def log_probs(self, inputs, cond_inputs, marginals=None):
        inputs, logdet0 = self.bn(inputs)
        u, logdet = self.magrinal_layer(inputs, cond_inputs)                             # x, log det|dz/dx|
        log_probs = self.mdn.log_probs(u, cond_inputs)                                   # p(x)
        return log_probs.view(-1) - logdet.view(-1) - logdet0.view(-1)                   # p(z) = p(x) - log det|dz/dx| 
    
    def log_probs_marginal(self, inputs, cond_inputs):
        inputs, logdet0 = self.bn(inputs)
        u, logdet = self.magrinal_layer(inputs, cond_inputs)                             # z, det|dx/dz|
        log_base_prob = (-0.5 * u.pow(2) - 0.5 * math.log(2 * math.pi)).sum(dim=1, keepdim=True)
        return log_base_prob.view(-1) - logdet.view(-1) - logdet0.view(-1)     

    def objective_func(self, inputs, cond_inputs):
        return 0.75*self.log_probs_marginal(inputs, cond_inputs).mean() + 0.25*self.log_probs(inputs, cond_inputs).mean()

    def learn(self, inputs, cond_inputs):
        self.bn.learn(inputs)
        optimizer.NNOptimizer.learn(self, inputs, cond_inputs)
    
    def print(self, cond_inputs):
        mu, V = self.mdn.params(cond_inputs)
        print('mu', mu)
        print('V', (V*100).int()/100.0)
        
# ----------------------------------------------------------------------------------------------------------------------------- #
    

class BNLayer(nn.Sequential):

    def __init__(self, dim_x):
        super().__init__()

    def forward(self, z):
        x = (z - self.mu)/self.std
        dzdx = self.std
        log_dzdx = torch.log(dzdx).sum(dim=1, keepdim=True)
        return x, log_dzdx 
    
    def learn(self, z):
        self.mu = z.mean(dim=0, keepdim=True)
        self.std = z.std(dim=0, keepdim=True)


    
class MarginalLayer(nn.Sequential):
    
    # compositing L marginal block
    
    def __init__(self, dim_x, dim_cond, L):
        module = []
        self.L = L
        for _ in range(L): module += [MarginalBlock(dim_x, dim_cond)]
        super().__init__(*module)
        
    def forward(self, inputs, cond_inputs, marginals=None):
        self.num_inputs = inputs.size(-1)
        sum_logdet = torch.zeros(inputs.size(0), 1, device=inputs.device)
        for module in self._modules.values():
            inputs, logdet = module.forward(inputs, cond_inputs, marginals)
            sum_logdet += logdet
        return inputs, logdet
    
    
class ContraintBlock(nn.Module):
    
    # apply the contraints on x
    
    def __init__(self, dim_x, x_ranges):
        super(ContraintBlock, self).__init__()
        self.x_ranges = x_ranges
        self.dummy = nn.Linear(dim_x, dim_x)

    def _tanh_derivative(self, a):
        return 1-torch.tanh(a)**2
        
    def forward(self, x):
        # y
        x = self.dummy(x)*0 + x
        y = torch.tanh(x)         # y in [-1~1]
        y = y/2 + 0.5             # y in [0, 1]
        y = y*0.995 + 0.0025
        lower = self.x_ranges[:, 0].view(1, -1)
        upper = self.x_ranges[:, 1].view(1, -1)
        y = lower + y*(upper-lower)
        # dy/dx
        det = self._tanh_derivative(x)/2
        det = det*0.995
        det = det*(upper-lower)
        log_det_dydx = (det.abs()+1e-12).log()
        log_det_dxdy = -log_det_dydx.sum(dim=1, keepdim=True)
        return y, log_det_dxdy
        
        
class MarginalBlock(nn.Module):
    
    # compute v = f(x)
    
    def __init__(self, dim_x, dim_cond, S=10):
        super(MarginalBlock, self).__init__()
        self.dim_x = dim_x
        self.S = S
        self.main_A = nn.Sequential(
            nn.Linear(dim_cond, 100),
            nn.Tanh(),
            nn.Linear(100, S*dim_x)
        )
        self.main_B = nn.Sequential(
            nn.Linear(dim_cond, 100),
            nn.Tanh(),
            nn.Linear(100, S*dim_x)
        )
        self.main_C = nn.Sequential(
            nn.Linear(dim_cond, 100),
            nn.Tanh(),
            nn.Linear(100, S*dim_x)
        )
        self.main_V = nn.Sequential(
            nn.Linear(dim_cond, 100),
            nn.Tanh(),
            nn.Linear(100, dim_x)
        )
        self.main_E = nn.Sequential(
            nn.Linear(dim_cond, 100),
            nn.Tanh(),
            nn.Linear(100, dim_x)
        )
        
    def ABC(self, cond):
        A = self.main_A(cond)
        B = self.main_B(cond)
        C = self.main_C(cond)
        V = self.main_V(cond)
        E = self.main_E(cond)
        return F.softplus(A), F.softplus(B), C, F.softplus(V), E

    def _tanh_derivative(self, a):
        return 1-torch.tanh(a)**2
    
    def forward(self, z, cond, marginals=None):
        A, B, C, V, E = self.ABC(cond)
        n, D, S = len(A), self.dim_x, self.S
        A, B, C, V, E = A.reshape(n, D, S), B.reshape(n, D, S), C.reshape(n, D, S), V.reshape(n, D), E.reshape(n, D)
        # x
        z = z.reshape(n, D, 1)
        v = B*(z+C)
        x = A*torch.tanh(v)                   # <-- n*D*S
        x = x.sum(dim=2)                      # <-- n*D
        x = x + V*z.view(n, D) + E
        # dx/dz
        det = A*self._tanh_derivative(v)*B    # <-- n*D*S
        det = det.sum(dim=2)                  # <-- n*D
        det = det + V
        log_det_dxdz = det.abs().log()
        log_det_dzdx = -log_det_dxdz.sum(dim=1, keepdim=True)
        return x, log_det_dzdx





# ----------------------------------------------------------------------------------------------------------------------------- #
    



class GaussianCopula(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def learn(self, x, y, marginal_only=False):
        xy = torch.cat([x, y], dim=1)
        n, d = xy.size()
        # calculate the latent Z
        z, sorted_xy = self.forward(xy, return_sorted_data=True)
        if marginal_only:
            return 
        V = torch.matmul(z.t(), z)/(len(z)+1)
        A = torch.cholesky(V, upper=False)
        A_t_inv = torch.inverse(A.t())
        # calculate e in z = Ae
        eps = torch.matmul(z, A_t_inv)
        # assign values
        self.V = V
        self.V2 = torch.eye(d).to(x.device)
        self.V2[0:d//2, 0:d//2] = self.V[0:d//2, 0:d//2]
        self.V2[d//2:, d//2:] = self.V[d//2:, d//2:]
        self.Vx = self.V[0:d//2, 0:d//2]
        self.Vy = self.V[d//2:, d//2:]
        self.V_inv, self.Vx_inv, self.Vy_inv = torch.inverse(self.V), torch.inverse(self.Vx), torch.inverse(self.Vy)
        self.sorted_xy = sorted_xy
        self.normal = distribution.multivariate_normal.MultivariateNormal(torch.zeros(d).to(x.device), self.V)
        self.normal2 = distribution.multivariate_normal.MultivariateNormal(torch.zeros(d).to(x.device), self.V2)
    
    def forward(self, data, return_sorted_data=False):
        # calculate empirical CDF
        sorted_data, idx = torch.sort(data, dim=0)
        _, idx2 = torch.sort(idx, dim=0)
        u = (idx2.float()+1)/(len(data)+1)    
        zeros, ones = torch.zeros(data.size()).to(data.device), torch.ones(data.size()).to(data.device)
        normal = distribution.Normal(zeros, ones)
        # calculate the latent Z
        z = normal.icdf(u)
        if return_sorted_data:
            return z, sorted_data
        else:
            return z
            
    def sample(self, n=10000, inner=True):
        # some preparation
        sorted_xy = self.sorted_xy
        N, D = sorted_xy.size()
        #sample z ~ N(0, V)
        mvn = distribution.multivariate_normal.MultivariateNormal(torch.zeros(D).to(sorted_xy.device), self.V)
        z = mvn.rsample([N])
        # early exit
        if inner==True:
            return z[0:n, :]
        else:
            return None
        # convert z to u
        normal = distribution.Normal(torch.zeros(N, D).to(sorted_xy.device), torch.ones(N, D).to(sorted_xy.device))
        u = normal.cdf(z).clamp(0.00001, 0.99999)
        # convert u to idx
        idx = (N*u).long()
        # idx to x
        x = torch.zeros(N, D).to(sorted_xy.device)
        for d in range(D):
            idx_d = idx[:, d]
            sorted_x_d = sorted_xy[:, d]
            x_d = sorted_x_d[idx_d]
            x[:, d] = x_d
        return x[0:n, :]
    
    @staticmethod
    def log_copula_density(z, V):
        log_det_V = torch.logdet(V)
        d, d = V.size()
        device = z.device
        V_inv = torch.inverse(V)
        inside_exp = torch.diag(z@V_inv@z.t())
        return -0.5*inside_exp - 0.5*log_det_V
    
    def KL_joint_marginal(self, x, y):                                     # E[log q(x ,y)/q(x)q(y)]
        xy = torch.cat([x, y], dim=1)
        log_copula_density_xy = GaussianCopula.log_copula_density(xy, self.V)
        log_copula_density_x = GaussianCopula.log_copula_density(x, self.Vx)
        log_copula_density_y = GaussianCopula.log_copula_density(y, self.Vy)
        mi = log_copula_density_xy - log_copula_density_x - log_copula_density_y
        return mi.mean()