import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.autograd as autograd
import numpy as np
import scipy
import math
import time
import optimizer
import estimators.layers as layers


class TRE(nn.Module):
    """ 
        Telescoping density-ratio estimate
    """
    def __init__(self, architecture_encoder_x, architecture_encoder_y, architecture_critic, hyperparams):
        super().__init__()

        # default hyperparameters 
        self.bs = 500 if not hasattr(hyperparams, 'bs') else hyperparams.bs 
        self.lr = 5e-4 if not hasattr(hyperparams, 'lr') else hyperparams.lr
        self.wd = 0e-5 if not hasattr(hyperparams, 'wd') else hyperparams.wd
        self.n_neg = 4 if not hasattr(hyperparams, 'n_neg') else hyperparams.n_neg
        self.encode_x = False if not hasattr(hyperparams, 'encode_x') else hyperparams.encode_x
        self.encode_y = False if not hasattr(hyperparams, 'encode_y') else hyperparams.encode_y
        self.K = 4 if not hasattr(hyperparams, 'n_bridges') else hyperparams.n_bridges
        self.bridge_mode = 'intrapolation' if not hasattr(hyperparams, 'bridge_mode') else hyperparams.bridge_mode
        self.max_iteration = 2000

        CriticLayer = layers.QuadraticCriticLayer if hyperparams.critic == 'quadratic' else layers.NeuralCriticLayer
        
        # layers
        self.encode_layer = None
        self.encode2_layer = None
        self.critic_layer = CriticLayer(architecture_critic, self.K, hyperparams)
    
    def encode(self, x):
        # s = s(x), get the summary statistic of x
        return self.encode_layer(x) if self.encode_x else x
    
    def encode2(self, y):
        # theta = h(y), get the representation of y
        return self.encode_layer(y) if self.encode_y else y
            
    def _generate_bridge_samples(self, samples_P, samples_Q, k):   
        # generate samples ~ P', Q', where KL(P', Q') < KL(P, Q)
        a1 = (1.0/self.K)*k                                     # when k=0, P' = P
        a2 = (1.0/self.K)*(k+1)                                 # when k=K, Q' = Q
        n_P, n_Q = len(samples_P), len(samples_Q)
        if self.bridge_mode == 'mixture':
            # samples for P'
            idx_P1, idx_Q1 = torch.randperm(n_P)[:int(n_P*(1-a1))], torch.randperm(n_Q)[:int(n_Q*a1)]
            samples_new_P = torch.cat([samples_P[idx_P1]] + [samples_Q[idx_Q1]], dim=0)   
            # samples for Q'
            idx_P2, idx_Q2 = torch.randperm(n_P)[:int(n_P*(1-a2))], torch.randperm(n_Q)[:int(n_Q*a2)]
            samples_new_Q = torch.cat([samples_P[idx_P2]] + [samples_Q[idx_Q2]], dim=0)   
        if self.bridge_mode == 'intrapolation':
            # samples for P'
            samples_new_P = samples_P*(1-a1**2)**0.5 + samples_Q*a1
            # samples for Q'
            samples_new_Q = samples_P*(1-a2**2)**0.5 + samples_Q*a2
        return samples_new_P, samples_new_Q
    
    def MI(self, x, y):
        return self.log_ratio(x, y).mean().item()
   
    def log_ratio(self, x, y):
        z, y = self.encode(x), self.encode2(y)
        lr = torch.zeros(len(x)).to(x.device)
        for k in range(self.K):
            zy = torch.cat([z, y], dim=1)
            lr += self.critic_layer(zy, k).view(-1)
        return lr.view(-1)
    
    def objective_func(self, x, y):
        # compute representation of z, y
        z, y = self.encode(x), self.encode2(y)
        m, d = x.size()
        
        # construct samples P = p(x,y), Q = p(x)p(y)
        idx_pos = []
        idx_neg1 = []
        idx_neg2 = []
        for i in range(self.n_neg): 
            idx_pos = idx_pos + np.linspace(0, m-1, m).tolist()
            idx_neg1 = idx_neg1 + torch.randperm(m).cpu().numpy().tolist()
            idx_neg2 = idx_neg2 + torch.randperm(m).cpu().numpy().tolist()
        zy_P = torch.cat([z[idx_pos], y[idx_pos]], dim=1)
        zy_Q = torch.cat([z[idx_neg1], y[idx_neg2]], dim=1)

        
        
        # compute JSD(P', Q')
        mi_array = torch.zeros(self.K).to(z.device)
        for k in range(self.K):
            # construct samples P', Q'
            zy_PP, zy_QQ = self._generate_bridge_samples(zy_P, zy_Q, k)
            
            # calculate the ratio f = log P'/Q'
            f_pos = self.critic_layer(zy_PP, k)
            f_neg = self.critic_layer(zy_QQ, k)
            A, B = -F.softplus(-f_pos), -F.softplus(f_neg)
            mi = A.mean() + B.mean()
            mi_array[k] = mi
        return mi_array.mean()
    
    def learn(self, x, y):
        return optimizer.NNOptimizer.learn(self, x, y)
