# -*- coding: utf-8 -*-
"""
Created on Sun Jun 30 15:49:24 2024

@author: User
"""
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)
# import torch

from MINE_pytorch import * 
from knnie import * 
from simplebinmi import * 
from normalize_functions import * 
import numpy as np 
from vig_lib import *
from MI_hybrid_generators import * 
np.random.seed(42)
import random
random.seed(42)
torch.manual_seed(42)
from lnc import MI_LNC
# import bmi

# from sklearn.feature_selection import mutual_info_classif

from npeet import entropy_estimators as ee

class MI_Estimator():
    def __init__(self, params,mode='continuous'):
        self.params = params
        self.mode = mode
        
    def make_presentable(self,x_sample):
        x_sample = np.reshape(x_sample, (x_sample.shape[0], int(x_sample.numel() / x_sample.shape[0])))
        return x_sample.numpy()
    
    
    def make_presentable_hot(self, y_sample):
        if self.mode == 'discrete':
            y_hot = np.zeros((y_sample.numel(), y_sample.max() + 1))
            y_hot[np.arange(y_sample.numel()), y_sample] = 1
            return y_hot
        else:
            y_sample = np.reshape(y_sample, (y_sample.shape[0], int(y_sample.numel() / y_sample.shape[0])))
            return y_sample.numpy() 
    

    def MINE_MI(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        MI =  MINE_estimator(x_sample,y_sample,iters = self.params[0],batch_size = self.params[1],hidden=self.params[2])
        return MI 
    
    def MINE_Global_MI(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        normx = global_normalize(x_sample,C=1,dim_correction=self.params[3])
        normy = global_normalize(y_sample,C=1,dim_correction=self.params[3])
        MI =  MINE_estimator(normx,normy,iters = self.params[0],batch_size = self.params[1],hidden=self.params[2])
        return MI
    
    # def MINE_Global_infnorm(self, x_sample,y_sample):
        
    #     x_sample = self.make_presentable(x_sample)
    #     y_sample = self.make_presentable_hot(y_sample)
        
    #     normx = global_normalize(x_sample,C=1,p=np.inf)
    #     normy = global_normalize(y_sample,C=1,p=np.inf)
    #     MI =  MINE_estimator(normx,normy,iters = self.params[0],batch_size = self.params[1],hidden=self.params[2])
    #     return MI
        
        
        
    def MINE_Local_MI(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        normx = local_normalize(x_sample,C=1)
        normy = local_normalize(y_sample,C=1)
        MI =  MINE_estimator(normx,normy,iters = self.params[0],batch_size = self.params[1],hidden=self.params[2])
        return MI
    
    def KSG(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        MI = kraskov_mi(x_sample,y_sample,k=self.params[0])
        return MI 
    
    def KSG_local(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)      
        
        normx = local_normalize(x_sample,C=1,p=2)
        normy = local_normalize(y_sample,C=1,p=2)
        MI = kraskov_mi(normx,normy,k=self.params[0])
        
        return MI 
    
    def KSG_local_infnorm(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)      
        
        normx = local_normalize(x_sample,C=1,p=np.inf)
        normy = local_normalize(y_sample,C=1,p=np.inf)
        MI = kraskov_mi(normx,normy,k=self.params[0])
        
        return MI
    

        
    def KSG_global(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        
        C_z=self.params[1]
        MI_list = [] 
        
        for C_zi in C_z:
            normx = global_normalize(x_sample,C=C_zi,p=2)
            normy = global_normalize(y_sample,C=1,p=2)
            MI = kraskov_mi(normx,normy,k=self.params[0])
            MI_list.append(MI)
            
        Final_MI=max(MI_list)
        return Final_MI
    
    def KSG_global_infnorm(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        
        C_z=self.params[1]
        MI_list = [] 
        
        for C_zi in C_z:
            normx = global_normalize(x_sample,C=C_zi,p=np.inf)
            normy = global_normalize(y_sample,C=1,p=np.inf)
            MI = kraskov_mi(normx,normy,k=self.params[0])
            MI_list.append(MI)
            
        Final_MI=max(MI_list)
        return Final_MI 
    
        
    def KSG_revised(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)

        
        MI = revised_mi(x_sample,y_sample,k=self.params[0],q=self.params[1])
        return MI
    
    # def KSG_revised_local(self,x_sample,y_sample):
    #     x_sample = self.make_presentable(x_sample)
    #     y_sample = self.make_presentable_hot(y_sample)

        
    #     MI = revised_mi(x_sample,y_sample,k=self.params[0],q=self.params[1])
    #     return MI
    
    # def KSG_revised_global()
    
    def M_VIG(self, x_sample, y_sample): 
        x_sample = self.make_presentable(x_sample)
        
        mvig = MVIG(x_sample, y_sample, hidden_ratio=self.params[0], batch_size=self.params[1], verbose=1, dataset_name='temp')
        return mvig
    
    def VInfo(self, x_sample, y_sample): 
        x_sample = self.make_presentable(x_sample)
        
        vi = VI(x_sample, y_sample, hidden_ratio=self.params[0], batch_size=self.params[1], verbose=0, dataset_name='temp')
        return vi
    
    def Mixed_KSG(self, x_sample,y_sample): 
        x_sample = self.make_presentable(x_sample)
        
        normx = local_normalize(x_sample,C=1)
        
        mixed_mi = ee.micd(normx, np.expand_dims(y_sample,axis=1), k=self.params[0],base=np.exp(1))
        return mixed_mi 
    
    def LNC_MI(self,x_sample,y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        x_list = []
        y_list = [] 
        for d in range(x_sample.shape[1]):
            x_list.append(x_sample[:,d].tolist())
        
        for d in range(y_sample.shape[1]):
            y_list.append(y_sample[:,d].tolist())
            
        # MI = MI_LNC.mi_LNC([[0.1,0.2,0.3],[0.1,0.2,0.3],[0.1,0.2,0.3],[0.1,0.2,0.3]],k=2,base=np.exp(1),alpha=0.2)
        MI_joint = MI_LNC.mi_LNC(x_list+y_list,k=self.params[0],base=np.exp(1),alpha=self.params[1])
        MI_x = MI_LNC.mi_LNC(x_list,k=self.params[0],base=np.exp(1),alpha=self.params[1])
        MI_y = MI_LNC.mi_LNC(y_list,k=self.params[0],base=np.exp(1),alpha=self.params[1])
        
        return MI_joint
    
    # def info_nce_MI(self, x_sample, y_sample):
    #     x_sample = self.make_presentable(x_sample)
    #     y_sample = self.make_presentable_hot(y_sample)
    #     estimator = bmi.estimators.InfoNCEEstimator()
    #     return estimator.estimate(x_sample, y_sample)
    
    # def info_nce_MI_local(self, x_sample, y_sample):
    #     x_sample = self.make_presentable(x_sample)
    #     y_sample = self.make_presentable_hot(y_sample)
        
    #     normx = local_normalize(x_sample,C=1,p=2)
    #     normy = local_normalize(y_sample,C=1,p=2)
        
    #     estimator = bmi.estimators.InfoNCEEstimator(batch_size=self.params[0])
    #     return estimator.estimate(normx, normy)
    
    # def info_nce_MI_global(self, x_sample, y_sample):
    #     x_sample = self.make_presentable(x_sample)
    #     y_sample = self.make_presentable_hot(y_sample)
        
    #     normx = global_normalize(x_sample,C=1,p=2)
    #     normy = global_normalize(y_sample,C=1,p=2)
        
    #     estimator = bmi.estimators.InfoNCEEstimator(batch_size=self.params[0])
    #     return estimator.estimate(normx, normy)
    
    def bin_MI(self, x_sample, y_sample):
        x_sample = self.make_presentable(x_sample)
        y_sample = self.make_presentable_hot(y_sample)
        
        MI = Bin_MI(x_sample, y_sample)
        return MI         
        
        
    def get_norm(self,x_sample):
        return get_data_norm(x_sample)
    
    
    


if __name__ == "__main__":
    
    
    #  All Estimators are in nats
    
    k1=3
    c_local = [1.0]
    c_global = np.linspace(0.1,2.0,20)
    # c_global = [0.8,0.9,1.0,1.1,1.2]
    # c_global = [1.0]
    # print(C_z)
    KSG_est = MI_Estimator([k1]).KSG
    total_epochs = 100
    batch_size = 400
    hidden_layer = 10
    mine_est = MI_Estimator([total_epochs,batch_size,hidden_layer]).MINE_MI
    
    joint_samples = np.random.multivariate_normal(np.array([0,0]), np.array([[1, 1], [1, 1]]),size=2000)
    
    X, Y = joint_samples[:, 0], joint_samples[:, 1]
    X = np.expand_dims(X,1)
    
    
    
    # print(VI_est.VInfo(X, Y))
    # print(KSG_est.KSG(X,Y))
        
    here = 1