import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import seaborn as sns
colors = sns.color_palette()
from matplotlib import rcParams
from matplotlib.ticker import (MultipleLocator,
                               FormatStrFormatter,
                               AutoMinorLocator)
params = {'axes.labelsize': 32,
          'axes.grid': True,
          'axes.linewidth': 1.6,
          'axes.titlepad': 20,
          'axes.xmargin': 0.05,
          'axes.ymargin': 0.05,
          'grid.alpha': 0.2,
          'grid.color': '#666666',
          'grid.linestyle': '-.',
          'legend.fontsize': 32,
          'legend.loc': 'best',
          'xtick.labelsize': 34,
          'xtick.major.width': 1.6,
          'xtick.major.size': 15,
          'xtick.minor.width': 1.0,
          'xtick.minor.size': 4,
          'ytick.labelsize': 34,
          'ytick.major.width': 1.6,
          'ytick.major.size': 15,
          'ytick.minor.width': 1.0,
          'ytick.minor.size': 4,
          'text.usetex': True,
          'figure.figsize': [12, 8],
          'font.size': 32.0, 
          'lines.markersize': np.sqrt(20) * 2.5,
          'figure.autolayout': True}
rcParams.update(params)
import json
import os

path_file = './last_results'

model_to_draw = ['AllGConv_7_layers', 'AllBConv_5_layers_2', 'AllConv_7_layers']
legends = ['G-CNN', 'B-CNN', 'CNN']

colors = ['#d7191c', '#2b83ba', '#abdda4', '#fdae61', '#ffffbf', '#f0027f']

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

j_files = [x for x in os.listdir(path_file) if '.json' in x]

n_runs = len(j_files)

summary = {}
for idx, j_file in enumerate(j_files):

    with open(os.path.join(path_file + '/', j_file)) as json_file:
        data = json.load(json_file)

    models = data.keys()

    for model in models:

        if idx == 0:

            n_epochs = data[model]['n_epochs']
            summary[model] = {'testing_loss': np.empty((n_runs, n_epochs), dtype=np.double),
                              'training_loss': np.empty((n_runs, n_epochs), dtype=np.double),
                              'testing_acc': np.empty((n_runs, n_epochs), dtype=np.double),
                              'training_acc': np.empty((n_runs, n_epochs), dtype=np.double)}

        summary[model]['testing_loss'][idx,:] = np.array(data[model]['testing_loss'])
        summary[model]['training_loss'][idx,:] = np.array(data[model]['training_loss'])
        summary[model]['testing_acc'][idx,:] = np.array(data[model]['testing_acc']) * 100.
        summary[model]['training_acc'][idx,:] = np.array(data[model]['training_acc']) * 100.

for metric in summary[model].keys():

    fig, ax = plt.subplots()
    ax.set_xlabel(r'Epoch')
    if 'loss' in metric:
        ax.set_ylabel(r'Loss')
    else:
        ax.set_ylabel(r'Accuracy (\%)')

    for idx, model in enumerate(model_to_draw):

        mean = []
        lower = []
        upper = []
        for i in range(n_epochs):
            m, l, u = mean_confidence_interval(summary[model + '.json'][metric][:,i], confidence=0.99)
            mean.append(m)
            lower.append(l)
            upper.append(u)

        ax.fill_between(list(range(n_epochs)), lower, upper, color=colors[idx], alpha=.5)
        ax.plot(list(range(n_epochs)), mean, color=colors[idx], lw=3, label=legends[idx])

    ax.legend(loc='best')

    fig.tight_layout()
    fig.savefig(path_file + '/' + metric + '.png', dpi=80)
    plt.close('all')
        

        