import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.autograd as autograd
import numpy as np
import scipy
import math
import time
import optimizer
import estimators.layers as layers
from nde.MAF import MAF


class SQE(nn.Module):
    """ 
        Squeeze estimator of mutual information
    """
    def __init__(self, architecture_encoder_x, architecture_encoder_y, architecture_critic, hyperparams):
        super().__init__()

        # default hyperparameters 
        self.bs = 500 if not hasattr(hyperparams, 'bs') else hyperparams.bs 
        self.lr = 5e-4 if not hasattr(hyperparams, 'lr') else hyperparams.lr
        self.wd = 0e-5 if not hasattr(hyperparams, 'wd') else hyperparams.wd
        self.n_neg = 4 if not hasattr(hyperparams, 'n_neg') else hyperparams.n_neg
        self.encode_x = False if not hasattr(hyperparams, 'encode_x') else hyperparams.encode_x
        self.encode_y = False if not hasattr(hyperparams, 'encode_y') else hyperparams.encode_y
        self.max_iteration = 3000

                
    def encode(self, x):
        # s = s(x), get the summary statistic of x
        return self.encode_layer(x) if self.encode_x else x
    
    def encode2(self, y):
        # theta = h(y), get the representation of y
        return self.encode_layer(y) if self.encode_y else y
    
    def set_nde(self, nde, nde_x, nde_y):
        self.nde = nde
        self.nde_x = nde_x
        self.nde_y = nde_y
   

    def MI(self, x, y):
        with torch.no_grad():
            lb = self.lower_bound(x, y).item()
            ub = self.upper_bound(x[0:2000], y[0:2000]).item()
            mi = (lb+ub)/2
            print('mi', mi, 'lb', lb, 'ub', ub)
            return mi

    def lower_bound(self, x, y):
        xy = torch.cat([x, y], dim=1)
        A = self.log_joint_marginal(xy).mean()
        return A
        
    def upper_bound(self, x, y):
        n, d = x.size()
        xy_joint = torch.cat([x, y], dim=1)
        A = self.log_conditional(xy_joint).mean()
        x_tiled = torch.stack([x] * n, dim=0)
        y_tiled = torch.stack([y] * n, dim=1)
        xy_marginal = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [
                                 n * n, -1])
        B = self.log_conditional(xy_marginal).mean()
        return A-B
        
    def log_joint_marginal(self, xy):
        n, d = xy.size()
        x, y = xy[:, 0:d//2], xy[:, d//2:]
        log_copula_density_xy = self.nde.log_prob(xy)
        log_copula_density_x = self.nde_x.log_prob(x)
        log_copula_density_y = self.nde_y.log_prob(y) 
        return log_copula_density_xy - log_copula_density_x - log_copula_density_y
    
    
    def log_conditional(self, xy):
        n, d = xy.size()
        x, y = xy[:, 0:d//2], xy[:, d//2:]
        log_copula_density_xy = self.nde.log_prob(xy)
        log_copula_density_x = self.nde_x.log_prob(x)
        return log_copula_density_xy - log_copula_density_x

    def learn(self, x, y):
        return optimizer.NNOptimizer.learn(self, x, y)