import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.autograd as autograd
import numpy as np
import scipy
import math
import time
import optimizer
import estimators.layers as layers
from copy import deepcopy
from nde.MAF import MAF
from nde.FMGGC import FMGGC
from nde.GGC import GGC


class NAE(nn.Module):
    """ 
        Neural Adaptive Mutual Information Estimate
    """
    def __init__(self, architecture_encoder_x, architecture_encoder_y, architecture_critic, hyperparams):
        super().__init__()

        # default hyperparameters 
        self.bs = 500 if not hasattr(hyperparams, 'bs') else hyperparams.bs 
        self.lr = 5e-4 if not hasattr(hyperparams, 'lr') else hyperparams.lr
        self.wd = 0e-5 if not hasattr(hyperparams, 'wd') else hyperparams.wd
        self.n_neg = 4 if not hasattr(hyperparams, 'n_neg') else hyperparams.n_neg
        self.encode_x = False if not hasattr(hyperparams, 'encode_x') else hyperparams.encode_x
        self.encode_y = False if not hasattr(hyperparams, 'encode_y') else hyperparams.encode_y
        self.K = 4 if not hasattr(hyperparams, 'n_bridges') else hyperparams.n_bridges
        self.bridge_mode = 'intrapolation' if not hasattr(hyperparams, 'bridge_mode') else hyperparams.bridge_mode
        self.max_iteration = 3000
        
        # layers
        self.encode_layer = None
        self.encode2_layer = None
        self.critic_layer = layers.SoftmaxLayer(architecture_critic, self.K+2, hyperparams)
        
        print('self.K', self.K)
        
    def encode(self, x):
        # s = s(x), get the summary statistic of x
        return self.encode_layer(x) if self.encode_x else x
    
    def encode2(self, y):
        # theta = h(y), get the representation of y
        return self.encode_layer(y) if self.encode_y else y
    
    def set_nde(self, nde1, nde2):
        mu1, mu2 = nde1.loc, nde2.loc
        cov1, cov2 = nde1.covariance_matrix, nde2.covariance_matrix
        self.nde_array = [None]
        for k in range(self.K):
            a = 1.0/(self.K-1)*k  
            mu, cov = (1-a)*mu1 + a*mu2, (1-a)*cov1 + a*cov2
            nde = torch.distributions.multivariate_normal.MultivariateNormal(mu, cov)
            self.nde_array.append(nde)
        self.nde_array.append(None)
   
    def _generate_bridge_samples(self, samples_P, samples_Q, k):   
        # generate samples ~ P', Q', where KL(P', Q') < KL(P, Q)
        a1 = 1.0/(self.K+1)*k                                         # when k=0, P'=P; when k=K+1, P'=Q 
        n_P, n_Q = len(samples_P), len(samples_Q)
        if k==0:
            return samples_P
        elif k==self.K+1:
            return samples_Q
        else:
            return self.nde_array[k].sample([n_P])
    
    def MI(self, x, y, mode='mc'):
        with torch.no_grad(): 
            v, w = self.gc.forward(x, y)
            if mode=='mc':
                return self.log_ratio(v, w).mean().item() 
            else:
                return self.kv(v, w, self.log_ratio).item()
           
    def log_ratio(self, x, y):
        z, y = self.encode(x), self.encode2(y)
        zy = torch.cat([z, y], dim=1)
        softmax_score = self.critic_layer(zy)
        lr = softmax_score[:, 0] - softmax_score[:, -1]
        return lr
    
    def objective_func(self, x, y):
        # compute representation of z, y
        z, y = self.encode(x), self.encode2(y)
        m, d = x.size()
        
        # construct samples P = p(x,y), Q = p(x)p(y)
        idx_pos = []
        idx_neg1 = []
        idx_neg2 = []
        for i in range(self.n_neg): 
            idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
            idx_neg1 = idx_neg1 + torch.randperm(m).cpu().numpy().tolist()
            idx_neg2 = idx_neg2 + torch.randperm(m).cpu().numpy().tolist()
        zy_P = torch.cat([z[idx_pos], y[idx_pos]], dim=1)
        zy_Q = torch.cat([z[idx_neg1], y[idx_neg2]], dim=1)

        # generate bridge + normal samples
        xy_array, label_array = [], []
        with torch.no_grad():
            for k in range(self.K+2):
                xy_k = self._generate_bridge_samples(zy_P, zy_Q, k)            # class 0 is P, class K+1 is Q
                label_k = torch.zeros(len(zy_P)).to(zy_P.device).long() + k
                xy_array.append(xy_k)
                label_array.append(label_k)
        xy_array, label_array = torch.cat(xy_array, dim=0), torch.cat(label_array, dim=0)

        # optimize softmax score
        softmax_score = self.critic_layer(xy_array)
        return -F.cross_entropy(softmax_score, label_array)
    
    def kv(self, x, y, critic):
        m, d = x.size()
        z, y = self.encode(x), self.encode2(y)
        idx_pos = []
        idx_neg = []
        n_neg = self.n_neg if self.training else min(m, 50)
        for i in range(n_neg): 
            idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
            idx_neg = idx_neg + torch.randperm(m).cpu().numpy().tolist()
            
        f_pos = critic(x=z[idx_pos], y=y[idx_pos])
        f_neg = critic(x=z[idx_pos], y=y[idx_neg])
        mi = f_pos.mean() - (f_neg.exp().mean()+1e-30).log()
        return mi
            
    def learn(self, x, y):
        # A. learn GGC
        gc = self.learn_nde(x, y)
        self.gc = gc
        self.gc_state_dict = deepcopy(gc.state_dict())
        self.set_nde(gc.normal, gc.normal2)
        # B. forward to get latent
        with torch.no_grad():
            v, w = gc.forward(x, y)
            v, w = v.clone().detach(), w.clone().detach()
        # C. estimate MI on latent
        optimizer.NNOptimizer.learn(self, v, w)
        # D. load previously save GGC state_dict
        self.gc.load_state_dict(self.gc_state_dict)
    
    def learn_nde(self, x, y):
        n, d = x.size()
    
        # Neural density estimate
        if False:
            gc = FMGGC(n_inputs=d//2)
        else:
            gc = GGC(n_blocks=2, n_inputs=d, n_hidden=500, n_cond_inputs=2)
            gc.to(x.device)
            gc.maf1.max_iteration = 200
            gc.maf2.max_iteration = 200
            gc.max_iteration = 100
            gc.bs = 250
        
        gc.to(x.device)
        gc.learn(x, y)
        return gc

    