import sys
### Change PATH by the path of the folder
sys.path.append('PATH')

from datetime import datetime

from pycle.utils.datasets import generatedataset_GMM
from pycle.sketching.feature_maps.MatrixFeatureMap import MatrixFeatureMap
from pycle.sketching import computeSketch
from pycle.compressive_learning.AltCLOMP_grid import AltCLOMP_grid

from pycle.compressive_learning.D_OMP import D_OMP
from pycle.compressive_learning.D_OMP_CKM import D_OMP_CKM

from pycle.compressive_learning.CLOMP_CKM import CLOMP_CKM

from loguru import logger
from pycle.utils.metrics import SSE
# from pycle.utils.vizualization import simple_plot_clustering, line_plot_clustering
from pycle.sketching.frequency_sampling import drawFrequencies
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import kmeans_plusplus


from matplotlib.colors import LogNorm

from sklearn.cluster import KMeans



import numpy as np
import torch

import copy

#


def run_kmeans_plus(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='k-means++', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_


def run_kmeans(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='random', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_



def gaussian_kernel(sigma):
    def gaussian_kernel_aux(x,y):
        x_y_squared_distance = -np.power(np.linalg.norm(x-y),2)/sigma**2
        return np.exp(x_y_squared_distance)
    return gaussian_kernel_aux


def experiment_CLOMP(Phi_emp, z,mode,grid_size,int_grid_size):
    # Solve CLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "grid_mode": mode,
        "grid_size": grid_size,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1

    ckm_solver = CLOMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
        store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    
    return centroids, weights






def experiment_D_OMP_grid(Phi_emp, z,s_sigma,grid_size,opt_mode,mode,ms_step_size,ms_iteration_number,int_grid_size):
    # Solve AltCLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "s_sigma": s_sigma,
        "grid_mode": mode,
        "grid_size": grid_size,
        "opt_mode": opt_mode,
        "ms_step_size": ms_step_size,
        "ms_iteration_number": ms_iteration_number,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1

    ckm_solver = D_OMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
    store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    int_centroids,t_grid = ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    

    return centroids,t_grid








# Dataset
#np.random.seed(20)  # easy
dim = 6
nb_clust = 3
nb_sample = 100000 #100000  # Number of samples we want to generate
#sketch_dim = 5000



X = torch.load('dataset_6D.t')


X = X.double()


L = 10

T = 10

# kmeans_centroids = run_kmeans(X, nb_clust,nb_clust) k_means_SSE = SSE(X,
# kmeans_centroids)/nb_sample
# 
# print(k_means_SSE)
# 
# kmeans_plus_centroids = run_kmeans_plus(X, nb_clust,nb_clust)
# k_means_plus_SSE = SSE(X, kmeans_plus_centroids)/nb_sample
# 
# print(k_means_plus_SSE)



k_means_SSE_5_list = []
for t in list(range(100)):
    kmeans_centroids = run_kmeans(X,nb_clust,5)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_SSE_5_list.append(k_means_SSE)
    
k_means_SSE = 1
#np.mean(k_means_SSE_5_list)

sqrt_sigma_list = [0.05,0.07,0.1,0.2,0.3,0.4,0.5,0.7,1]

sigma_list = [0.0025,0.0049,0.01,0.04,0.09,0.16,0.25,0.49,1]
m_list = [100,200,500,1000,2000,5000]

S_len = len(sigma_list)
M_len = len(m_list)

CLOMP_err_array = np.zeros((S_len,M_len,T))
AltCLOMP_err_array = np.zeros((S_len,M_len,T))

CLOMP_mean_err_array = np.zeros((S_len,M_len))
AltCLOMP_mean_err_array = np.zeros((S_len,M_len))


for s_sigma in sigma_list:
    for m in m_list:
        for t in list(range(T)):
            grid = []
            t_grid = []
            print("sigma = "+str(s_sigma) + "and t = "+str(t) )
            
    
    
            #weights_t_gaussian, Omega_t_gaussian = radial_quadrature(dim,sketch_dim,s_sigma,s_truncated_gaussian,w_truncated_gaussian)    
            #Phi_t_gaussian = MatrixFeatureMap("complexexponential", torch.tensor(Omega_t_gaussian), device=torch.device("cpu"), weights=torch.tensor(weights_t_gaussian))
            
            
            Sigma_g =  s_sigma * np.eye(dim)
            Omega_g = drawFrequencies("gaussian", dim, m, Sigma_g, return_torch=True)
            Phi_emp_g = MatrixFeatureMap("complexexponential", Omega_g,c_norm = 'unit', device=torch.device("cpu"))
            
            #print('s_decoder')
            #print(s_sigma+float(covariances[0][0,0]))
            z_gaussian = computeSketch(X, Phi_emp_g)
            
            
            
            centroids_CLOMP, weights_CLOMP = experiment_CLOMP(Phi_emp_g, z_gaussian,'fast',1,2*nb_clust)
            CLOMP_err_array[sigma_list.index(s_sigma),m_list.index(m),t] = SSE(X, centroids_CLOMP)/(k_means_SSE*nb_sample)

    
    
            
            centroids_AltCLOMP_shifted_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,L,'normal','shifted_grid',s_sigma,10,2*nb_clust)
            AltCLOMP_err_array[sigma_list.index(s_sigma),m_list.index(m),t] = SSE(X, centroids_AltCLOMP_shifted_grid)/(k_means_SSE*nb_sample)
        CLOMP_mean_err_array[sigma_list.index(s_sigma),m_list.index(m)] = np.mean(CLOMP_err_array[sigma_list.index(s_sigma),m_list.index(m),:])
        AltCLOMP_mean_err_array[sigma_list.index(s_sigma),m_list.index(m)] = np.mean(AltCLOMP_err_array[sigma_list.index(s_sigma),m_list.index(m),:])






m_str_ticks_list = [str(x) for x in m_list]
sigma_str_ticks_list = [str(x) for x in sqrt_sigma_list]





f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap='hot',interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/CLOMP_m_sigma_hot.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap='hot',interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/AltCLOMP_m_sigma_hot.pdf", bbox_inches='tight')
#plt.show()








f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap=plt.cm.Greys,interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/CLOMP_m_sigma_greys.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap=plt.cm.Greys,interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/AltCLOMP_m_sigma_greys.pdf", bbox_inches='tight')
#plt.show()








f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap='gist_heat',interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/CLOMP_m_sigma_gistheat.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, norm=LogNorm(), cmap='gist_heat',
                interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/AltCLOMP_m_sigma_gistheat.pdf", bbox_inches='tight')
#plt.show()







f = plt.figure(figsize=(12, 6))
im = plt.imshow(CLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap='afmhot',interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/CLOMP_m_sigma_afmhot.pdf", bbox_inches='tight')
#plt.show()





f = plt.figure(figsize=(12, 6))
im = plt.imshow(AltCLOMP_mean_err_array.T, norm=LogNorm(vmin=1,vmax=100), cmap='afmhot',interpolation='bilinear', origin='lower')
plt.yticks([0,1,2,3,4,5],m_str_ticks_list, fontsize=16)
plt.xticks(list(range(9)),sigma_str_ticks_list, fontsize=16)
cb = plt.colorbar(im)
cb.set_ticks([1,10,100])
cb.set_ticklabels(['1','10','100'])
cb.set_label(label='RSE',fontsize=16)
cb.ax.tick_params(labelsize=16)
#plt.colorbar()
plt.ylabel('m', fontsize=24)
plt.xlabel(r'$\sigma$', fontsize=24)
f.savefig("pdf/AltCLOMP_m_sigma_afmhot.pdf", bbox_inches='tight')
#plt.show()















k_means_SSE_list = []
k_means_plus_SSE_list = []
n_i_list = [1,2,3,4,5,6,7,8,9,10,20]

for n_i in n_i_list:
    kmeans_centroids = run_kmeans(X, nb_clust,n_i)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_SSE_list.append(k_means_SSE)
    print(k_means_SSE)

#kmeans_centroids = run_kmeans(X, nb_clust,20)
#k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
#print(k_means_SSE)

for n_i in n_i_list:
    kmeans_centroids = run_kmeans_plus(X, nb_clust,n_i)
    k_means_plus_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_plus_SSE_list.append(k_means_plus_SSE)
    print(k_means_plus_SSE)







f = plt.figure(figsize=(14, 10))
plt.plot(n_i_list,k_means_SSE_list)
plt.ylabel('kmeans',fontsize =24)
f.savefig("pdf/kmeans_vs_ninit.pdf", bbox_inches='tight')
#plt.show()


f = plt.figure(figsize=(14, 10))
plt.plot(n_i_list,k_means_plus_SSE_list)
plt.ylabel('kmeans++',fontsize =24)
f.savefig("pdf/kmeans_plus_vs_ninit.pdf", bbox_inches='tight')
#plt.show()




np.save("npy/CLOMP_err_array.npy",CLOMP_err_array)
np.save("npy/AltCLOMP_err_array.npy",AltCLOMP_err_array)
#np.save("npy/sqrtsigma_d_"+str(dim)+"_k_"+str(nb_clust)+"_m_"+str(sketch_dim)+".npy", np.sqrt(sigma_list))
np.save("npy/k_means_SSE_list.npy", k_means_SSE_list)
np.save("npy/k_means_plus_SSE_list.npy", k_means_plus_SSE_list)



# Getting the current date and time
dt = datetime.now()
print("Date and time is:", dt)  