# -*- coding: utf-8 -*-
"""
Created on Fri Jul 12 18:41:51 2024

@author: User
"""

# -*- coding: utf-8 -*-
"""
Created on Mon Jul  8 19:45:14 2024

@author: User
"""
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)


        
        

from all_estimators import * 
from MI_hybrid_generators import * 

from scipy import stats


random.seed(42)



num_samples = 1000
file_to_test = 0
 


        

# save_name = 'MI_Estimation_File'+ str(file_to_test)+ '_NumSamples_1000_results.pkl'
# save_name = 'MI_Estimation_File4_NumSamples_1000_results.pkl'
# file_dir = os.path.join(dir_path, save_name)   

# with open(file_dir, 'rb') as f:  # Python 3: open(..., 'wb')
#         [estimated_mi_mine, estimated_mi_ksg, estimated_mi_mvig, estimated_mi_VI,MI_list] = pickle.load(f)
        
# print('KSG:',np.corrcoef(np.array(estimated_mi_ksg),np.array(MI_list)))
# print('MINE:',np.corrcoef(np.array(estimated_mi_mine),np.array(MI_list)))
# print('VI:',np.corrcoef(np.array(estimated_mi_VI),np.array(MI_list)))
# print('M-VIG:',np.corrcoef(np.array(estimated_mi_mvig),np.array(MI_list)))


# print('KSG:',np.mean((np.array(estimated_mi_ksg) -np.array(MI_list))**2))
# print('MINE:',np.mean((np.array(estimated_mi_mine) -np.array(MI_list))**2))
# print('VI:',np.mean((np.array(estimated_mi_VI) -np.array(MI_list))**2))
# print('M-VIG:',np.mean((np.array(estimated_mi_mvig) -np.array(MI_list))**2))

# a = input('')

# -----------------------------------

total_epochs = 100
batch_size = 400
hidden_layer = 50
mine_est = MI_Estimator([total_epochs,batch_size,hidden_layer])
mine_est_local = MI_Estimator([total_epochs,batch_size,hidden_layer])
mine_est_global = MI_Estimator([total_epochs,batch_size,hidden_layer])
# -----------------------------------


k1=1
C_z = np.linspace(0.1,2.0,10)
print(C_z)
KSG_est = MI_Estimator([k1])
KSG_local_est = MI_Estimator([k1,C_z])
Mixed_est = MI_Estimator([k1])
KSG_global_est = MI_Estimator([k1,C_z])
# -----------------------------------

k2 = 3
q = 6
revised_KSG_est = MI_Estimator([k2,q])
# -----------------------------------


# hidden_ratio = [np.linspace(0.1,2.0,num=10)]
hidden_ratio = np.arange(1,20)/50.0
batch_size = 200
MVIG_est = MI_Estimator([hidden_ratio,batch_size])
VI_est = MI_Estimator([hidden_ratio[-1],batch_size])
# -----------------------------------

# -----------------------------------


file_list = [] 
estimated_mi_mine = [] 
estimated_mi_ksg = [] 
estimated_mi_ksg_local = [] 
estimated_mi_mixed = [] 
estimated_mi_revised_ksg = [] 
estimated_mi_mine_local = []
estimated_mi_mine_global = [] 
estimated_mi_ksg_global = [] 


file_list.append('Random_MI_Est_Experiments_num_classes_2_Dim_3_radial_var_0.002.pkl')
file_list.append('MI_Est_Experiments_num_classes_2_Dim_3_radial_var_0.03.pkl')
file_list.append('MI_Est_Experiments_num_classes_2_Dim_10_radial_var_0.005.pkl')
file_list.append('MI_Est_Experiments_num_classes_2_Dim_100_radial_var_0.001.pkl')
file_list.append('MI_Est_Experiments_num_classes_2_Dim_100_radial_var_0.1extreme_configs.pkl')




with open(file_list[file_to_test], 'rb') as f:  # Python 3: open(..., 'wb')
    [means_list,covs_list,MI_list] = pickle.load(f)
    

