import sys
### Change PATH by the path of the folder
sys.path.append('PATH')

from datetime import datetime


from pycle.utils.datasets import generatedataset_GMM
from pycle.sketching.feature_maps.MatrixFeatureMap import MatrixFeatureMap
from pycle.sketching import computeSketch
from pycle.compressive_learning.AltCLOMP_grid import AltCLOMP_grid

from pycle.compressive_learning.D_OMP import D_OMP
from pycle.compressive_learning.D_OMP_CKM import D_OMP_CKM

from pycle.compressive_learning.CLOMP_CKM import CLOMP_CKM

from loguru import logger
from pycle.utils.metrics import SSE
# from pycle.utils.vizualization import simple_plot_clustering, line_plot_clustering
from pycle.sketching.frequency_sampling import drawFrequencies
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import kmeans_plusplus

import matplotlib

from sklearn.cluster import KMeans



import numpy as np
import torch

import copy

#


def run_kmeans_plus(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='k-means++', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_


def run_kmeans(X, nb_clust,n_i):
    
    kmeans_cls = KMeans(n_clusters=nb_clust, init='random', n_init=n_i).fit(X)
    return kmeans_cls.cluster_centers_



def gaussian_kernel(sigma):
    def gaussian_kernel_aux(x,y):
        x_y_squared_distance = -np.power(np.linalg.norm(x-y),2)/sigma**2
        return np.exp(x_y_squared_distance)
    return gaussian_kernel_aux


def experiment_CLOMP(Phi_emp, z,mode,grid_size,int_grid_size):
    # Solve CLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "grid_mode": mode,
        "grid_size": grid_size,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1
    ckm_solver = CLOMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
        store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    
    return centroids, weights






def experiment_D_OMP_grid(Phi_emp, z,s_sigma,grid_size,opt_mode,mode,ms_step_size,ms_iteration_number,int_grid_size):
    # Solve AltCLOMP
    dct_bfgs = {
        "maxiter_inner_optimizations": 15000,
        "tol_inner_optimizations": 1e-9,
        "lr_inner_optimizations": 0.01,
        "opt_method_step_1": "lbfgs",
        "opt_method_step_34": "nnls",
        "opt_method_step_5": "lbfgs",
        "s_sigma": s_sigma,
        "grid_mode": mode,
        "grid_size": grid_size,
        "opt_mode": opt_mode,
        "ms_step_size": ms_step_size,
        "ms_iteration_number": ms_iteration_number,
        "int_grid_size": int_grid_size,
    }
    bounds = torch.tensor(np.array([-np.ones(dim), np.ones(dim)]))  # We assumed the data is normalized between -1 and 1

    ckm_solver = D_OMP_CKM(
        phi=Phi_emp, size_mixture_K=nb_clust, bounds=bounds, sketch_z=z,
    store_objective_values=False, dct_optim_method_hyperparameters=dct_bfgs
    )

    # Launch the CLOMP optimization procedure
    int_centroids,t_grid = ckm_solver.fit_once()

    # Get the solution
    (theta, weights) = ckm_solver.current_solution
    centroids, sigma = theta[..., :dim], theta[..., -dim:]
    

    return centroids,t_grid






# Dataset
#np.random.seed(20)  # easy
dim = 6
nb_clust = 3
nb_sample = 100000 #100000  # Number of samples we want to generate
sketch_dim = 1000



X = torch.load('dataset_6D.t')


X = X.double()




# plt.figure(figsize=(5, 5))
# plt.scatter(X[:, 0], X[:, 1], s=1, alpha=0.3)
# plt.xlim(-1, 1)
# plt.ylim(-1, 1)
# plt.show()




CLOMP_SSE_list_of_lists = [] 
D_OMP_grid_SSE_list_of_lists = []

AltCLOMP_SSE_g_1_list_of_lists = []
AltCLOMP_SSE_g_2_list_of_lists = []
AltCLOMP_SSE_g_3_list_of_lists = []
AltCLOMP_SSE_g_4_list_of_lists = []

AltCLOMP_SSE_sg_1_list_of_lists = []
AltCLOMP_SSE_sg_2_list_of_lists = []
AltCLOMP_SSE_sg_3_list_of_lists = []
AltCLOMP_SSE_sg_4_list_of_lists = []

AltCLOMP_SSE_g_sg_1_list_of_lists = []
AltCLOMP_SSE_g_sg_2_list_of_lists = []
AltCLOMP_SSE_g_sg_3_list_of_lists = []
AltCLOMP_SSE_g_sg_4_list_of_lists = []




CLOMP_mean_SSE = []
AltCLOMP_mean_SSE_g_1 = []
AltCLOMP_mean_SSE_g_2 = []
AltCLOMP_mean_SSE_g_3 = []
AltCLOMP_mean_SSE_g_4 = []

AltCLOMP_mean_SSE_sg_1 = []
AltCLOMP_mean_SSE_sg_2 = []
AltCLOMP_mean_SSE_sg_3 = []
AltCLOMP_mean_SSE_sg_4 = []

AltCLOMP_mean_SSE_g_sg_1 = []
AltCLOMP_mean_SSE_g_sg_2 = []
AltCLOMP_mean_SSE_g_sg_3 = []
AltCLOMP_mean_SSE_g_sg_4 = []



CLOMP_median_SSE = []
AltCLOMP_median_SSE_g_1 = []
AltCLOMP_median_SSE_g_2 = []
AltCLOMP_median_SSE_g_3 = []
AltCLOMP_median_SSE_g_4 = []

AltCLOMP_median_SSE_sg_1 = []
AltCLOMP_median_SSE_sg_2 = []
AltCLOMP_median_SSE_sg_3 = []
AltCLOMP_median_SSE_sg_4 = []

AltCLOMP_median_SSE_g_sg_1 = []
AltCLOMP_median_SSE_g_sg_2 = []
AltCLOMP_median_SSE_g_sg_3 = []
AltCLOMP_median_SSE_g_sg_4 = []

sigma_list = [0.0025,0.0049,0.01,0.02,0.03,0.04,0.05,0.09,0.1,0.12,0.16,0.25,0.49,1]
T = 100
S = len(sigma_list)

CLOMP_err_array = np.zeros((S,T))

AltCLOMP_sg_1_err_array = np.zeros((S,T))
AltCLOMP_sg_2_err_array = np.zeros((S,T))
AltCLOMP_sg_3_err_array = np.zeros((S,T))
AltCLOMP_sg_4_err_array = np.zeros((S,T))

AltCLOMP_g_1_err_array = np.zeros((S,T))
AltCLOMP_g_2_err_array = np.zeros((S,T))
AltCLOMP_g_3_err_array = np.zeros((S,T))
AltCLOMP_g_4_err_array = np.zeros((S,T))

AltCLOMP_z_1_err_array = np.zeros((S,T))
AltCLOMP_z_2_err_array = np.zeros((S,T))
AltCLOMP_z_3_err_array = np.zeros((S,T))
AltCLOMP_z_4_err_array = np.zeros((S,T))





CLOMP_mean_err_array = np.zeros((S))

AltCLOMP_sg_1_mean_err_array = np.zeros((S))
AltCLOMP_sg_2_mean_err_array = np.zeros((S))
AltCLOMP_sg_3_mean_err_array = np.zeros((S))
AltCLOMP_sg_4_mean_err_array = np.zeros((S))

AltCLOMP_g_1_mean_err_array = np.zeros((S))
AltCLOMP_g_2_mean_err_array = np.zeros((S))
AltCLOMP_g_3_mean_err_array = np.zeros((S))
AltCLOMP_g_4_mean_err_array = np.zeros((S))

AltCLOMP_z_1_mean_err_array = np.zeros((S))
AltCLOMP_z_2_mean_err_array = np.zeros((S))
AltCLOMP_z_3_mean_err_array = np.zeros((S))
AltCLOMP_z_4_mean_err_array = np.zeros((S))



CLOMP_median_err_array = np.zeros((S))

AltCLOMP_sg_1_median_err_array = np.zeros((S))
AltCLOMP_sg_2_median_err_array = np.zeros((S))
AltCLOMP_sg_3_median_err_array = np.zeros((S))
AltCLOMP_sg_4_median_err_array = np.zeros((S))

AltCLOMP_g_1_median_err_array = np.zeros((S))
AltCLOMP_g_2_median_err_array = np.zeros((S))
AltCLOMP_g_3_median_err_array = np.zeros((S))
AltCLOMP_g_4_median_err_array = np.zeros((S))

AltCLOMP_z_1_median_err_array = np.zeros((S))
AltCLOMP_z_2_median_err_array = np.zeros((S))
AltCLOMP_z_3_median_err_array = np.zeros((S))
AltCLOMP_z_4_median_err_array = np.zeros((S))





kmeans_plus_centroids = run_kmeans_plus(X, nb_clust,nb_clust)
k_means_plus_SSE = SSE(X, kmeans_plus_centroids)/nb_sample

print(k_means_plus_SSE)



k_means_SSE_5_list = []
for t in list(range(100)):
    kmeans_centroids = run_kmeans(X,nb_clust,5)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_SSE_5_list.append(k_means_SSE)
    
k_means_SSE = np.mean(k_means_SSE_5_list)


#[0.01,0.02,0.04,0.05,0.09,0.16,0.25,0.49,1]
#sigma_list = [0.16]
#[0.001,0.002,0.005,0.007,0.01,0.02,0.05,0.07,0.1,0.2,0.5]
#[0.01,0.02,0.05,0.07,0.1,0.2,0.5,0.7,1]

#[0.001,0.002,0.005,0.007,0.01,0.02,0.05,0.07,0.1,0.2,0.5]
for s_sigma in sigma_list:

    
    tmp_AltCLOMP_sse_list_g_1 = []
    tmp_AltCLOMP_sse_list_g_2 = []
    tmp_AltCLOMP_sse_list_g_3 = []
    tmp_AltCLOMP_sse_list_g_4 = []
    
    tmp_AltCLOMP_sse_list_sg_1 = []
    tmp_AltCLOMP_sse_list_sg_2 = []
    tmp_AltCLOMP_sse_list_sg_3 = []
    tmp_AltCLOMP_sse_list_sg_4 = []
    
    tmp_AltCLOMP_sse_list_g_sg_1 = []
    tmp_AltCLOMP_sse_list_g_sg_2 = []
    tmp_AltCLOMP_sse_list_g_sg_3 = []
    tmp_AltCLOMP_sse_list_g_sg_4 = []
    
    
    tmp_CLOMP_sse_list = []


    

    
    error_flag = 0



    for t in list(range(T)):
        grid = []
        t_grid = []
        print("sigma = "+str(s_sigma) + "and t = "+str(t) )
        


        #weights_t_gaussian, Omega_t_gaussian = radial_quadrature(dim,sketch_dim,s_sigma,s_truncated_gaussian,w_truncated_gaussian)    
        #Phi_t_gaussian = MatrixFeatureMap("complexexponential", torch.tensor(Omega_t_gaussian), device=torch.device("cpu"), weights=torch.tensor(weights_t_gaussian))
        
        
        Sigma_g =  s_sigma * np.eye(dim)
        Omega_g = drawFrequencies("gaussian", dim, sketch_dim, Sigma_g, return_torch=True)
        Phi_emp_g = MatrixFeatureMap("complexexponential", Omega_g,c_norm = 'unit', device=torch.device("cpu"))
        
        #print('s_decoder')
        #print(s_sigma+float(covariances[0][0,0]))
        z_gaussian = computeSketch(X, Phi_emp_g)
        
        
        
        centroids_CLOMP, weights_CLOMP = experiment_CLOMP(Phi_emp_g, z_gaussian,'fast',1,2*nb_clust)
        tmp_CLOMP_sse_list.append(SSE(X, centroids_CLOMP)/nb_sample)   
        CLOMP_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_CLOMP)/nb_sample

        
        centroids_AltCLOMP_shifted_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,1,'normal','shifted_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_sg_1.append(SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample) 
        AltCLOMP_sg_1_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample
        
        
        centroids_AltCLOMP_shifted_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,10,'normal','shifted_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_sg_2.append(SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample) 
        AltCLOMP_sg_2_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample
        
        
        centroids_AltCLOMP_shifted_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,100,'normal','shifted_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_sg_3.append(SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample) 
        AltCLOMP_sg_3_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample
        
        
        centroids_AltCLOMP_shifted_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,1000,'normal','shifted_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_sg_4.append(SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample)  
        AltCLOMP_sg_4_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_shifted_grid)/nb_sample



        
   
        centroids_AltCLOMP_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,1,'normal','static_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_g_1.append(SSE(X, centroids_AltCLOMP_grid)/nb_sample) 
        AltCLOMP_z_1_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_grid)/nb_sample


        centroids_AltCLOMP_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,10,'normal','static_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_g_2.append(SSE(X, centroids_AltCLOMP_grid)/nb_sample) 
        AltCLOMP_z_2_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_grid)/nb_sample

        
        centroids_AltCLOMP_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,100,'normal','static_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_g_3.append(SSE(X, centroids_AltCLOMP_grid)/nb_sample) 
        AltCLOMP_z_3_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_grid)/nb_sample

        
        centroids_AltCLOMP_grid,t_grid = experiment_D_OMP_grid(Phi_emp_g, z_gaussian,s_sigma,1000,'normal','static_grid',s_sigma,10,2*nb_clust)
        tmp_AltCLOMP_sse_list_g_4.append(SSE(X, centroids_AltCLOMP_grid)/nb_sample) 
        AltCLOMP_z_4_err_array[sigma_list.index(s_sigma),t] = SSE(X, centroids_AltCLOMP_grid)/nb_sample



    
    CLOMP_mean_SSE.append(np.mean(tmp_CLOMP_sse_list))
    
    AltCLOMP_mean_SSE_g_1.append(np.mean(tmp_AltCLOMP_sse_list_g_1))
    AltCLOMP_mean_SSE_g_2.append(np.mean(tmp_AltCLOMP_sse_list_g_2))
    AltCLOMP_mean_SSE_g_3.append(np.mean(tmp_AltCLOMP_sse_list_g_3))
    AltCLOMP_mean_SSE_g_4.append(np.mean(tmp_AltCLOMP_sse_list_g_4))
    
    AltCLOMP_mean_SSE_sg_1.append(np.mean(tmp_AltCLOMP_sse_list_sg_1))
    AltCLOMP_mean_SSE_sg_2.append(np.mean(tmp_AltCLOMP_sse_list_sg_2))
    AltCLOMP_mean_SSE_sg_3.append(np.mean(tmp_AltCLOMP_sse_list_sg_3))
    AltCLOMP_mean_SSE_sg_4.append(np.mean(tmp_AltCLOMP_sse_list_sg_4))
    
    AltCLOMP_mean_SSE_g_sg_1.append(np.mean(tmp_AltCLOMP_sse_list_g_sg_1))
    AltCLOMP_mean_SSE_g_sg_2.append(np.mean(tmp_AltCLOMP_sse_list_g_sg_2))
    AltCLOMP_mean_SSE_g_sg_3.append(np.mean(tmp_AltCLOMP_sse_list_g_sg_3))
    AltCLOMP_mean_SSE_g_sg_4.append(np.mean(tmp_AltCLOMP_sse_list_g_sg_4))
    
    
    
    CLOMP_mean_err_array[sigma_list.index(s_sigma)] = np.mean(CLOMP_err_array[sigma_list.index(s_sigma),:])
    
    
    AltCLOMP_z_1_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_z_1_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_z_2_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_z_2_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_z_3_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_z_3_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_z_4_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_z_4_err_array[sigma_list.index(s_sigma),:])
    
    
    AltCLOMP_g_1_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_g_1_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_g_2_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_g_2_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_g_3_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_g_3_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_g_4_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_g_4_err_array[sigma_list.index(s_sigma),:])
    
    
    AltCLOMP_sg_1_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_sg_1_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_sg_2_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_sg_2_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_sg_3_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_sg_3_err_array[sigma_list.index(s_sigma),:])
    AltCLOMP_sg_4_mean_err_array[sigma_list.index(s_sigma)] = np.mean(AltCLOMP_sg_4_err_array[sigma_list.index(s_sigma),:])
    
    

    
    
    
    
    
    CLOMP_median_SSE.append(np.median(tmp_CLOMP_sse_list))
    
    AltCLOMP_median_SSE_g_1.append(np.median(tmp_AltCLOMP_sse_list_g_1))
    AltCLOMP_median_SSE_g_2.append(np.median(tmp_AltCLOMP_sse_list_g_2))
    AltCLOMP_median_SSE_g_3.append(np.median(tmp_AltCLOMP_sse_list_g_3))
    AltCLOMP_median_SSE_g_4.append(np.median(tmp_AltCLOMP_sse_list_g_4))
    
    AltCLOMP_median_SSE_sg_1.append(np.median(tmp_AltCLOMP_sse_list_sg_1))
    AltCLOMP_median_SSE_sg_2.append(np.median(tmp_AltCLOMP_sse_list_sg_2))
    AltCLOMP_median_SSE_sg_3.append(np.median(tmp_AltCLOMP_sse_list_sg_3))
    AltCLOMP_median_SSE_sg_4.append(np.median(tmp_AltCLOMP_sse_list_sg_4))
    
    
    
    # tmp_CLOMP_sse_list_copy = copy.deepcopy(tmp_CLOMP_sse_list)
    # CLOMP_SSE_list_of_lists.append(tmp_CLOMP_sse_list_copy)
    
    # tmp_AltCLOMP_sse_list_g_1_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_1)
    # AltCLOMP_SSE_g_1_list_of_lists.append(tmp_AltCLOMP_sse_list_g_1_copy)
    
    # tmp_AltCLOMP_sse_list_g_2_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_2)
    # AltCLOMP_SSE_g_2_list_of_lists.append(tmp_AltCLOMP_sse_list_g_2_copy)
    
    # tmp_AltCLOMP_sse_list_g_3_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_3)
    # AltCLOMP_SSE_g_3_list_of_lists.append(tmp_AltCLOMP_sse_list_g_3_copy)
    
    # tmp_AltCLOMP_sse_list_g_4_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_4)
    # AltCLOMP_SSE_g_4_list_of_lists.append(tmp_AltCLOMP_sse_list_g_4_copy)
    
    
    
    # tmp_AltCLOMP_sse_list_sg_1_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_sg_1)
    # AltCLOMP_SSE_sg_1_list_of_lists.append(tmp_AltCLOMP_sse_list_sg_1_copy)
    
    # tmp_AltCLOMP_sse_list_sg_2_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_sg_2)
    # AltCLOMP_SSE_sg_2_list_of_lists.append(tmp_AltCLOMP_sse_list_sg_2_copy)
    
    # tmp_AltCLOMP_sse_list_sg_3_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_sg_3)
    # AltCLOMP_SSE_sg_3_list_of_lists.append(tmp_AltCLOMP_sse_list_sg_3_copy)
    
    # tmp_AltCLOMP_sse_list_sg_4_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_sg_4)
    # AltCLOMP_SSE_sg_4_list_of_lists.append(tmp_AltCLOMP_sse_list_sg_4_copy)
    
    
    
    # tmp_AltCLOMP_sse_list_g_sg_1_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_sg_1)
    # AltCLOMP_SSE_g_sg_1_list_of_lists.append(tmp_AltCLOMP_sse_list_g_sg_1_copy)
    
    # tmp_AltCLOMP_sse_list_g_sg_2_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_sg_2)
    # AltCLOMP_SSE_g_sg_2_list_of_lists.append(tmp_AltCLOMP_sse_list_g_sg_2_copy)
    
    # tmp_AltCLOMP_sse_list_g_sg_3_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_sg_3)
    # AltCLOMP_SSE_g_sg_3_list_of_lists.append(tmp_AltCLOMP_sse_list_g_sg_3_copy)
    
    # tmp_AltCLOMP_sse_list_g_sg_4_copy = copy.deepcopy(tmp_AltCLOMP_sse_list_g_sg_4)
    # AltCLOMP_SSE_g_sg_4_list_of_lists.append(tmp_AltCLOMP_sse_list_g_sg_4_copy)



