import sys
### Change PATH by the path of the folder
sys.path.append('PATH')

from pycle.utils.datasets import generatedataset_GMM
from pycle.sketching.feature_maps.MatrixFeatureMap import MatrixFeatureMap
from pycle.sketching import computeSketch
from pycle.compressive_learning.AltCLOMP_grid import AltCLOMP_grid

from pycle.compressive_learning.D_OMP import D_OMP
from pycle.compressive_learning.D_OMP_CKM import D_OMP_CKM

from pycle.compressive_learning.CLOMP_CKM import CLOMP_CKM

from loguru import logger
from pycle.utils.metrics import SSE
# from pycle.utils.vizualization import simple_plot_clustering, line_plot_clustering
from pycle.sketching.frequency_sampling import drawFrequencies
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import kmeans_plusplus



from pylab import figure, cm
from matplotlib.ticker import LogFormatter 



from matplotlib.colors import LogNorm

from sklearn.cluster import KMeans



import numpy as np
import torch

import copy

#


def run_kmeans_plus(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='k-means++', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_


def run_kmeans(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='random', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_



def gaussian_kernel(sigma):
    def gaussian_kernel_aux(x,y):
        x_y_squared_distance = -np.power(np.linalg.norm(x-y),2)/sigma**2
        return np.exp(x_y_squared_distance)
    return gaussian_kernel_aux


def experiment_CLOMP(Phi_emp, z,mode,grid_size,int_grid_size):
    # Solve CLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "grid_mode": mode,
        "grid_size": grid_size,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1

    ckm_solver = CLOMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
        store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    
    return centroids, weights






def experiment_D_OMP_grid(Phi_emp, z,s_sigma,grid_size,opt_mode,mode,ms_step_size,ms_iteration_number,int_grid_size):
    # Solve AltCLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "s_sigma": s_sigma,
        "grid_mode": mode,
        "grid_size": grid_size,
        "opt_mode": opt_mode,
        "ms_step_size": ms_step_size,
        "ms_iteration_number": ms_iteration_number,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1

    ckm_solver = D_OMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
    store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    int_centroids,t_grid = ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    

    return centroids,t_grid








# Dataset
#np.random.seed(20)  # easy
dim = 10
nb_clust = 10
nb_sample = 70000 #100000  # Number of samples we want to generate
sketch_dim = 1000



X = torch.load('mnist_70k.t')


X = X.double()






#T = 1

kmeans_centroids = run_kmeans(X, nb_clust,nb_clust)
k_means_SSE_10 = SSE(X, kmeans_centroids)/nb_sample

#print(k_means_SSE)

kmeans_plus_centroids = run_kmeans_plus(X, nb_clust,nb_clust)
k_means_plus_SSE_10 = SSE(X, kmeans_plus_centroids)/nb_sample

#print(k_means_plus_SSE_10)



k_means_SSE_5_list = [] 
for t in list(range(100)):
    kmeans_centroids = run_kmeans(X, nb_clust,5)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_SSE_5_list.append(k_means_SSE)

k_means_SSE = np.mean(k_means_SSE_5_list)




# sqrt_sigma_list = [0.05,0.07,0.1,0.2,0.3,0.4,0.5,0.7,1]

# sigma_list = [0.0025,0.0049,0.01,0.04,0.09,0.16,0.25,0.49,1]
# m_list = [100,200,500,1000,2000,5000]


sqrt_sigma_list = [0.1,0.2,0.3,0.4,0.5,0.7,1]

sigma_list = [0.01,0.04,0.09,0.16,0.25,0.49,1]
m_list = [100,200,500,1000,2000,5000]

S_len = len(sigma_list)
M_len = len(m_list)

#CLOMP_err_array = np.zeros((S_len,M_len,T))
#AltCLOMP_err_array = np.zeros((S_len,M_len,T))

CLOMP_mean_err_array = np.zeros((S_len,M_len))
AltCLOMP_mean_err_array = np.zeros((S_len,M_len))


CLOMP_err_array = np.load("npy/T_10/CLOMP_err_array.npy")
AltCLOMP_err_array = np.load("npy/T_10/AltCLOMP_err_array.npy")

#np.save("npy/sqrtsigma_d_"+str(dim)+"_k_"+str(nb_clust)+"_m_"+str(sketch_dim)+".npy", np.sqrt(sigma_list))
#k_means_SSE_list = np.load("npy/k_means_SSE_list.npy")
#k_means_plus_SSE_list = np.load("npy/k_means_plus_SSE_list.npy")



for s_sigma in sigma_list:
    for m in m_list:

        CLOMP_mean_err_array[sigma_list.index(s_sigma),m_list.index(m)] = np.mean(CLOMP_err_array[sigma_list.index(s_sigma),m_list.index(m),:])
        AltCLOMP_mean_err_array[sigma_list.index(s_sigma),m_list.index(m)] = np.mean(AltCLOMP_err_array[sigma_list.index(s_sigma),m_list.index(m),:])






m_str_ticks_list = [str(x) for x in m_list]
sigma_str_ticks_list = [str(x) for x in sqrt_sigma_list]







f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=2), cmap=plt.cm.Greys,
                 interpolation='bilinear',origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/CLOMP_m_sigma_gray_1_2_log_interpolation.pdf", bbox_inches='tight')
#plt.show()




f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, vmin=1,vmax=2, cmap=plt.cm.Greys,
                 interpolation='bilinear',origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/CLOMP_m_sigma_gray_1_2_nolog_interpolation.pdf", bbox_inches='tight')
#plt.show()




f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=2), cmap=plt.cm.Greys,
                 origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/CLOMP_m_sigma_gray_1_2_log_nointerpolation.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, vmin=1,vmax=2, cmap=plt.cm.Greys,
                 origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/CLOMP_m_sigma_gray_1_2_nolog_nointerpolation.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, vmin=1,vmax=5, cmap=plt.cm.Greys,
                 origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,2,3,4,5])
