
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)

os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'

import torch
from torch.utils.data import Dataset
from torch.distributions import MultivariateNormal
import numpy as np
from all_estimators import *
np.random.seed(42)
import random 
random.seed(42)
import argparse
from scipy import stats
import sys





def list_of_strings(arg):
    return arg.split(',')

parser = argparse.ArgumentParser()
parser.add_argument("N")
parser.add_argument("dim")
parser.add_argument("data_ID")
parser.add_argument("--tx", type=list_of_strings)
parser.add_argument("--ty", type=list_of_strings)

args = parser.parse_args()

WRITE_FILE = 0

if WRITE_FILE == 1:
    
    orig_stdout = sys.stdout
    f = open('Results_studentT_MINE.txt', "a")
    sys.stdout = f



# print(args.tx[0])
# print(args.ty)
# print(args)

import bmi.samplers._splitmultinormal as spl
from scipy.special import digamma, gamma


def _differential_entropy(k: int, dof: int) -> float:
    """Differential entropy of a :math:`Student-t(0, I_k, dof)`.

    See Eq. (7) of
      R.B. Arellano-Valle, J.E. Contreras-Reyes, M.G. Genton,
      Shannon Entropy and Mutual Information for Multivariate
      Skew-Elliptical Distributions,
      Scandinavian Journal of Statistics, vol. 40, pp. 46-47, 2013
    """
    half_sum = 0.5 * (dof + k)
    digamma_term = half_sum * (digamma(half_sum) - digamma(0.5 * dof))

    log_term = -np.log(gamma(half_sum)) + np.log(gamma(0.5 * dof)) + 0.5 * k * np.log(dof * np.pi)

    return log_term + digamma_term


class MultivariateStudentTDataset(Dataset):
    def __init__(self, N, dim, df, rho, transforms_x=['none'], transforms_y=['none']):
        """
        Args:
            N: Number of samples
            dim: Dimensionality of each variable
            df: Degrees of freedom for Student-t distribution
            rho: Correlation coefficient
            transforms_x: List of transformations for X
            transforms_y: List of transformations for Y
        """
        self.N = N
        self.dim = dim
        self.df = df
        self.rho = rho
        self.x_transforms = transforms_x
        self.y_transforms = transforms_y
        
        self._multinormal = spl.SplitMultinormal(
            dim_x=dim, dim_y=dim, mean=np.zeros(2 * self.dim), covariance=self.cov_matrix.numpy()
        )
        
        self.dist = self.build_dist
        self.x, self.y = self.sample_data()
        
        self.transform_both()

    def sample_data(self):
        """Samples data from a multivariate Student-t distribution."""
        xy = stats.multivariate_t.rvs(
            loc=np.zeros(2 * self.dim),
            shape=self.cov_matrix.numpy(),
            df=self.df,
            size=self.N
        )
        x, y = xy[:, :self.dim], xy[:, self.dim:]
        x = torch.from_numpy(x)
        y = torch.from_numpy(y)
        return x,y

    def transform_both(self):
        """Applies transformations to both X and Y."""
        for transform in self.y_transforms:
            self.y = transform_points(self.y, transform)
        for transform in self.x_transforms:
            self.x = transform_points(self.x, transform)

    def __getitem__(self, ix):
        return self.x[ix], self.y[ix]

    def __len__(self):
        return self.N

    @property
    def build_dist(self):
        """Defines the distribution."""
        return stats.multivariate_t(loc=np.zeros(2 * self.dim), shape=self.cov_matrix.numpy(), df=self.df)

    @property
    def cov_matrix(self):
        """Constructs the dispersion matrix."""
        cov = torch.zeros((2 * self.dim, 2 * self.dim))
        cov[torch.arange(self.dim), torch.arange(self.dim, 2 * self.dim)] = self.rho
        cov[torch.arange(self.dim, 2 * self.dim), torch.arange(self.dim)] = self.rho
        cov[torch.arange(2 * self.dim), torch.arange(2 * self.dim)] = 1.0
        return cov

    @property
    def true_mi(self):
        """Computes the mutual information using differential entropy."""
        h_x = _differential_entropy(k=self.dim, dof=self.df)
        h_y = _differential_entropy(k=self.dim, dof=self.df)
        h_xy = _differential_entropy(k=2 * self.dim, dof=self.df)
        return h_x + h_y - h_xy + self._multinormal.mutual_information()


def transform_points(data,transform):
    # print(transform)
    if transform == 'none':
        return data
    if transform == 'sigmoid':
        return torch.nn.functional.sigmoid(data)
    if transform == 'concat_self':
        params = params_dict[transform]
        data_orig = deepcopy(data) 
        for k in range(params[0]):
            data = torch.concatenate((data,data_orig),axis=1)
        return data
    if transform == 'concat_self_noisy':
        params = params_dict[transform]
        data_orig = deepcopy(data) 
        for k in range(params[0]):
            data = torch.concatenate((data,torch.rand_like(data_orig)*params[1]),axis=1)
        return data
    if transform == 'cube':
        data = data**3.0
        return data
    if transform == 'randmat':
        
        flag = 0
        while flag == 0:
            try:
                rand_mat = torch.rand(1)*torch.rand(data.shape[1],data.shape[1])
                torch.linalg.inv(rand_mat)
                flag = 1
            except:
                pass 
        return torch.matmul(data.double(), rand_mat.double())
    
    