dt = datetime.now()
print("Date and time is:", dt)




x_ticks_list = [0.05,0.1,0.3,1]
x_str_ticks_list = [str(x) for x in x_ticks_list]

y_ticks_list = [0.05,0.1,1]
y_str_ticks_list = [str(x) for x in y_ticks_list]


colors = plt.cm.Reds(np.linspace(0,1,4))
colors_2 = plt.cm.Blues(np.linspace(0,1,4))

colors_list = [colors_2[2],colors[2]]




###################

f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])

plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_1,label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_1,label='Algorithm 2 (discretized)',linewidth=4,color = colors_list[0],linestyle = '--')
#plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_sg_1,label='Algorithm 2 (gradient ascent)',linewidth=4,color = 'magenta',linestyle = 'dotted')



plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)
    
plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)

plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)
#plt.yticks([0.05,0.1,1],['0.05','0.1','1'])

plt.legend(fontsize =18,loc='upper right')

f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_1.pdf", bbox_inches='tight')

#plt.show()



#####

f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_err_array,label='CL-OMPR',linewidth=4,color = colors_list[1])

plt.plot(np.sqrt(sigma_list),AltCLOMP_sg_1_mean_err_array,label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
plt.plot(np.sqrt(sigma_list),AltCLOMP_z_1_mean_err_array,label='Algorithm 2 (discretized)',linewidth=4,color = colors_list[0],linestyle = '--')
#plt.plot(np.sqrt(sigma_list),AltCLOMP_g_1_mean_err_array,label='Algorithm 2 (gradient ascent)',linewidth=4,color = 'magenta',linestyle = 'dotted')



plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)
    
plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)

plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)
#plt.yticks([0.05,0.1,1],['0.05','0.1','1'])

plt.legend(fontsize =18,loc='upper right')

f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_1_.pdf", bbox_inches='tight')

#plt.show()




#####



cmap1 = plt.cm.get_cmap('Blues')
cmap2 = plt.cm.get_cmap('Reds')
cmap3 = plt.cm.get_cmap('Purples')


normalize = matplotlib.colors.Normalize(vmin=0, vmax=1)


#####


perc = 25


fig, ax = plt.subplots(figsize=(14,10))
    #ax.plot(all_M, cost_sw_sphdes[k,:], color='blue', label='SW_Wp Spherical Design', lw=3)

   


ax.plot(np.sqrt(sigma_list), CLOMP_mean_err_array, label='CL-OMPR',linewidth=4,color = colors_list[1])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(CLOMP_err_array,perc,axis=1), 
                np.percentile(CLOMP_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap2(normalize(1)), 
                facecolor=cmap2(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


ax.plot(np.sqrt(sigma_list), AltCLOMP_sg_1_mean_err_array, label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_sg_1_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_sg_1_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap1(normalize(1)), 
                facecolor=cmap1(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )

ax.plot(np.sqrt(sigma_list), AltCLOMP_z_1_mean_err_array, label='Algorithm 2 (discretized)',linewidth=4,color = 'purple')
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_z_1_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_z_1_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap3(normalize(1)), 
                facecolor=cmap3(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')

plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)


plt.legend(fontsize =18,loc='upper right')

fig.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_1__.pdf", bbox_inches='tight')



########################



f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])

plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_2,label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_2,label='Algorithm 2 (discretized)',linewidth=4,color = colors_list[0],linestyle = '--')
#plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_sg_2,label='Algorithm 2 (gradient ascent)',linewidth=4,color = 'magenta',linestyle = 'dotted')



plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)
    
plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)

plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

plt.legend(fontsize =18,loc='upper right')

f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_10.pdf", bbox_inches='tight')




#####


perc = 25


fig, ax = plt.subplots(figsize=(14,10))
    #ax.plot(all_M, cost_sw_sphdes[k,:], color='blue', label='SW_Wp Spherical Design', lw=3)

   


ax.plot(np.sqrt(sigma_list), CLOMP_mean_err_array, label='CL-OMPR',linewidth=4,color = colors_list[1])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(CLOMP_err_array,perc,axis=1), 
                np.percentile(CLOMP_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap2(normalize(1)), 
                facecolor=cmap2(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


ax.plot(np.sqrt(sigma_list), AltCLOMP_sg_2_mean_err_array, label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_sg_2_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_sg_2_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap1(normalize(1)), 
                facecolor=cmap1(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )

ax.plot(np.sqrt(sigma_list), AltCLOMP_z_2_mean_err_array, label='Algorithm 2 (discretized)',linewidth=4,color = 'purple')
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_z_2_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_z_2_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap3(normalize(1)), 
                facecolor=cmap3(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


plt.grid(alpha=1, linestyle=':')



plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)


plt.xscale('log')
plt.yscale('log')

plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)

plt.legend(fontsize =18,loc='upper right')

fig.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_10___.pdf", bbox_inches='tight')




######################


f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])

plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_3,label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_3,label='Algorithm 2 (discretized)',linewidth=4,color = colors_list[0],linestyle = '--')
#plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_sg_3,label='Algorithm 2 (gradient ascent)',linewidth=4,color = 'magenta',linestyle = 'dotted')


plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)
    
plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)
plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)
plt.legend(fontsize =18,loc='upper right')



f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_100.pdf", bbox_inches='tight')



#####


perc = 25


fig, ax = plt.subplots(figsize=(14,10))
    #ax.plot(all_M, cost_sw_sphdes[k,:], color='blue', label='SW_Wp Spherical Design', lw=3)

   


ax.plot(np.sqrt(sigma_list), CLOMP_mean_err_array, label='CL-OMPR',linewidth=4,color = colors_list[1])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(CLOMP_err_array,perc,axis=1), 
                np.percentile(CLOMP_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap2(normalize(1)), 
                facecolor=cmap2(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


ax.plot(np.sqrt(sigma_list), AltCLOMP_sg_3_mean_err_array, label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_sg_3_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_sg_3_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap1(normalize(1)), 
                facecolor=cmap1(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )

ax.plot(np.sqrt(sigma_list), AltCLOMP_z_3_mean_err_array, label='Algorithm 2 (discretized)',linewidth=4,color = 'purple')
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_z_3_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_z_3_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap3(normalize(1)), 
                facecolor=cmap3(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


plt.grid(alpha=1, linestyle=':')

plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)


for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.xscale('log')
plt.yscale('log')
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)
plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)

plt.legend(fontsize =18,loc='upper right')

fig.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_100__.pdf", bbox_inches='tight')







###############

f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_4,label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_4,label='Algorithm 2 (discretized)',linewidth=4,color = colors_list[0],linestyle = '--')
#plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_g_sg_4,label='Algorithm 2 (gradient ascent)',linewidth=4,color = 'magenta',linestyle = 'dotted')



plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)
    
plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 
plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

plt.legend(fontsize =18,loc='upper right')


f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_1000.pdf", bbox_inches='tight')







colors = plt.cm.Reds(np.linspace(0,1,4))
colors_2 = plt.cm.Blues(np.linspace(0,1,4))

colors_list = [colors_2[2],colors[2]]




#####


perc = 25


fig, ax = plt.subplots(figsize=(14,10))
    #ax.plot(all_M, cost_sw_sphdes[k,:], color='blue', label='SW_Wp Spherical Design', lw=3)

   


ax.plot(np.sqrt(sigma_list), CLOMP_mean_err_array, label='CL-OMPR',linewidth=4,color = colors_list[1])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(CLOMP_err_array,perc,axis=1), 
                np.percentile(CLOMP_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap2(normalize(1)), 
                facecolor=cmap2(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


ax.plot(np.sqrt(sigma_list), AltCLOMP_sg_4_mean_err_array, label='Algorithm 2 (sketched mean shift)',linewidth=4,color = colors_list[0])
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_sg_4_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_sg_4_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap1(normalize(1)), 
                facecolor=cmap1(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )

ax.plot(np.sqrt(sigma_list), AltCLOMP_z_4_mean_err_array, label='Algorithm 2 (discretized)',linewidth=4,color = 'purple')
ax.fill_between(np.sqrt(sigma_list), 
                np.percentile(AltCLOMP_z_4_err_array,perc,axis=1), 
                np.percentile(AltCLOMP_z_4_err_array,100-perc,axis=1), 
                alpha=0.2, 
                edgecolor=cmap3(normalize(1)), 
                facecolor=cmap3(normalize(1)),
                linewidth=2, 
                linestyle='dashdot', 
                antialiased=True
               )


plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm",linewidth=3)

for t in x_ticks_list:
    plt.axvline(x=t, color = 'black', linestyle = 'dotted',linewidth=3)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)

plt.legend(fontsize =18,loc='upper right')

fig.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_L_1000__.pdf", bbox_inches='tight')





#################

f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])


plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_1,label='Algorithm 2 (sketched mean shift, L = 1)',linewidth=4,color = colors_list[0],linestyle = 'dotted')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_2,label='Algorithm 2 (sketched mean shift, L = 10)',linewidth=4,color = colors_list[0], linestyle = '-.')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_3,label='Algorithm 2 (sketched mean shift, L = 100)',linewidth=4,color = colors_list[0], linestyle = '--')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_4,label='Algorithm 2 (sketched mean shift, L = 1000)',linewidth=4,color = colors_list[0])




plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm")

plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)
plt.legend(fontsize =18,loc='upper right')
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_variousL.pdf", bbox_inches='tight')
#plt.show()



#################

f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])


#plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_1,label='Algorithm 2 (sketched mean shift, L = 1)',linewidth=4,color = colors_list[0],linestyle = 'dotted')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_2,label='Algorithm 2 (sketched mean shift, L = 10)',linewidth=4,color = colors_list[0], linestyle = '-.')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_3,label='Algorithm 2 (sketched mean shift, L = 100)',linewidth=4,color = colors_list[0], linestyle = '--')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_4,label='Algorithm 2 (sketched mean shift, L = 1000)',linewidth=4,color = colors_list[0])




plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm")

plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)
plt.legend(fontsize =18,loc='upper right')
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_variousL_2.pdf", bbox_inches='tight')
#plt.show()


f = plt.figure(figsize=(14, 10))

plt.plot(np.sqrt(sigma_list),CLOMP_mean_SSE,label='CL-OMPR',linewidth=4,color = colors_list[1])


#plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_1,label='Algorithm 2 (sketched mean shift, L = 1)',linewidth=4,color = colors_list[0],linestyle = 'dotted')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_2,label='Algorithm 2 (sketched mean shift, L = 10)',linewidth=4,color = colors_list[0], linestyle = 'dotted')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_3,label='Algorithm 2 (sketched mean shift, L = 100)',linewidth=4,color = colors_list[0], linestyle = '--')
plt.plot(np.sqrt(sigma_list),AltCLOMP_mean_SSE_sg_4,label='Algorithm 2 (sketched mean shift, L = 1000)',linewidth=4,color = colors_list[0])




plt.axhline(y = k_means_SSE, color = 'black', linestyle = '--',label="Lloyd's algorithm")

plt.grid(alpha=1, linestyle=':')

plt.xscale('log')
plt.yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
 

plt.xlabel(r'$\sigma$',fontsize =24)
plt.ylabel('MSE',fontsize =24)
plt.legend(fontsize =18,loc='upper right')
plt.xticks(x_ticks_list,x_str_ticks_list)
plt.yticks(y_ticks_list,y_str_ticks_list)

f.savefig("pdf/CLOMP_vs_AltCLOMP_d_"+str(dim)+"_k_"+str(nb_clust)+"_log_2k_2_SSE_variousL_3.pdf", bbox_inches='tight')
#plt.show()








dt = datetime.now()
print("Date and time is:", dt)



k_means_SSE_list = []
k_means_plus_SSE_list = []
n_i_list = [1,2,3,4,5,6,7,8,9,10,20]

for n_i in n_i_list:
    kmeans_centroids = run_kmeans(X, nb_clust,n_i)
    k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_SSE_list.append(k_means_SSE)
    print(k_means_SSE)

#kmeans_centroids = run_kmeans(X, nb_clust,20)
#k_means_SSE = SSE(X, kmeans_centroids)/nb_sample
#print(k_means_SSE)

for n_i in n_i_list:
    kmeans_centroids = run_kmeans_plus(X, nb_clust,n_i)
    k_means_plus_SSE = SSE(X, kmeans_centroids)/nb_sample
    k_means_plus_SSE_list.append(k_means_plus_SSE)
    print(k_means_plus_SSE)





f = plt.figure(figsize=(14, 10))
plt.plot(n_i_list,k_means_SSE_list)
plt.ylabel('kmeans',fontsize =24)
f.savefig("pdf/kmeans_vs_ninit.pdf", bbox_inches='tight')
#plt.show()


f = plt.figure(figsize=(14, 10))
plt.plot(n_i_list,k_means_plus_SSE_list)
plt.ylabel('kmeans++',fontsize =24)
f.savefig("pdf/kmeans_plus_vs_ninit.pdf", bbox_inches='tight')
#plt.show()



np.save("npy/CLOMP_err_array.npy",CLOMP_err_array)

np.save("npy/AltCLOMP_z_1_err_array.npy",AltCLOMP_z_1_err_array)
np.save("npy/AltCLOMP_z_2_err_array.npy",AltCLOMP_z_2_err_array)
np.save("npy/AltCLOMP_z_3_err_array.npy",AltCLOMP_z_3_err_array)
np.save("npy/AltCLOMP_z_4_err_array.npy",AltCLOMP_z_4_err_array)

np.save("npy/AltCLOMP_g_1_err_array.npy",AltCLOMP_g_1_err_array)
np.save("npy/AltCLOMP_g_2_err_array.npy",AltCLOMP_g_2_err_array)
np.save("npy/AltCLOMP_g_3_err_array.npy",AltCLOMP_g_3_err_array)
np.save("npy/AltCLOMP_g_4_err_array.npy",AltCLOMP_g_4_err_array)

np.save("npy/AltCLOMP_sg_1_err_array.npy",AltCLOMP_sg_1_err_array)
np.save("npy/AltCLOMP_sg_2_err_array.npy",AltCLOMP_sg_2_err_array)
np.save("npy/AltCLOMP_sg_3_err_array.npy",AltCLOMP_sg_3_err_array)
np.save("npy/AltCLOMP_sg_4_err_array.npy",AltCLOMP_sg_4_err_array)




np.save("npy/sqrtsigma_d_"+str(dim)+"_k_"+str(nb_clust)+"_m_"+str(sketch_dim)+".npy", np.sqrt(sigma_list))

np.save("npy/k_means_SSE_list.npy", k_means_SSE_list)
np.save("npy/k_means_plus_SSE_list.npy", k_means_plus_SSE_list)



dt = datetime.now()
print("Date and time is:", dt)