for j in range(len(means_list)):
    prob_gen = Hybrid_Generator(means_list[j],covs_list[j])
    x_samples, y_samples = prob_gen.generate_samples(num_samples)
    x_samples = np.array(x_samples)
    y_samples = np.array(y_samples)
  
    # estimated_mi_mine.append(mine_est.MINE_MI(x_samples,y_samples))
    estimated_mi_ksg.append(KSG_est.KSG(x_samples,y_samples))
    # estimated_mi_mixed.append(Mixed_est.Mixed_KSG(x_samples, y_samples))
    estimated_mi_ksg_local.append(KSG_local_est.KSG_local(x_samples, y_samples))
    estimated_mi_ksg_global.append(KSG_global_est.KSG_global(x_samples, y_samples))
    # estimated_mi_revised_ksg.append(0)
    # estimated_mi_mine_local.append(mine_est.MINE_Local_MI(x_samples,y_samples))
    # estimated_mi_mine_global.append(mine_est.MINE_Global_MI(x_samples,y_samples))

    
    print('MI:',MI_list[j])
    # print('Estimated MI (MINE):',estimated_mi_mine[-1])
    # print('Estimated MI (KSG):',estimated_mi_ksg[-1])
    # print('Estimated MI (M-VIG):',estimated_mi_mvig[-1])
    # print('Estimated MI (V-I):',estimated_mi_VI[-1])
    # print('Estimated MI (KSG-Local):',estimated_mi_ksg_local[-1])
    print('Estimated MI (KSG):',estimated_mi_ksg[-1])
    # print('Estimated MI (Mixed-KSG):',estimated_mi_mixed[-1])
    # print('Estimated MI (KSG-revised):',estimated_mi_revised_ksg[-1])
    print('Estimated MI (KSG-local):',estimated_mi_ksg_local[-1])
    print('Estimated MI (KSG-global):',estimated_mi_ksg_global[-1])
    # print('Estimated MI (MINE):',estimated_mi_mine[-1])
    # print('Estimated MI (MINE-local):',estimated_mi_mine_local[-1])
    # print('Estimated MI (MINE-global):',estimated_mi_mine_global[-1])



    
print('--------------RMSE------------------')
print('KSG:',np.sqrt(np.mean((np.array(estimated_mi_ksg) -np.array(MI_list))**2)))
# print('Mixed-KSG:',np.sqrt(np.mean((np.array(estimated_mi_mixed) -np.array(MI_list))**2)))
# print('KSG-revised:',np.sqrt(np.mean((np.array(estimated_mi_revised_ksg) -np.array(MI_list))**2)))
print('KSG-Local:',np.sqrt(np.mean((np.array(estimated_mi_ksg_local) -np.array(MI_list))**2)))
print('KSG-global:',np.sqrt(np.mean((np.array(estimated_mi_ksg_global) -np.array(MI_list))**2)))
# print('MINE',np.sqrt(np.mean((np.array(estimated_mi_mine) -np.array(MI_list))**2)))
# print('MINE-local:',np.sqrt(np.mean((np.array(estimated_mi_mine_local) -np.array(MI_list))**2)))
# print('MINE-global:',np.sqrt(np.mean((np.array(estimated_mi_mine_global) -np.array(MI_list))**2)))


print('--------------MAE------------------')
print('KSG:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_ksg) -np.array(MI_list)))))
# print('Mixed-KSG:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_mixed) -np.array(MI_list)))))
# print('KSG-revised:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_revised_ksg) -np.array(MI_list)))))
print('KSG-Local:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_ksg_local) -np.array(MI_list)))))
print('KSG-global:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_ksg_global) -np.array(MI_list)))))
# print('MINE',np.sqrt(np.mean(np.abs(np.array(estimated_mi_mine) -np.array(MI_list)))))
# print('MINE-local:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_mine_local) -np.array(MI_list)))))
# print('MINE-global:',np.sqrt(np.mean(np.abs(np.array(estimated_mi_mine_global) -np.array(MI_list)))))


print('--------------Spearman------------------')
print('KSG:',stats.spearmanr(np.array(estimated_mi_ksg),np.array(MI_list)))
# print('Mixed-KSG:',stats.spearmanr(np.array(estimated_mi_mixed),np.array(MI_list)))
# print('KSG-revised:',stats.spearmanr(np.array(estimated_mi_revised_ksg),np.array(MI_list)))
print('KSG-Local:',stats.spearmanr(np.array(estimated_mi_ksg_local),np.array(MI_list)))
print('KSG-global:',stats.spearmanr(np.array(estimated_mi_ksg_global),np.array(MI_list)))
# print('MINE:',stats.spearmanr(np.array(estimated_mi_mine),np.array(MI_list)))
# print('MINE-Local:',stats.spearmanr(np.array(estimated_mi_mine_local),np.array(MI_list)))
# print('MINE-global:',stats.spearmanr(np.array(estimated_mi_mine_global),np.array(MI_list)))


# save_name = 'MI_Estimation_v2_'+ str(file_to_test)+ '_NumSamples_'+str(num_samples)+ '_results.pkl'
# file_dir = os.path.join(dir_path, save_name)         


# with open(file_dir, 'wb') as f:  # Python 3: open(..., 'wb')
#         pickle.dump([estimated_mi_mine,	estimated_mi_ksg,	estimated_mi_ksg_local,	
#                      estimated_mi_mixed,	estimated_mi_revised_ksg,	estimated_mi_mine_local,
#                      imated_mi_mine_global,	estimated_mi_ksg_global], f)
        
        
  
        

    
        
        

