import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.autograd as autograd
import numpy as np
import scipy
import math
import time
import optimizer
import estimators.layers as layers
from copy import deepcopy
from nde.MAF import MAF
from nde.FMGGC import FMGGC


class NAE4(nn.Module):
    """ 
        Neural Adaptive Mutual Information Estimate, ver 4 (== two KL method)
    """
    def __init__(self, architecture_encoder_x, architecture_encoder_y, architecture_critic, hyperparams):
        super().__init__()

        # default hyperparameters 
        self.bs = 500 if not hasattr(hyperparams, 'bs') else hyperparams.bs 
        self.lr = 5e-4 if not hasattr(hyperparams, 'lr') else hyperparams.lr
        self.wd = 0e-5 if not hasattr(hyperparams, 'wd') else hyperparams.wd
        self.n_neg = 4 if not hasattr(hyperparams, 'n_neg') else hyperparams.n_neg
        self.encode_x = False if not hasattr(hyperparams, 'encode_x') else hyperparams.encode_x
        self.encode_y = False if not hasattr(hyperparams, 'encode_y') else hyperparams.encode_y
        self.K = 4 if not hasattr(hyperparams, 'n_bridges') else hyperparams.n_bridges
        self.bridge_mode = 'intrapolation' if not hasattr(hyperparams, 'bridge_mode') else hyperparams.bridge_mode
        self.max_iteration = 3000
        
        # layers
        self.encode_layer = None
        self.encode2_layer = None
        self.critic_layer = layers.SoftmaxLayer(architecture_critic, self.K+2, hyperparams)
        print('NAE4')
        
    def encode(self, x):
        # s = s(x), get the summary statistic of x
        return self.encode_layer(x) if self.encode_x else x
    
    def encode2(self, y):
        # theta = h(y), get the representation of y
        return self.encode_layer(y) if self.encode_y else y
    
    def set_nde(self, nde1, nde2):
        mu1, mu2 = nde1.loc, nde2.loc
        cov1, cov2 = nde1.covariance_matrix, nde2.covariance_matrix
        self.nde_array = [None]
        for k in range(self.K):
            a = 1.0/(self.K-1)*k  
            mu, cov = (1-a)*mu1 + a*mu2, (1-a)*cov1 + a*cov2
            nde = torch.distributions.multivariate_normal.MultivariateNormal(mu, cov)
            self.nde_array.append(nde)
        self.nde_array.append(None)
   
    def _generate_bridge_samples(self, samples_P, samples_Q, k):   
        # generate samples ~ P', Q', where KL(P', Q') < KL(P, Q)
        a1 = 1.0/(self.K+1)*k                                         # when k=0, P'=P; when k=K+1, P'=Q 
        n_P, n_Q = len(samples_P), len(samples_Q)
        if k==0:
            return samples_P
        elif k==self.K+1:
            return samples_Q
        else:
            return self.nde_array[k].sample([n_P])
    
    def MI(self, x, y, mode='2KL'):
        with torch.no_grad(): 
            v, w = self.gc.forward(x, y)
            if mode == '2KL':
                kl1 = self.kl1(v, w)
                kl2 = self.kl2(v, w)
                kl_analytic = self.kl_analytic(v, w)
                return kl1 - kl2 + kl_analytic
            else:
                return self.log_ratio(v, w).mean().item() 
            
    def kl1(self, v, w):
        # KL[p(x, y), q(x,y)]
        with torch.no_grad(): 
            return self.log_ratio(v, w, 0, 1).mean().item()
        
    def kl2(self, v, w):
        # KL[p(x)p(y), q(x)q(y)]
        with torch.no_grad(): 
            m, d = v.size()
            idx_pos = []
            idx_neg = []
            for i in range(50): 
                idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
                idx_neg = idx_neg + torch.randperm(m).cpu().numpy().tolist()
            vv, ww = v[idx_pos], w[idx_neg]
            return self.log_ratio(vv, ww, -1, -2).mean().item()
        
    def kl_analytic(self, v, w):
        # KL[q(x, y), q(x)q(y)]
        with torch.no_grad(): 
            normal1, normal2 = self.nde_array[1], self.nde_array[-2]
            vw = torch.cat([v, w], dim=1)
            A = normal1.log_prob(vw)
            B = normal2.log_prob(vw)
            return (A-B).mean().item()
        
    
    def log_ratio(self, x, y, i=0, j=-1):
        z, y = self.encode(x), self.encode2(y)
        zy = torch.cat([z, y], dim=1)
        softmax_score = self.critic_layer(zy)
        lr = softmax_score[:, i] - softmax_score[:, j]
        return lr
    
    def objective_func(self, x, y):
        # compute representation of z, y
        z, y = self.encode(x), self.encode2(y)
        m, d = x.size()
        
        # construct samples P = p(x,y), Q = p(x)p(y)
        idx_pos = []
        idx_neg1 = []
        idx_neg2 = []
        for i in range(self.n_neg): 
            idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
            idx_neg1 = idx_neg1 + torch.randperm(m).cpu().numpy().tolist()
            idx_neg2 = idx_neg2 + torch.randperm(m).cpu().numpy().tolist()
        zy_P = torch.cat([z[idx_pos], y[idx_pos]], dim=1)
        zy_Q = torch.cat([z[idx_neg1], y[idx_neg2]], dim=1)

        # generate bridge + normal samples
        xy_array, label_array = [], []
        with torch.no_grad():
            for k in range(self.K+2):
                xy_k = self._generate_bridge_samples(zy_P, zy_Q, k)            # class 0 is P, class K+1 is Q
                label_k = torch.zeros(len(zy_P)).to(zy_P.device).long() + k
                xy_array.append(xy_k)
                label_array.append(label_k)
        xy_array, label_array = torch.cat(xy_array, dim=0), torch.cat(label_array, dim=0)

        # optimize softmax score
        softmax_score = self.critic_layer(xy_array)
        return -F.cross_entropy(softmax_score, label_array)
    
      
    def learn(self, x, y):
        # A. learn GGC
        gc = self.learn_nde(x, y)
        self.gc = gc
        self.gc_state_dict = deepcopy(gc.state_dict())
        self.set_nde(gc.normal, gc.normal2)
        # B. forward to get latent
        with torch.no_grad():
            v, w = gc.forward(x, y)
            v, w = v.clone().detach(), w.clone().detach()
        # C. estimate MI on latent
        optimizer.NNOptimizer.learn(self, v, w)
        # D. load previously save GGC state_dict
        self.gc.load_state_dict(self.gc_state_dict)
    
    def learn_nde(self, x, y):
        n, d = x.size()
    
        gc = FMGGC(n_inputs=d)
        gc.to(x.device)
        gc.learn(x, y)
        return gc

