#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os import path

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import seaborn as sns
color_pal = sns.color_palette("colorblind", 11).as_hex()
colors = ["black",  "reddish purple", "salmon pink", "neon pink", "cornflower","cobalt blue"
          ,"blue green", "aquamarine", "dark orange", "golden yellow", "reddish pink" ]

color_pal = sns.xkcd_palette(colors)
plt.close("all")


def read_perf(res, key, n_repeat=10):
    perf = res[key]
    perf_mean = perf[:n_repeat, :].mean(axis=0)
    perf_std = perf[:n_repeat, :].std(axis=0)
    
    return perf_mean, perf_std
    


models = ['gw','sgw','risgw', 'distrib_min_sse']
label_algo = ['GW','SGW','RI-SGW', 'DSE']
data='correspondence'
expe = "InvarianceKNN"
n_vertices=1000
scaling="minmax"
#scaling="standard"
nproj=1000
nproj_dist_d=10
max_iter_ri=500
max_epoch=50
n_iter_inner=1
dim_latent=5

# pathes
filename = data
pathres='./result/' 

all_perfs_mean = np.empty((len(models), 2))
all_perfs_std  = np.empty((len(models), 2))

for id, model in enumerate(models):
    res_filename = f"shape_{expe}_{model}_dataset_{data}_vertices_{n_vertices}_s_{scaling}_nproj_{nproj}_nproj_d_{nproj_dist_d}_nb{max_epoch}_iterinner{n_iter_inner:d}_latent{dim_latent:d}"
    if path.exists(pathres+res_filename+'.npz'):
        res = np.load(pathres+res_filename+'.npz')
        n_repeat = 10
    else:
        res = np.load(pathres+res_filename+"-partial"+'.npz')
        n_repeat = 2
    
    key = 'perf_'+ model
    mean, std = read_perf(res, key, n_repeat=n_repeat)
    all_perfs_mean[id] = mean
    all_perfs_std[id] = std


#%% plots
markert = ['d','p','s','o','h','o','p','<','>','8','P']
colort = color_pal

fig, ax = plt.subplots()
x_pos = np.arange(len(models))

for id in range(len(models)):
    ax.errorbar(x=x_pos[id], y=all_perfs_mean[id,0], yerr=all_perfs_std[id,0], color = colort[id], marker=markert[id], 
                mfc=colort[id], mec=colort[id], alpha=0.9, ms=20, mew=2)
    
    
#ax.errorbar(x=x_pos[1], y=all_perfs_mean[1,0], yerr=all_perfs_std[1,0], marker='d', mfc="red", alpha=0.5, ms=20, mew=2)
#ax.errorbar(x=x_pos[2], y=all_perfs_mean[2,0], yerr=all_perfs_std[2,0], marker='p', mfc="green", alpha=0.5, ms=20, mew=2)

#ax.errorbar(x_pos, all_perfs_mean[:,0], yerr=all_perfs_std[:,0],  alpha=0.5, ecolor='black', capsize=3, marker='s')
ax.set_ylabel('Classification accuray', fontsize=16)
ax.set_xticks(x_pos)
ax.set_xticklabels(label_algo, fontsize=16)
#ax.set_title('Invariant shape retrieval')
#ax.yaxis.grid(True)
plt.tight_layout()
#plt.show()

# Save the figure and show
pathfig = './figure/'
plt.savefig(pathfig+res_filename, bb_box='tight')


#%%
#np.savez("data_perf",
#     mean_perf = all_perfs_mean,
#     std_pef = all_perfs_std,
#     label_algo = label_algo,
#     )
