import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.autograd as autograd
import numpy as np
import scipy
import math
import time
import optimizer
import estimators.layers as layers
from copy import deepcopy
from nde.MAF import MAF
from nde.NGC import GaussianCopula


class NAE3(nn.Module):
    """ 
        Neural Adaptive Mutual Information Estimate, ver. 3 (==copula + MRE)
    """
    def __init__(self, architecture_encoder_x, architecture_encoder_y, architecture_critic, hyperparams):
        super().__init__()

        # default hyperparameters 
        self.bs = 500 if not hasattr(hyperparams, 'bs') else hyperparams.bs 
        self.lr = 5e-4 if not hasattr(hyperparams, 'lr') else hyperparams.lr
        self.wd = 0e-5 if not hasattr(hyperparams, 'wd') else hyperparams.wd
        self.n_neg = 4 if not hasattr(hyperparams, 'n_neg') else hyperparams.n_neg
        self.encode_x = False if not hasattr(hyperparams, 'encode_x') else hyperparams.encode_x
        self.encode_y = False if not hasattr(hyperparams, 'encode_y') else hyperparams.encode_y
        self.K = 4 if not hasattr(hyperparams, 'n_bridges') else hyperparams.n_bridges
        self.bridge_mode = 'intrapolation' if not hasattr(hyperparams, 'bridge_mode') else hyperparams.bridge_mode
        self.max_iteration = 3000
        
        # layers
        self.encode_layer = None
        self.encode2_layer = None
        self.critic_layer = layers.SoftmaxLayer(architecture_critic, self.K+2, hyperparams)
        
    def encode(self, x):
        # s = s(x), get the summary statistic of x
        return self.encode_layer(x) if self.encode_x else x
    
    def encode2(self, y):
        # theta = h(y), get the representation of y
        return self.encode_layer(y) if self.encode_y else y
    
    def _generate_bridge_samples(self, samples_P, samples_Q, k):   
        # generate samples ~ P', Q', where KL(P', Q') < KL(P, Q)
        a1 = 1.0/(self.K+1)*k                                     # when k=0, P'=P; when k=K+1, P'=Q 
        n_P, n_Q = len(samples_P), len(samples_Q)
        if self.bridge_mode == 'mixture':
            idx_P1, idx_Q1 = torch.randperm(n_P)[:int(n_P*(1-a1))], torch.randperm(n_Q)[:int(n_Q*a1)]
            samples_new_P = torch.cat([samples_P[idx_P1]] + [samples_Q[idx_Q1]], dim=0)    
        if self.bridge_mode == 'intrapolation':
            samples_new_P = samples_P*(1-a1) + samples_Q*a1
        return samples_new_P
    
    def MI(self, x, y, mode='mc'):
        with torch.no_grad():
            latent = self.gc.forward(data = torch.cat([x, y], dim=1))
            n, d = x.size()
            v, w = latent[:, 0:d], latent[:, d:]
            return self.log_ratio(v, w).mean().item() if mode=='mc' else self.kv(v, w, self.log_ratio).item()
           
    def log_ratio(self, x, y):
        z, y = self.encode(x), self.encode2(y)
        zy = torch.cat([z, y], dim=1)
        softmax_score = self.critic_layer(zy)
        lr = softmax_score[:, 0] - softmax_score[:, -1]
        return lr
    
    def objective_func(self, x, y):
        # compute representation of z, y
        z, y = self.encode(x), self.encode2(y)
        m, d = x.size()
        
        # construct samples P = p(x,y), Q = p(x)p(y)
        idx_pos = []
        idx_neg1 = []
        idx_neg2 = []
        for i in range(self.n_neg): 
            idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
            idx_neg1 = idx_neg1 + torch.randperm(m).cpu().numpy().tolist()
            idx_neg2 = idx_neg2 + torch.randperm(m).cpu().numpy().tolist()
        zy_P = torch.cat([z[idx_pos], y[idx_pos]], dim=1)
        zy_Q = torch.cat([z[idx_neg1], y[idx_neg2]], dim=1)

        # generate bridge + normal samples
        xy_array, label_array = [], []
        with torch.no_grad():
            for k in range(self.K+2):
                xy_k = self._generate_bridge_samples(zy_P, zy_Q, k)            # class 0 is P, class K+1 is Q
                label_k = torch.zeros(len(zy_P)).to(zy_P.device).long() + k
                xy_array.append(xy_k)
                label_array.append(label_k)
        xy_array, label_array = torch.cat(xy_array, dim=0), torch.cat(label_array, dim=0)

        # optimize softmax score
        softmax_score = self.critic_layer(xy_array)
        return -F.cross_entropy(softmax_score, label_array)
    
    def kv(self, x, y, critic):
        m, d = x.size()
        z, y = self.encode(x), self.encode2(y)
        idx_pos = []
        idx_neg = []
        n_neg = self.n_neg if self.training else min(m, 50)
        for i in range(n_neg): 
            idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
            idx_neg = idx_neg + torch.randperm(m).cpu().numpy().tolist()
            
        f_pos = critic(x=z[idx_pos], y=y[idx_pos])
        f_neg = critic(x=z[idx_pos], y=y[idx_neg])
        mi = f_pos.mean() - (f_neg.exp().mean()+1e-30).log()
        return mi
            
    def learn(self, x, y):
        n, d = x.size()
        
        # A. learn GC
        gc = GaussianCopula()
        gc.learn(x, y, marginal_only=True)
        self.gc = gc

        # B. forward to get latent
        with torch.no_grad():
            latent = gc.forward(data = torch.cat([x, y], dim=1))
            v, w = latent[:, 0:d], latent[:, d:]
            
        # C. estimate MI on latent
        optimizer.NNOptimizer.learn(self, v, w)

    