#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""


import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
color_pal = sns.color_palette("colorblind", 11).as_hex()
colors = ["black",  "reddish purple", "salmon pink", "neon pink", "cornflower","cobalt blue"
          ,"blue green", "aquamarine", "dark orange", "golden yellow", "reddish pink" ]
color_pal = sns.xkcd_palette(colors)
plt.close("all")

#filename = 'toy_gaussian_nsample200_100_10_10_nb20_iterinner10.npz'
#filename = 'toy_spiral_nsample2000_1000_10_10_nb20_iterinner10.npz'
pathres = './result/toy/'


n_samples= 500
nproj = 100
nproj_dist_d = 10
nproj_dist = 10
nb_iter = 20
n_iter_inner = 50
liste_algo = ['sgw','distrib_min_sse']
label_algo = ['SGW', 'HWD','RI-SGW']

data = 'spiral'
M_time = []
M_val = []
S_val = []
for algo in liste_algo:
    filename = f"toy_{data}_{algo}_nsample{n_samples:d}_{nproj}_{nproj_dist}_{nproj_dist_d}_nb{nb_iter}_iterinner{n_iter_inner:d}"
    res = np.load(pathres+ filename + '.npz')
    vec_value = res['alldist_'+algo]
    vec_time = res['time_'+algo]
    m_vec_value = vec_value.mean(axis=1)
    m_vec_time = vec_time.mean(axis=1)
    s_vec_value = vec_value.std(axis=1)
    s_time_value = vec_time.std(axis=1)

    param_vec = res['param_vec']
    M_time.append(m_vec_time)
    M_val.append(m_vec_value)
    S_val.append(s_vec_value)

M_time =  np.array(M_time)
M_val =  np.array(M_val)
print(M_time)
print(M_val)

    #_time_sample.


plt.figure(figsize=(7,5))
ax = []
markert = ['o','p','s','d','h','o','p','<','>','8','P']
colort = color_pal
for i in range(len(liste_algo)):
    ax1, = plt.plot(param_vec, M_val[i,:],label=str(label_algo[i]), lw = 2, marker = markert[i], markersize=10,
                    c=colort[i])
    error=  S_val[i]
    plt.fill_between(param_vec, M_val[i,:]-error, M_val[i,:]+error, color=colort[i],alpha = 0.1)
    ax.append(ax1)

plt.grid()
plt.legend(fontsize=15)
if data == 'spiral':
    plt.xlabel('Angle of rotation', fontsize=14)
else :
    plt.xlabel('distance between means', fontsize=14)

plt.ylabel('Distance',fontsize=14)
figsave = f"{data}-distance.png"
plt.savefig(figsave,dpi=200,bb_box='tight')

#%%ù


sample_vec = [1000,5000,10000,20000,50000,100000]


M_time = np.zeros((len(liste_algo),len(sample_vec)))
S_time = np.zeros((len(liste_algo),len(sample_vec)))

for j,n_samples in enumerate(sample_vec):
    for i,algo in enumerate(liste_algo):
        filename = f"toy_{data}_{algo}_nsample{n_samples:d}_{nproj}_{nproj_dist}_{nproj_dist_d}_nb{nb_iter}_iterinner{n_iter_inner:d}"
        res = np.load(pathres+ filename + '.npz')
        vec_time = res['time_'+algo]
        m_vec_time = vec_time.mean()
        s_vec_time = vec_time.std()

        M_time[i,j] = m_vec_time
        S_time[i,j] = s_vec_time



    


plt.figure(figsize=(7,5))
ax = []
for i in range(len(liste_algo)):
    ax1, = plt.semilogy(sample_vec, M_time[i,:],label=str(label_algo[i]), lw = 2, marker = markert[i], markersize=10,
                    c=colort[i])
    error=  S_time[i]
    plt.fill_between(sample_vec, M_time[i,:]-error, M_time[i,:]+error, color=colort[i],alpha = 0.1)
    ax.append(ax1)

plt.grid()
plt.legend(fontsize=15)
plt.xlabel('Number of samples', fontsize=14)
plt.ylabel('Running time (s)',fontsize=14)
figsave = f"{data}-time.png"
plt.savefig(figsave,dpi=200,bb_box='tight')