class MultivariateNormalDataset(Dataset):
    def __init__(self, N, dim, rho,transforms_x=['none'],transforms_y=['none']):
        self.N = N
        self.rho = rho
        self.dim = dim
        # print(transforms_x)
        self.x_transforms = transforms_x 
        
        self.y_transforms = transforms_y
        
        self.dist = self.build_dist
        
        self.x = self.dist.sample((N, ))
        self.y = self.x[:,dim:]
        self.x = self.x[:,:dim]        
        
        self.transform_both() 
        # self.distractor_x = 
        self.dim = dim
        
    

    def __getitem__(self, ix):
        a, b = self.x[ix, 0:self.dim], self.x[ix, self.dim:2 * self.dim]
        return a, b
    
    def transform_both(self):
        for iter in range(len(self.y_transforms)):
            self.y = transform_points(self.y,self.y_transforms[iter])
        
        for iter in range(len(self.x_transforms)):
            self.x = transform_points(self.x,self.x_transforms[iter])
            
        self.x = self.x.numpy()
        self.y = self.y.numpy()

    def __len__(self):
        return self.N

    @property
    def build_dist(self):
        mu = torch.zeros(2 * self.dim)
        dist = MultivariateNormal(mu, self.cov_matrix)
        return dist

    @property
    def cov_matrix(self):
        cov = torch.zeros((2 * self.dim, 2 * self.dim))
        cov[torch.arange(self.dim), torch.arange(self.dim, 2 * self.dim)] = self.rho
        cov[torch.arange(self.dim, 2 * self.dim), torch.arange(self.dim)] = self.rho
        cov[torch.arange(2 * self.dim), torch.arange(2 * self.dim)] = 1.0
        return cov

    @property
    def true_mi(self):
        return -0.5 * np.log(np.linalg.det(self.cov_matrix.data.numpy()))
    
    
    

    
class GaussianAdditionDataset(Dataset):
    def __init__(self, N, dim, SNR,transforms_x=['none'],transforms_y=['none']):
        self.N = N
        self.dim = dim
        self.SNR = SNR
        
        self.x_transforms = transforms_x
        # self.x_transform_params = transforms_x[1]
        
        self.y_transforms = transforms_y 
        # self.y_transform_params = transforms_y[1]

        
        self.x = np.random.normal(0., 1, [self.N, self.dim])    
        self.y = self.x + np.random.normal(0., np.sqrt(1/self.SNR), [self.N, self.dim])
        self.x = torch.from_numpy(self.x)
        self.y = torch.from_numpy(self.y)
        
        self.transform_both() 
        # self.distractor_x = 
        
    
    def transform_both(self):
        for iter in range(len(self.y_transforms)):
            self.y = transform_points(self.y,self.y_transforms[iter])
        
        for iter in range(len(self.x_transforms)):
            self.x = transform_points(self.x,self.x_transforms[iter])
            
        self.x = self.x.numpy()
        self.y = self.y.numpy()

    def __len__(self):
        return self.N


    @property
    def true_mi(self):
        return self.dim*0.5*np.log(1 + self.SNR)
    

# Example usage:

#--------------------Configurations used for experiments:------------------------------
# params_dict = {
#   "concat_self": [20],
#   "randmat": [],
#    "cube": [],
#    "concat_self_noisy": [20,0.2],
#    "sigmoid": [],
#    "scale": []
# }
# --------------------------------------------------------------------------------------

params_dict = {
  "concat_self": [20],
  "randmat": [],
   "cube": [],
   "concat_self_noisy": [20,0.2],
   "sigmoid": [],
   "scale": []
}



total_epochs = 80
batch_size = 400
hidden_layer = 15
mine_est = MI_Estimator([total_epochs,batch_size,hidden_layer]).MINE_MI
mine_est_local = MI_Estimator([total_epochs,batch_size,hidden_layer]).MINE_Local_MI
mine_est_global = MI_Estimator([total_epochs,batch_size,hidden_layer,True]).MINE_Global_MI
mine_est_global_nocorrection = MI_Estimator([total_epochs,batch_size,hidden_layer,False]).MINE_Global_MI


# mine_est_global_infnorm = MI_Estimator([total_epochs,batch_size,hidden_layer,True]).MINE_Global_infnorm
# -----------------------------------


k1=3
c_local = [1.0]
c_global = np.linspace(0.1,2.0,20)
# c_global = [0.8,0.9,1.0,1.1,1.2]
# c_global = [1.0]
# print(C_z)
KSG_est = MI_Estimator([k1]).KSG
KSG_local_est = MI_Estimator([k1,c_local]).KSG_local
KSG_global_est = MI_Estimator([k1,c_global]).KSG_global
KSG_global_est_nomax = MI_Estimator([k1,[1.0]]).KSG_global
KSG_local_est_infnorm = MI_Estimator([k1,c_local]).KSG_local_infnorm
KSG_global_est_infnorm = MI_Estimator([k1,c_global]).KSG_global_infnorm