cb.set_ticklabels(['1','2','3','4','>5'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/CLOMP_m_sigma_gray_1_2_nolog_nointerpolation_2.pdf", bbox_inches='tight')
#plt.show()











f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=2), cmap=plt.cm.Greys,
                 interpolation='bilinear',origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/AltCLOMP_m_sigma_gray_1_2_log_interpolation.pdf", bbox_inches='tight')
#plt.show()




f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, vmin=1,vmax=2, cmap=plt.cm.Greys,
                 interpolation='bilinear',origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/AltCLOMP_m_sigma_gray_1_2_nolog_interpolation.pdf", bbox_inches='tight')
#plt.show()




f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=2), cmap=plt.cm.Greys,
                 origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/AltCLOMP_m_sigma_gray_1_2_log_nointerpolation.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, vmin=1,vmax=2, cmap=plt.cm.Greys,
                 origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,1.2,1.4,1.6,1.8,2])
cb.set_ticklabels(['1','1.2','1.4','1.6','1.8','>2'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/AltCLOMP_m_sigma_gray_1_2_nolog_nointerpolation.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, vmin=1,vmax=5, cmap=plt.cm.Greys,
                 origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(7)),sigma_str_ticks_list, fontsize=16)
formatter = LogFormatter(10, labelOnlyBase=False) 
cb = plt.colorbar(im, format=formatter)
#cb.set_ticks([])

# cb.ax.tick_params(labelsize=16)
cb.set_ticks([1,2,3,4,5])
cb.set_ticklabels(['1','2','3','4','>5'])
cb.set_label(label='RSE',fontsize=16)

#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf_reader/AltCLOMP_m_sigma_gray_1_2_nolog_nointerpolation_2.pdf", bbox_inches='tight')
#plt.show()

















# k_means_SSE_list = []
# k_means_plus_SSE_list = []
# n_i_list = [1,2,3,4,5,6,7,8,9,10,20]

# for n_i in n_i_list:
#     kmeans_centroids = run_kmeans(X, nb_clust,n_i)
#     k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
#     k_means_SSE_list.append(k_means_SSE)
#     print(k_means_SSE)

# #kmeans_centroids = run_kmeans(X, nb_clust,20)
# #k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
# #print(k_means_SSE)

# for n_i in n_i_list:
#     kmeans_centroids = run_kmeans_plus(X, nb_clust,n_i)
#     k_means_plus_SSE = SSE(X, kmeans_centroids)/nb_sample
#     k_means_plus_SSE_list.append(k_means_plus_SSE)
#     print(k_means_plus_SSE)







# f = plt.figure(figsize=(14, 10))
# plt.plot(n_i_list,k_means_SSE_list)
# plt.ylabel('kmeans',fontsize =24)
# f.savefig("pdf_reader/kmeans_vs_ninit.pdf", bbox_inches='tight')
# #plt.show()


# f = plt.figure(figsize=(14, 10))
# plt.plot(n_i_list,k_means_plus_SSE_list)
# plt.ylabel('kmeans++',fontsize =24)
# f.savefig("pdf_reader/kmeans_plus_vs_ninit.pdf", bbox_inches='tight')
# #plt.show()




# np.save("npy/CLOMP_err_array.npy",CLOMP_err_array)
# np.save("npy/AltCLOMP_err_array.npy",AltCLOMP_err_array)
# np.save("npy/sqrtsigma_d_"+str(dim)+"_k_"+str(nb_clust)+"_m_"+str(sketch_dim)+".npy", np.sqrt(sigma_list))
# np.save("npy/k_means_SSE_list.npy", k_means_SSE_list)
# np.save("npy/k_means_plus_SSE_list.npy", k_means_plus_SSE_list)


    