#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os import path

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import seaborn as sns
color_pal = sns.color_palette("colorblind", 11).as_hex()
colors = ["black",  "reddish purple", "salmon pink", "neon pink", "cornflower","cobalt blue"
          ,"blue green", "aquamarine", "dark orange", "golden yellow", "reddish pink" ]

color_pal = sns.xkcd_palette(colors)
plt.close("all")


def read_perf(res, key):
    perf = res[key]
    perf_mean = perf.mean(axis=1)
    perf_std = perf.std(axis=1)
    
    return perf_mean, perf_std
    


models = ['gw','sgw', 'distrib_min_sse']
#models = ['sgw', 'distrib_min_sse']
#label_algo = ['SGW', 'DSE']
label_algo = ['GW','SGW', 'DSE']
data='shapes'
expe = "timing"
scaling="standard"
#scaling="standard"
nproj=1000
nproj_dist_d=10
max_epoch=50
n_iter_inner=1
dim_latent=5
vector_n = [100, 250, 500, 1000, 1500, 2000]

# pathes
filename = data
pathres='./result/' 

all_perfs_mean = np.empty((len(models), len(vector_n)))
all_perfs_std  = np.empty((len(models), len(vector_n)))

for id, model in enumerate(models):
    res_filename = f"shape_{expe}_{model}_dataset_{data}_s_{scaling}_nproj_{nproj}_nproj_d_{nproj_dist_d}_nb{max_epoch}_iterinner{n_iter_inner:d}_latent{dim_latent:d}"
    if path.exists(pathres+res_filename+'.npz'):
        res = np.load(pathres+res_filename+'.npz')
    else:
        raise OSError
    
    key = 'time_'+ model
    mean, std = read_perf(res, key)
    all_perfs_mean[id] = mean
    all_perfs_std[id] = std



#%% plots
markert = ['d','p','o','h','o','p','<','>','8','P']
colort = color_pal[:2]+color_pal[3:-1]

fig, ax = plt.subplots()

for i in range(len(models)):
    plt.plot(vector_n, all_perfs_mean[i], label=str(label_algo[i]), lw = 2, marker = markert[i], markersize=10,
                    c=colort[i])
    error= all_perfs_std[i]
    plt.fill_between(vector_n, all_perfs_mean[i]-error, all_perfs_mean[i]+error, color=colort[i],alpha = 0.2)
    

ax.set_ylabel('Running  time', fontsize=14)
ax.set_xlabel('Number of meshes per each shape', fontsize=14)

plt.legend(fontsize=15)
#
plt.yscale('log')
plt.xscale('log')

plt.tight_layout()


# Save the figure and show
pathfig = './figure/'
plt.savefig(pathfig+res_filename, bb_box='tight')