# -----------------------------------
# Mixed_est = MI_Estimator([k1]).Mixed_KSG 


k2 = 3
q = np.inf
revised_KSG_est = MI_Estimator([k2,q]).KSG_revised
# -----------------------------------
k3 = 5
alpha = 0.25
LNC_est = MI_Estimator([k3,alpha]).LNC_MI 

bin_est = MI_Estimator([]).bin_MI 

# ----------------------------------
# infonce_est = MI_Estimator([100]).info_nce_MI
# infonce_est_local = MI_Estimator([100]).info_nce_MI_local
# infonce_est_global = MI_Estimator([100]).info_nce_MI_global


# hidden_ratio = [np.linspace(0.1,2.0,num=10)]
# hidden_ratio = np.arange(1,20)/50.0
# batch_size = 200
# MVIG_est = MI_Estimator([hidden_ratio,batch_size])
# VI_est = MI_Estimator([hidden_ratio[-1],batch_size])


# estimators = [KSG_est,KSG_local_est,KSG_global_est,KSG_local_est_infnorm,KSG_global_est_infnorm,revised_KSG_est] 
estimators = [mine_est,mine_est_local,mine_est_global,mine_est_global_nocorrection]
# estimators = [KSG_est,KSG_local_est,KSG_global_est,KSG_local_est_infnorm,KSG_global_est_infnorm,revised_KSG_est,bin_est]
# estimators = [mine_est,mine_est_local,mine_est_global,mine_est_global_infnorm]
# estimators = [KSG_est,KSG_local_est,KSG_global_est,KSG_local_est_infnorm,KSG_global_est_infnorm,revised_KSG_est,bin_est]

error_list = [[] for x in estimators]
output_list = [[] for x in estimators]
true_mi_list = [] 
one_sided_check_list = [[] for x in estimators]

# print('here')

N = int(args.N)
dim = int(args.dim)
trials = 40
transforms_x = args.tx
transforms_y = args.ty

# y_transforms = ['concat_self','cube','sigmoid']
# concat_num = 20
# concat_noise = 0.1
# y_transform_params = [[concat_num,concat_noise],[],[]]


# x_transforms = ['randmat','cube','sigmoid','concat_self']
# concat_num = 20
# x_transform_params = [[concat_num,concat_noise],[],[],[concat_num]]

datasets = [MultivariateNormalDataset,GaussianAdditionDataset,MultivariateStudentTDataset]

data_ID = int(args.data_ID)
if data_ID == 0 or data_ID == 2:
    rho_max = 0.8
else:
    rho_max = 2.0

ksg_error = [] 
ksg_local_error = []
ksg_global_error = []

for i in range(trials):    
    rho = np.random.rand()*rho_max
    df = random.choice([1,2,3,5,8])
    dataset = datasets[data_ID](N, dim, df, rho,transforms_x,transforms_y)
    # print("True MI:", dataset.true_mi)
    # print('LNC:',estimators[-1](dataset.x,dataset.y))
    for temp in range(len(estimators)):
        E = estimators[temp](dataset.x,dataset.y)
        error_list[temp].append((E-dataset.true_mi))
        output_list[temp].append(E)
        one_sided_check_list[temp].append(int((E-dataset.true_mi)<0))
    
    true_mi_list.append(dataset.true_mi)
    

true_mi_list = np.array(true_mi_list)
permuted_mi_list = true_mi_list[np.random.permutation(len(true_mi_list))]
ref_error1 = np.mean(np.abs(permuted_mi_list - true_mi_list))
ref_error2 = np.sqrt(np.mean((permuted_mi_list - true_mi_list)**2))
# ref_error3 = np.mean(permuted_mi_list - true_mi_list)


for temp in range(len(error_list)):
    error_list[temp] = np.array(error_list[temp])
    
        
print('\n\n\n\n')
print(args)
print('\n')
for temp in range(len(error_list)):
    print(estimators[temp].__name__+"(MAE):   "+str(np.mean(np.abs(error_list[temp]))))
    print(estimators[temp].__name__+"(RMSE):   "+str(np.sqrt(np.mean(error_list[temp]**2))))
    print(estimators[temp].__name__+"(bias):   "+str(np.mean(error_list[temp])))
    print(estimators[temp].__name__+"(normalized MAE):   "+str(np.mean(np.abs(error_list[temp]))/ref_error1))
    print(estimators[temp].__name__+"(normalized RMSE):   "+str(np.sqrt(np.mean(error_list[temp]**2)/ref_error2)))
    print(estimators[temp].__name__+"(normalized bias):   "+str(np.mean(error_list[temp])/ref_error1))

    print(estimators[temp].__name__+"(Spearman):"+str(stats.spearmanr(np.array(output_list[temp]),np.array(true_mi_list))))
    print(estimators[temp].__name__+" Direction:",np.mean((np.array(one_sided_check_list[temp])==1)).astype(float))
    # error_list[temp] = np.mean(error_list[temp])


# a = input('done')

if WRITE_FILE == 1:
    sys.stdout = orig_stdout
    f.close()


# print(error_list)
