import sys
### Change PATH by the path of the folder
sys.path.append('PATH')



from pycle.utils.datasets import generatedataset_GMM
from pycle.sketching.feature_maps.MatrixFeatureMap import MatrixFeatureMap
from pycle.sketching import computeSketch

from pycle.compressive_learning.AltCLOMP_grid import AltCLOMP_grid


from pycle.compressive_learning.AltCLOMP import AltCLOMP
from pycle.compressive_learning.CLOMP_CKM import CLOMP_CKM

from pycle.compressive_learning.D_OMP import D_OMP
from pycle.compressive_learning.D_OMP_CKM import D_OMP_CKM

from loguru import logger
from pycle.utils.metrics import SSE
# from pycle.utils.vizualization import simple_plot_clustering, line_plot_clustering
from pycle.sketching.frequency_sampling import drawFrequencies
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


from sklearn.cluster import KMeans

import copy


import numpy as np
import torch





def run_kmeans_plus(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='k-means++', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_


def run_kmeans(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='random', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_

def gaussian_kernel(sigma):
    def gaussian_kernel_aux(x,y):
        x_y_squared_distance = -np.power(np.linalg.norm(x-y),2)/sigma**2
        return np.exp(x_y_squared_distance)
    return gaussian_kernel_aux

def experiment_CLOMP(Phi_emp, z,mode,grid_size,int_grid_size):
    # Solve CLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "grid_mode": mode,
        "grid_size": grid_size,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1
    ckm_solver = CLOMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
        store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    
    return centroids, weights






def experiment_D_OMP_grid(Phi_emp, z,s_sigma,grid_size,opt_mode,mode,ms_step_size,ms_iteration_number,int_grid_size):
    # Solve AltCLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "s_sigma": s_sigma,
        "grid_mode": mode,
        "grid_size": grid_size,
        "opt_mode": opt_mode,
        "ms_step_size": ms_step_size,
        "ms_iteration_number": ms_iteration_number,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1

    ckm_solver = D_OMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
    store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    int_centroids,t_grid = ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    

    return centroids,t_grid







# Dataset
#np.random.seed(20)  # easy
dim = 2
nb_clust = 3
nb_sample = 100000  # Number of samples we want to generate
X = torch.load('dataset_f.t')


T = 1





    






for n_i in [1,2,5,10,20]:
    kmeans_centroids = run_kmeans(X, nb_clust,n_i)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    print(k_means_SSE)

kmeans_centroids = run_kmeans(X, nb_clust,20)
k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
print(k_means_SSE)

kmeans_plus_centroids = run_kmeans_plus(X, nb_clust,5)
k_means_plus_SSE = SSE(X, kmeans_plus_centroids)/nb_sample

print(k_means_plus_SSE)




CLOMP_SSE_list_of_lists = [] 
D_OMP_grid_SSE_list_of_lists = []


CLOMP_mean_SSE_list = []
D_OMP_grid_mean_SSE_list = []


CLOMP_median_SSE_list = []
D_OMP_grid_median_SSE_list = []



sigma_list = [0.0001,0.0009,0.0025,0.0049,0.01,0.04,0.09,0.16,0.25,0.36,0.49,1]

sketch_dim_list = [30,1000]
for sketch_dim in sketch_dim_list:
    CLOMP_mean_SSE = []
    CLOMP_median_SSE = []
    D_OMP_grid_mean_SSE = []
    D_OMP_grid_median_SSE = []
    for s_sigma in sigma_list:
    
        
        tmp_median_CLOMP_sse_list = []
        tmp_CLOMP_sse_list = []
        tmp_median_D_OMP_grid_sse_list = []
        tmp_D_OMP_grid_sse_list = []
        
    
        
        error_flag = 0
    
    
    
        for t in list(range(T)):
            print("sigma = "+str(s_sigma) + "and t = "+str(t) )
            
    
    

            
            Sigma_g =  s_sigma * np.eye(dim)
            Omega_g = drawFrequencies("gaussian", dim, sketch_dim, Sigma_g, return_torch=True)
            Phi_emp_g = MatrixFeatureMap("complexexponential", Omega_g,c_norm = 'unit', device=torch.device("cpu"))
            

            z_gaussian = computeSketch(X, Phi_emp_g)
            
            
            centroids_CLOMP, weights_CLOMP = experiment_CLOMP(Phi_emp_g, z_gaussian,'fast',1,2*nb_clust)
            tmp_CLOMP_sse_list.append(SSE(X, centroids_CLOMP)/nb_sample)   

            
            # grid_size
            # ms_step_size
            # opt_mode
            # ms_iteration_number
            
            centroids_D_OMP_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,500,'fast','shifted_grid',s_sigma,1000,2*nb_clust)
            tmp_D_OMP_grid_sse_list.append(SSE(X, centroids_D_OMP_grid)/nb_sample) 
        

            
        tmp_CLOMP_sse_list_copy = copy.deepcopy(tmp_CLOMP_sse_list)
        CLOMP_SSE_list_of_lists.append(tmp_CLOMP_sse_list_copy)
            
        
        CLOMP_mean_SSE.append(np.mean(tmp_CLOMP_sse_list))
        CLOMP_median_SSE.append(np.median(tmp_CLOMP_sse_list))
        
        
        
        
        tmp_D_OMP_grid_sse_list_copy = copy.deepcopy(tmp_D_OMP_grid_sse_list)
        D_OMP_grid_SSE_list_of_lists.append(tmp_D_OMP_grid_sse_list_copy)
            
        
        D_OMP_grid_mean_SSE.append(np.mean(tmp_D_OMP_grid_sse_list))
        D_OMP_grid_median_SSE.append(np.median(tmp_D_OMP_grid_sse_list))
        
        
        
        
        
        
    CLOMP_mean_SSE_copy = copy.deepcopy(CLOMP_mean_SSE)
    CLOMP_mean_SSE_list.append(CLOMP_mean_SSE_copy)
    CLOMP_median_SSE_copy = copy.deepcopy(CLOMP_median_SSE)
    CLOMP_median_SSE_list.append(CLOMP_median_SSE_copy)
    
    
    
    D_OMP_grid_mean_SSE_copy = copy.deepcopy(D_OMP_grid_mean_SSE)
    D_OMP_grid_mean_SSE_list.append(D_OMP_grid_mean_SSE_copy)
    D_OMP_grid_median_SSE_copy = copy.deepcopy(D_OMP_grid_median_SSE)
    D_OMP_grid_median_SSE_list.append(D_OMP_grid_median_SSE_copy)    

#CLOMP_median_SSE_list = [np.median(CLOMP_SSE_list_of_lists[i]) for i in list(range(7))]

# Results


colors = plt.cm.Reds(np.linspace(0,1,4))
colors_2 = plt.cm.Blues(np.linspace(0,1,4))



f = plt.figure(figsize=(14, 10))
colors_list = [colors_2[2],colors[2]]



plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE_list[0],label='CL-OMP (m='+str(sketch_dim_list[0])+str(')'), color= colors_list[1],linewidth=4,linestyle = '--')
plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE_list[1],label='CL-OMP (m='+str(sketch_dim_list[1])+str(')'), color= colors_list[1],linewidth=4)
plt.plot(np.sqrt(sigma_list),D_OMP_grid_mean_SSE_list[0],label='Proposed algorithm (m='+str(sketch_dim_list[0])+str(')'), color= colors_list[0],linewidth=4,linestyle = '--')
plt.plot(np.sqrt(sigma_list),D_OMP_grid_mean_SSE_list[1],label='Proposed algorithm (m='+str(sketch_dim_list[1])+str(')'), color= colors_list[0],linewidth=4)



plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)
plt.axvline(x=0.01, color = 'black', linestyle = 'dotted',linewidth=3)
plt.axvline(x=0.05, color = 'black', linestyle = 'dotted',linewidth=3)
plt.axvline(x=0.1, color = 'black', linestyle = 'dotted',linewidth=3)
plt.axvline(x=0.2, color = 'black', linestyle = 'dotted',linewidth=3)
plt.axvline(x=1, color = 'black', linestyle = 'dotted',linewidth=3)


plt.grid(alpha=1, linestyle=':')
plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.minorticks_on()
plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)
plt.legend(fontsize =18, loc='upper right')
plt.xticks([0.01,0.05,0.1,0.2,1],['0.01', '0.05','0.1', '0.2','1'])
plt.yticks([0.01,0.1],['0.01', '0.1'])
f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2.pdf", bbox_inches='tight')
#plt.show()


k_means_SSE_list = []
k_means_plus_SSE_list = []
n_i_list = [1,2,5,10,20]

for n_i in n_i_list:
    kmeans_centroids = run_kmeans(X, nb_clust,n_i)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_SSE_list.append(k_means_SSE)
    print(k_means_SSE)

#kmeans_centroids = run_kmeans(X, nb_clust,20)
#k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
#print(k_means_SSE)

for n_i in n_i_list:
    kmeans_centroids = run_kmeans_plus(X, nb_clust,n_i)
    k_means_plus_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_plus_SSE_list.append(k_means_plus_SSE)
    print(k_means_plus_SSE)




np.save("npy/CLOMP_SSE_list_of_lists.npy",CLOMP_SSE_list_of_lists)
np.save("npy/D_OMP_grid_SSE_list_of_lists.npy",D_OMP_grid_SSE_list_of_lists)
np.save("npy/sqrtsigma_d_"+str(dim)+"_k_"+str(nb_clust)+"_m_"+str(sketch_dim)+".npy", np.sqrt(sigma_list))

np.save("npy/k_means_SSE_list.npy", k_means_SSE_list)
np.save("npy/k_means_plus_SSE_list.npy", k_means_plus_SSE_list)





    