import matplotlib.pyplot as plt
import utils
import json
import numpy as np
import statistics
from matplotlib.ticker import MaxNLocator
from scipy import stats
from utils import normality_test



directory='plot_data/'
save_directory='saved/'
save_format='png'
save_plots=False
delta_loss_default_range_linear_scale=(0,1000)


epoch_ratio_bar_plot=False

plot_test_loss=True
plot_test_acc=True
plot_relative_loss_improvements=True

log_scale=True
args={'figsize':(5,5),'dpi':100}

saved_dpi=1000

plot_means=True #only for test loss
plot_means_err=True
plot_medians=True
plot_medians_err=False

means_linestyle='--'
medians_linestyle='-'

plot_means_args={'linewidth':.5}

legend_linewidth=1.33333#1.33333

# not_improved_file_name='results_model5_not_improved_training_seed23_200trainings_30e_not_mnist'
# improved_file_name='results_model5_improved_training_seed23_100trainings_30e_not_mnist'
# improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e_not_mnist'

# not_improved_file_name='results_model6_not_improved_training_seed23_200trainings_40e'
# improved_file_name='results_model6_improved_training_seed23_100trainings_40e'
# improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e'
# delta_loss_default_range_linear_scale=(4,1000)

# not_improved_file_name='results_model6_not_improved_training_seed23_200trainings_40e_not_mnist'
# improved_file_name='results_model6_improved_training_seed23_100trainings_40e_not_mnist'
# improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e_not_mnist'
# delta_loss_default_range_linear_scale=(12,1000)

# not_improved_file_name='results_model5_not_improved_training_seed0_200trainings_30e'
# improved_file_name='results_model5_improved_training_seed24_100trainings_30e'
# improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e'

#0.1 lr
# not_improved_file_name='results_model5_not_improved_training_seed23_200trainings_30e_0.1lr'
# improved_file_name='results_model5_improved_training_seed23_100trainings_30e_0.1lr'
# improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_100trainings_30e_0.1lr'

# not_improved_file_name='results_model6_not_improved_training_seed23_200trainings_40e_0.1lr'
# improved_file_name='results_model6_improved_training_iter2_seed23_100trainings_40e_0.1lr'
# improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e_0.1lr'
# delta_loss_default_range_linear_scale=(20,1000)

# not_improved_file_name='results_model5_not_improved_training_seed23_200trainings_30e_not_mnist_0.1lr'
# improved_file_name='results_model5_improved_training_iter2_seed23_100trainings_30e_not_mnist_0.1lr'
# improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_100trainings_30e_not_mnist_0.1lr'

# not_improved_file_name='results_model6_not_improved_training_seed23_200trainings_40e_not_mnist_0.1lr'
# improved_file_name='results_model6_improved_training_iter2_seed23_100trainings_40e_not_mnist_0.1lr'
# improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e_not_mnist_0.1lr'
######################
improved_more_iter_file_name=None

### not_improved_file_name='results_model5_not_improved_training_seed0_200trainings_30e'
### improved_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e'

# # not_improved_file_name='results_model5_not_improved_training_seed23_100trainings_30e_not_mnist'
# # improved_file_name='results_model5_improved_training_iter2_seed23_100trainings_30e_not_mnist'
# # improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e_not_mnist'
# #
# # not_improved_file_name='results_model6_not_improved_training_seed23_100trainings_40e'
# # improved_file_name='results_model6_improved_training_iter2_seed23_100trainings_40e'
# # improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e'
# # delta_loss_default_range_linear_scale=(4,1000)
# #
# # not_improved_file_name='results_model6_not_improved_training_seed23_100trainings_40e_not_mnist'
# # improved_file_name='results_model6_improved_training_iter2_seed23_100trainings_40e_not_mnist'
# # improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e_not_mnist'
# # delta_loss_default_range_linear_scale=(12,1000)
# #
# # not_improved_file_name='results_model5_not_improved_training_seed23_100trainings_30e'
# # improved_file_name='results_model5_improved_training_iter2_seed23_100trainings_30e'
# # improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e'
# #
# # 0.1 lr
# # not_improved_file_name='results_model5_not_improved_training_seed23_100trainings_30e_0.1lr'
# # improved_file_name='results_model5_improved_training_iter2_seed23_100trainings_30e_0.1lr'
# # improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e_0.1lr'
# #
# # not_improved_file_name='results_model6_not_improved_training_seed23_100trainings_40e_0.1lr'
# # improved_file_name='results_model6_improved_training_iter2_seed23_100trainings_40e_0.1lr'
# # improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e_0.1lr'
# # delta_loss_default_range_linear_scale=(20,1000)
# #
# # not_improved_file_name='results_model5_not_improved_training_seed23_100trainings_30e_not_mnist_0.1lr'
# # improved_file_name='results_model5_improved_training_iter2_seed23_100trainings_30e_not_mnist_0.1lr'
# # improved_more_iter_file_name='results_model5_improved_training_iter5_seed23_50trainings_30e_not_mnist_0.1lr'
# #
# # not_improved_file_name='results_model6_not_improved_training_seed23_100trainings_40e_not_mnist_0.1lr'
# # improved_file_name='results_model6_improved_training_iter2_seed23_100trainings_40e_not_mnist_0.1lr'
# # improved_more_iter_file_name='results_model6_improved_training_iter5_seed23_50trainings_40e_not_mnist_0.1lr'

######################################################################################
# not_improved_file_name='results_model6_not_improved_training_seed5_200trainings_15e'
# improved_file_name='results_model6_improved_training_iter2_seed5_100trainings_15e'
# improved_more_iter_file_name='results_model6_improved_training_iter5_seed5_100trainings_15e'
## delta_loss_default_range_linear_scale=(4,1000)

# not_improved_file_name='results_model6_not_improved_training_seed5_200trainings_15e_not_mnist'
# improved_file_name='results_model6_improved_training_iter2_seed5_100trainings_15e_not_mnist'
# improved_more_iter_file_name='results_model6_improved_training_iter5_seed5_100trainings_15e_not_mnist'

not_improved_file_name='results_model9_not_improved_training_seed4_15trainings_500e'
improved_file_name='results_model9_improved_training_iter2_seed4_15trainings_300e'
improved_more_iter_file_name='results_model9_improved_training_iter5_seed4_7trainings_300e'
improved_more_iter_file_name='results_model9_improved_training2_seed1_30trainings_300e'

# not_improved_file_name='results_model9_not_improved_training_seed4_15trainings_500e_not_mnist'
# improved_file_name='results_model9_improved_training_iter2_seed4_15trainings_300e_not_mnist'
# improved_more_iter_file_name='results_model9_improved_training_iter5_seed4_7trainings_300e_not_mnist'
# improved_more_iter_file_name='results_model9_improved_training2_seed1_30trainings_300e_not_mnist'


###############plots for optimal LR for the gradient-based RMSProp:
# not_improved_file_name='results_model9_not_improved_training_seed4_15trainings_500e'
# improved_file_name='results_model9_improved_training_iter2_seed6_3trainings_500e'
# improved_more_iter_file_name='results_model9_improved_training_iter5_seed6_2trainings_500e'

# not_improved_file_name='results_model9_not_improved_training_seed4_15trainings_500e_not_mnist'
# improved_file_name='results_model9_improved_training_iter2_seed6_3trainings_500e_not_mnist'
# improved_more_iter_file_name='results_model9_improved_training_iter5_seed6_2trainings_500e_not_mnist'


not_improved=utils.read_file(directory+not_improved_file_name+'.txt')
improved=utils.read_file(directory+improved_file_name+'.txt')
improved_more_iter=utils.read_file(directory+improved_more_iter_file_name+'.txt') if improved_more_iter_file_name is not None else None

def count_of_trainings(file_name):
    return int(file_name[(file_name[:file_name.index('trainings_')].rindex('_')+1):file_name.index('trainings_')])
count_not_improved=count_of_trainings(not_improved_file_name)
count_improved=count_of_trainings(improved_file_name)
count_improved_more_iter=count_of_trainings(improved_more_iter_file_name)

count_not_improved=' of '+str(count_not_improved)
count_improved=' of '+str(count_improved)
count_improved_more_iter=' of '+str(count_improved_more_iter)

def change_legend_linewidth(width=legend_linewidth):
    leg=plt.legend()
    leg_lines=leg.get_lines()
    #plt.setp(leg_lines,linewidth=width)
    for line in leg_lines:
        if line.get_linewidth()<legend_linewidth:
            line.set_linewidth(legend_linewidth)
        else:
            line.set_linewidth(legend_linewidth*2./1.3333)

plt.figure(1,**args)
not_improved_to_plot=not_improved['train_loss']
improved_to_plot=improved['train_loss']
improved_more_iter_to_plot=improved_more_iter['train_loss'] if improved_more_iter is not None else None
improved_means,improved_sem=utils.means_and_sem_err(improved_to_plot)
not_improved_means,not_improved_sem=utils.means_and_sem_err(not_improved_to_plot)
improved_more_iter_means,improved_more_iter_sem=utils.means_and_sem_err(improved_more_iter_to_plot) if improved_more_iter is not None else (None,None)

if log_scale:
    plt.yscale("log")
else:
    plt.ylim([0, None])

if plot_means:
    if plot_means_err:
        plt.fill_between(x=not_improved_means.keys(),
                         y1=[y - utils.confidence_z_score*e for y, e in zip(not_improved_means.values(), not_improved_sem.values())],
                        y2=[y + utils.confidence_z_score*e for y, e in zip(not_improved_means.values(), not_improved_sem.values())], alpha=.2,color='darkcyan')
        plt.fill_between(x=improved_means.keys(),
                         y1=[y - utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())],
                        y2=[y + utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())], alpha=.2,color='magenta')

    plt.plot(not_improved_means.keys(),not_improved_means.values(),color='darkcyan',label='Standard Training (Mean'+count_not_improved+')',linestyle=means_linestyle,**plot_means_args)
    plt.plot(improved_means.keys(),improved_means.values(),color='magenta',label='My Algorithm (2 Iterations; Mean'+count_improved+')',linestyle=means_linestyle,**plot_means_args)
    if improved_more_iter is not None:
        if plot_means_err:
            plt.fill_between(x=improved_more_iter_means.keys(),
                             y1=[y - utils.confidence_z_score * e for y, e in
                                 zip(improved_more_iter_means.values(), improved_more_iter_sem.values())],
                             y2=[y + utils.confidence_z_score * e for y, e in
                                 zip(improved_more_iter_means.values(), improved_more_iter_sem.values())], alpha=.2,color='limegreen')
        plt.plot(improved_more_iter_means.keys(), improved_more_iter_means.values(), color='limegreen',
                 label='My Algorithm (5 Iterations; Mean'+count_improved_more_iter+')',linestyle=means_linestyle,**plot_means_args)


def medians_plot(_plt=plt,confidence_ranges=True):
    improved_medians, improved_median_sem = utils.medians_and_sem_err(improved_to_plot)
    not_improved_medians, not_improved_median_sem = utils.medians_and_sem_err(not_improved_to_plot)
    improved_more_iter_medians, improved_median_more_iter_sem = utils.medians_and_sem_err(
        improved_more_iter_to_plot) if improved_more_iter is not None else (None, None)


    if len(not_improved_medians.keys())!=0:
        _plt.plot(not_improved_medians.keys(), not_improved_medians.values(), color='darkblue', label='Standard Training (Median'+count_not_improved+')', linestyle=medians_linestyle)
    _plt.plot(improved_medians.keys(), improved_medians.values(), color='darkmagenta', label='My Algorithm (2 Iterations; Median'+count_improved+')', linestyle=medians_linestyle)
    ylim=0
    if improved_more_iter is not None:
        _plt.plot(improved_more_iter_medians.keys(), improved_more_iter_medians.values(), color='darkgreen',
                 label='My Algorithm (5 Iterations; Median'+count_improved_more_iter+')', linestyle=medians_linestyle)
        if plot_medians_err and confidence_ranges and improved_more_iter_medians:
            ylim = _plt.ylim()
            plt.autoscale(False)
            _plt.fill_between(x=improved_more_iter_medians.keys(),
                             y1=[y - utils.confidence_z_score*utils.standard_error_of_the_median_mul * e for y, e in
                                 zip(improved_more_iter_medians.values(), improved_median_more_iter_sem.values())],
                             y2=[y + utils.confidence_z_score*utils.standard_error_of_the_median_mul * e for y, e in
                                 zip(improved_more_iter_medians.values(), improved_median_more_iter_sem.values())], alpha=.2,
                             color='limegreen')

    # plt.yscale("log")
    # plt.draw()
    # xlim=_plt.xlim() if plt==_plt else _plt.get_xlim()
    # ylim = _plt.ylim() if plt == _plt else _plt.get_ylim()
    # if plt == _plt:
    #     _plt.xlim(xlim)
    #     _plt.ylim(ylim)
    #     # _plt.gca().autoscale(tight=True)
    # else:
    #     _plt.set_xlim(xlim)
    #     _plt.set_ylim(ylim)
    #     # _plt.autoscale(tight=True)

    #_plt.gca().set_xlim(_plt.xlim())
    if plot_medians_err and confidence_ranges and not_improved_medians:
        _plt.fill_between(x=not_improved_medians.keys(),
                         y1=[y - utils.confidence_z_score *utils.standard_error_of_the_median_mul* e for y, e in
                             zip(not_improved_medians.values(), not_improved_median_sem.values())],
                         y2=[y + utils.confidence_z_score*utils.standard_error_of_the_median_mul * e for y, e in
                             zip(not_improved_medians.values(), not_improved_median_sem.values())], alpha=.2, color='darkblue')



    if plot_medians_err and confidence_ranges and improved_medians:
        _plt.fill_between(x=improved_medians.keys(),
                         y1=[y - utils.confidence_z_score*utils.standard_error_of_the_median_mul * e for y, e in
                             zip(improved_medians.values(), improved_median_sem.values())],
                         y2=[y + utils.confidence_z_score*utils.standard_error_of_the_median_mul * e for y, e in
                             zip(improved_medians.values(), improved_median_sem.values())], alpha=.2, color='darkmagenta')
        # _plt.draw()
    try:
        if ylim:
            _plt.ylim(ylim)
    except:
        pass
    # if plt == _plt:
    #     if _plt.ylim<ymin
    #     _plt.ylim((ymin,_plt.ylim))
    #     # _plt.gca().autoscale(tight=True)
    # else:
    #     _plt.set_xlim(xlim)
    #     _plt.set_ylim(ylim)

if plot_medians:
    medians_plot()

#plt.title("Genetic Algorithm Mean Squared Error on Training Dataset for Sorting Algorithm Search")
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.legend()
#plt.ylim([1/30, 40])
change_legend_linewidth()

if save_plots:
    plt.savefig(save_directory + 'train_loss_' + not_improved_file_name + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)
#plt.show()



if epoch_ratio_bar_plot:
    width = 0.8  # the width of the bars: can also be len(x) sequence

    fig, ax = plt.subplots(**args)
    bottom = np.zeros(len(improved['train_higher_loss_batch_ratio'].keys()),dtype='float64')

    #for sex, sex_count in sex_counts.items():
    p = ax.bar(improved['train_lower_loss_batch_ratio'].keys(), improved['train_lower_loss_batch_ratio_aggregated'].values(), width, label="Lower Loss Batches\nCompared to Standard Update", bottom=bottom,color='g')
    bottom += np.array(list(improved['train_lower_loss_batch_ratio_aggregated'].values()))
    p = ax.bar(improved['train_same_loss_batch_ratio'].keys(), improved['train_same_loss_batch_ratio_aggregated'].values(), width, label="Same Loss Batches", bottom=bottom,color='black')
    bottom += np.array(list(improved['train_same_loss_batch_ratio_aggregated'].values()))
    p = ax.bar(improved['train_higher_loss_batch_ratio'].keys(), improved['train_higher_loss_batch_ratio_aggregated'].values(), width, label="Higher Loss Batches", bottom=bottom,color='red')

    #ax.bar_label(p, label_type='center')

    #ax.set_title('Ratios of Batches With Loss Improvement Over Standard Backpropagation Loss Update')
    ax.legend()

    ax.set_xlabel("Epoch")
    ax.set_ylabel("Ratio of All Batches")
    #plt.show()
else:
    plt.figure(2, **args)
    for (key,color,label) in [('train_lower_loss_batch_ratio','limegreen','Loss Decrease'),('train_same_loss_batch_ratio','black','Same Loss'),('train_higher_loss_batch_ratio','red','Loss Increase')]:
        #X=list(improved['train_lower_loss_batch_ratio'].keys())

        improved_to_plot = improved[key]
        improved_more_iter_to_plot = improved_more_iter[key] if improved_more_iter is not None else None
        improved_means, improved_sem = utils.means_and_sem_err(improved_to_plot)
        improved_more_iter_means, improved_more_iter_sem = utils.means_and_sem_err(improved_more_iter_to_plot) if improved_more_iter is not None else (None,None)

        if improved_more_iter is not None:
            plt.fill_between(x=improved_more_iter_means.keys(),
                         y1=[y - utils.confidence_z_score * e for y, e in
                             zip(improved_more_iter_means.values(), improved_more_iter_sem.values())],
                         y2=[y + utils.confidence_z_score * e for y, e in
                             zip(improved_more_iter_means.values(), improved_more_iter_sem.values())], alpha=.15,color=color)
        plt.fill_between(x=improved_means.keys(),
                         y1=[y - utils.confidence_z_score * e for y, e in
                             zip(improved_means.values(), improved_sem.values())],
                         y2=[y + utils.confidence_z_score * e for y, e in
                             zip(improved_means.values(), improved_sem.values())], alpha=.15,color=color)
        if improved_more_iter is not None:
            plt.plot(improved_more_iter_means.keys(), improved_more_iter_means.values(), color=color, linestyle="--",
                 label=label+' (5 Iterations)')
        plt.plot(improved_means.keys(), improved_means.values(), color=color, linestyle='-',
                 label=label+' (2 Iterations)')

        # plt.title("Genetic Algorithm Mean Squared Error on Training Dataset for Sorting Algorithm Search")
        plt.xlabel("Epoch")
        plt.ylabel("Ratio of All Batches")
        plt.legend()
        plt.ylim([0, None])#plt.ylim([0, 1])
        #plt.ylim([0, 1])

if save_plots:
    plt.savefig(save_directory + 'train_batch_ratios_' + not_improved_file_name + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)


improved_to_plot=improved['train_batch_avg_loss_improvement']
improved_more_iter_to_plot=improved_more_iter['train_batch_avg_loss_improvement'] if improved_more_iter is not None else None
improved_means,improved_sem=utils.means_and_sem_err(improved_to_plot)
improved_more_iter_means,improved_more_iter_sem=utils.means_and_sem_err(improved_more_iter_to_plot) if improved_more_iter is not None else (None,None)


plot_double=len(np.where(np.array(list(improved_means.values()))<=0)[0])!=0
if improved_more_iter is not None and not plot_double:
    plot_double = len(np.where(np.array(list(improved_more_iter_means.values())) <= 0)[0]) != 0

fig, axs=plt.subplots(2 if plot_double else 1,**args)
if not plot_double:
    axs=(axs,)
# plt.fill_between(x=improved_means.keys(),
#                  y1=[y - utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())],
#                 y2=[y + utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())], alpha=.25,color='magenta')
#plt.plot(improved_means.keys(),improved_means.values(),color='magenta',label='Optimized Training With Optimal Hyperparameters For Standard Training')
if improved_more_iter is not None:
    axs[0].errorbar(improved_more_iter_means.keys(),improved_more_iter_means.values(),yerr=utils.confidence_z_score*np.array(list(improved_more_iter_sem.values())),color='limegreen', ecolor=(0.7,1,0.7),label='5 Iterations',linestyle=means_linestyle)
axs[0].errorbar(improved_means.keys(),improved_means.values(),yerr=utils.confidence_z_score*np.array(list(improved_sem.values())),color='magenta', ecolor=(1,0.7,1),label='2 Iterations',linestyle=means_linestyle)

if plot_medians:
    not_improved_to_plot={}
    medians_plot(axs[0],confidence_ranges=False)

#if log_scale:
#axs[0].yscale("log")
axs[0].set_yscale("log")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("−ΔLoss")
axs[0].legend()
#plt.show()


#plt.figure(4,**args)

if plot_double:
    #improved_to_plot=improved['train_batch_avg_loss_improvement']
    #improved_means,improved_sem=utils.means_and_sem_err(improved_to_plot)

    # axs[1].fill_between(x=np.array(list(improved_means.keys()),dtype='float'),
    #                  y1=np.array([y - utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())],dtype='float'),
    #                 y2=np.array([y + utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())],dtype='float'), alpha=.25,color='green')
    #axs[1].plot(improved_means.keys(),improved_means.values(),color='green',label='Optimized Training With Optimal Hyperparameters For Standard Training')
    r=delta_loss_default_range_linear_scale
    if improved_more_iter is not None:
        axs[1].errorbar(list(improved_more_iter_means.keys())[r[0]:r[1]], list(improved_more_iter_means.values())[r[0]:r[1]],
                        yerr=utils.confidence_z_score * np.array(list(improved_more_iter_sem.values()))[r[0]:r[1]], color='limegreen', ecolor=(0.7,1,0.7),
                        label='5 Iterations',linestyle=means_linestyle)
    axs[1].errorbar(list(improved_means.keys())[r[0]:r[1]],list(improved_means.values())[r[0]:r[1]],yerr=utils.confidence_z_score*np.array(list(improved_sem.values()))[r[0]:r[1]],color='magenta', ecolor=(1,0.7,1),label='2 Iterations',linestyle=means_linestyle)


    #plt.errorbar(improved_means.keys(),improved_means.values(),yerr=utils.confidence_z_score*np.array(list(improved_sem.values())),color='magenta',label='Optimized Training With Optimal Hyperparameters For Standard Training')
    if plot_medians:
        medians_plot(axs[1])

    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("−ΔLoss")
    axs[1].legend()
    axs[1].xaxis.set_major_locator(MaxNLocator(integer=True))
    fig.subplots_adjust(hspace=0.25)

    # if log_scale:
    #     plt.yscale("log")
if save_plots:
    plt.savefig(save_directory + 'train_batch_avg_loss_improvement_' + not_improved_file_name + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)


improved_to_plot=improved['train_batch_avg_relative_loss_improvement']
improved_more_iter_to_plot=improved_more_iter['train_batch_avg_relative_loss_improvement'] if improved_more_iter is not None else None
improved_means,improved_sem=utils.means_and_sem_err(improved_to_plot)
improved_more_iter_means,improved_more_iter_sem=utils.means_and_sem_err(improved_more_iter_to_plot) if improved_more_iter is not None else (None,None)


plot_double=len(np.where(np.array(list(improved_means.values()))<=0)[0])!=0
if improved_more_iter is not None and not plot_double:
    plot_double = len(np.where(np.array(list(improved_more_iter_means.values())) <= 0)[0]) != 0


#fig, axs=plt.subplots(2 if plot_double else 1,**args)
#if not plot_double:
#    axs=(axs,)
if plot_relative_loss_improvements:
    plt.figure(10)
    axs=plt.gca()
    # plt.fill_between(x=improved_means.keys(),
    #                  y1=[y - utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())],
    #                 y2=[y + utils.confidence_z_score*e for y, e in zip(improved_means.values(), improved_sem.values())], alpha=.25,color='magenta')
    #plt.plot(improved_means.keys(),improved_means.values(),color='magenta',label='Optimized Training With Optimal Hyperparameters For Standard Training')
    axs.errorbar(improved_means.keys(),improved_means.values(),yerr=utils.confidence_z_score*np.array(list(improved_sem.values())),color='magenta', ecolor=(1,0.7,1),label='2 Iterations',linestyle=means_linestyle)
    if improved_more_iter is not None:
        axs.errorbar(improved_more_iter_means.keys(), improved_more_iter_means.values(),
                     yerr=utils.confidence_z_score * np.array(list(improved_more_iter_sem.values())), color='limegreen',
                     ecolor=(0.7, 1, 0.7), label='5 Iterations',linestyle=means_linestyle)

    #if log_scale:
    #axs[0].yscale("log")
    #axs.set_yscale("log")
    axs.set_xlabel("Epoch")
    axs.set_ylabel("−ΔLoss")
    axs.legend()

    if plot_medians:
        not_improved_to_plot = {}
        medians_plot(axs,confidence_ranges=False)

    if save_plots:
        plt.savefig(save_directory + 'train_batch_avg_relative_loss_improvement_' + not_improved_file_name + "." + save_format, bbox_inches='tight',
                    dpi=saved_dpi, format=save_format)




print("Standard training count: "+str(len(not_improved['train_loss_min'])))
print("Improved training count: "+str(len(improved['train_loss_min'])))
if improved_more_iter is not None:
    print("Improved training count (more iter): "+str(len(improved_more_iter['train_loss_min'])))
print()

improved_mean,improved_sem=utils.mean_and_sem_err(improved['train_loss_min'])
not_improved_mean,not_improved_sem=utils.mean_and_sem_err(not_improved['train_loss_min'])
improved_more_iter_means,improved_more_iter_sem=utils.mean_and_sem_err(improved_more_iter['train_loss_min']) if improved_more_iter is not None else (None,None)

if improved_more_iter is not None:
    print("Improved training average best loss over training (more iter): "+str(improved_more_iter_means)+" +/- "+str(improved_more_iter_sem*utils.confidence_z_score))
    means=list(utils.means_and_sem_err(improved_more_iter['train_loss'])[0].values())
    argmin=np.argmin(means)+1
    print("Epoch with best average: "+str(argmin))

print("Improved training average best loss over training: "+str(improved_mean)+" +/- "+str(improved_sem*utils.confidence_z_score))
means=list(utils.means_and_sem_err(improved['train_loss'])[0].values())
argmin=np.argmin(means)+1
print("Epoch with best average: "+str(argmin))

print("Not improved training average best loss over training: "+str(not_improved_mean)+" +/- "+str(not_improved_sem*utils.confidence_z_score))
means=list(utils.means_and_sem_err(not_improved['train_loss'])[0].values())
argmin=np.argmin(means)+1
print("Epoch with best average: "+str(argmin))

print()

if improved_more_iter is not None:
    improved_more_iter_mean, improved_more_iter_sem = utils.mean_and_sem_err(
        improved_more_iter['train_loss'][list(improved_more_iter['train_loss'].keys())[len(improved_more_iter['train_loss'].keys())-1]])
    print("Improved training last epoch loss (more iter): "+str(improved_more_iter_mean)+" +/- "+str(improved_more_iter_sem*utils.confidence_z_score))
improved_mean, improved_sem = utils.mean_and_sem_err(
    improved['train_loss'][list(improved['train_loss'].keys())[len(improved['train_loss'].keys())-1]])
print("Improved training last epoch loss: "+str(improved_mean)+" +/- "+str(improved_sem*utils.confidence_z_score))
not_improved_mean, not_improved_sem = utils.mean_and_sem_err(
    not_improved['train_loss'][list(not_improved['train_loss'].keys())[len(not_improved['train_loss'].keys())-1]])
print("Not improved training last epoch loss: "+str(not_improved_mean)+" +/- "+str(not_improved_sem*utils.confidence_z_score))

print()

improved_mean,improved_sem=utils.mean_and_sem_err(improved['test_loss_min'])
not_improved_mean,not_improved_sem=utils.mean_and_sem_err(not_improved['test_loss_min'])
improved_more_iter_means,improved_more_iter_sem=utils.mean_and_sem_err(improved_more_iter['test_loss_min']) if improved_more_iter is not None else (None,None)

if improved_more_iter is not None:
    print("Improved test average best loss over training (more iter): "+str(improved_more_iter_means)+" +/- "+str(improved_more_iter_sem*utils.confidence_z_score))
    means=list(utils.means_and_sem_err(improved_more_iter['test_loss'])[0].values())
    argmin=np.argmin(means)+1
    print("Epoch with best average: "+str(argmin))

print("Improved test average best loss over training: "+str(improved_mean)+" +/- "+str(improved_sem*utils.confidence_z_score))
means=list(utils.means_and_sem_err(improved['test_loss'])[0].values())
argmin=np.argmin(means)+1
print("Epoch with best average: "+str(argmin))

print("Not improved test average best loss over training: "+str(not_improved_mean)+" +/- "+str(not_improved_sem*utils.confidence_z_score))
means=list(utils.means_and_sem_err(not_improved['test_loss'])[0].values())
argmin=np.argmin(means)+1
print("Epoch with best average: "+str(argmin))

print()


improved_mean,improved_sem=utils.mean_and_sem_err(improved['test_accuracy_max'])
not_improved_mean,not_improved_sem=utils.mean_and_sem_err(not_improved['test_accuracy_max'])
improved_more_iter_means,improved_more_iter_sem=utils.mean_and_sem_err(improved_more_iter['test_accuracy_max']) if improved_more_iter is not None else (None,None)

if improved_more_iter is not None:
    print("Improved test average best accuracy over training (more iter): "+str(improved_more_iter_means)+" +/- "+str(improved_more_iter_sem*utils.confidence_z_score))
    means=list(utils.means_and_sem_err(improved_more_iter['test_accuracy'])[0].values())
    argmax=np.argmax(means)+1
    print("Epoch with best average: "+str(argmax))

print("Improved test average best accuracy over training: "+str(improved_mean)+" +/- "+str(improved_sem*utils.confidence_z_score))
means=list(utils.means_and_sem_err(improved['test_accuracy'])[0].values())
argmax=np.argmax(means)+1
print("Epoch with best average: "+str(argmax))

print("Not improved test average best accuracy over training: "+str(not_improved_mean)+" +/- "+str(not_improved_sem*utils.confidence_z_score))
means=list(utils.means_and_sem_err(not_improved['test_accuracy'])[0].values())
argmax=np.argmax(means)+1
print("Epoch with best average: "+str(argmax))
print()

key='train_avg_relative_loss_improvement'
if key in improved:
    if improved_more_iter is not None:
        improved_more_iter_mean, improved_more_iter_sem = utils.mean_and_sem_err(
            improved_more_iter[key][list(improved_more_iter[key].keys())[len(improved_more_iter[key].keys())-1]])
        print("Improved training (more iter) average relative loss improvement: "+str(improved_more_iter_mean)+" +/- "+str(improved_more_iter_sem*utils.confidence_z_score))
    improved_mean, improved_sem = utils.mean_and_sem_err(
        improved[key][list(improved[key].keys())[len(improved[key].keys())-1]])
    print("Improved training average relative loss improvement: "+str(improved_mean)+" +/- "+str(improved_sem*utils.confidence_z_score))
    print()

def display_param(param,percentile=None):
    improved_mean, improved_sem = utils.mean_and_sem_err(list(improved[param].values()))
    #not_improved_mean, not_improved_sem = utils.mean_and_sem_err(list(not_improved[param].values()))
    improved_more_iter_means, improved_more_iter_sem = utils.mean_and_sem_err(
        list(improved_more_iter[param].values())) if improved_more_iter is not None else (None, None)
    if percentile is not None:
        improved_mean, improved_sem=(np.percentile(list(improved[param].values()),percentile),float("inf"))
        improved_more_iter_means, improved_more_iter_sem=(np.percentile(list(improved_more_iter[param].values()),percentile),float("inf")) if improved_more_iter is not None else (None, None)
        print("Percentile "+str(percentile)+":")
    if improved_more_iter is not None:
        print("Improved "+param+" (more iter): " + str(
            improved_more_iter_means) + " +/- " + str(improved_more_iter_sem * utils.confidence_z_score))

    print("Improved "+param+": " + str(improved_mean) + " +/- " + str(
        improved_sem * utils.confidence_z_score))

    # print("Not improved "+param+": " + str(not_improved_mean) + " +/- " + str(
    #     not_improved_sem * utils.confidence_z_score))
    print()
def display_param_trainingwisely(param,to_percent=True,percentile=None):
    mul=1
    if to_percent:
        mul=100
    vals=[]
    data=improved
    for t in range(0, len(data[param][1])):
        avg = 0.
        count = 0
        for key in data[param].keys():
            avg += data[param][key][t]
            count += 1
        avg /= count
        vals.append(avg)
    improved_mean, improved_sem = utils.mean_and_sem_err(vals)
    if percentile is not None:
        improved_mean, improved_sem=(np.percentile(vals,percentile),1.2533141373155002512078826424055*improved_sem)#float("inf"))
        print("Percentile "+str(percentile)+":")
    print("Improved " + param + ": " + str(improved_mean*mul) + " +/- " + str(
        improved_sem * utils.confidence_z_score*mul))
    if improved_more_iter is not None:
        data = improved_more_iter
        vals=[]
        for t in range(0, len(data[param][1])):
            avg=0.
            count=0
            for key in data[param].keys():
                avg+=data[param][key][t]
                count+=1
            avg/=count
            vals.append(avg)
        improved_more_iter_means, improved_more_iter_sem = utils.mean_and_sem_err(
            vals)
        if percentile is not None:
            improved_more_iter_means, improved_more_iter_sem = (
            np.percentile(vals, percentile), 1.2533141373155002512078826424055*improved_more_iter_sem)#float("inf"))
        print("Improved " + param + " (more iter): " + str(
            improved_more_iter_means*mul) + " +/- " + str(improved_more_iter_sem * utils.confidence_z_score*mul))


display_param_trainingwisely('train_higher_loss_batch_ratio')
display_param_trainingwisely('train_same_loss_batch_ratio')
display_param_trainingwisely('train_lower_loss_batch_ratio')
# display_param('train_higher_loss_batch_ratio_aggregated')
# display_param('train_same_loss_batch_ratio_aggregated')
# display_param('train_lower_loss_batch_ratio_aggregated')

display_param('train_batch_avg_loss_improvement_aggregated',30)
display_param_trainingwisely('train_batch_avg_loss_improvement',percentile=50)
# print(sorted(improved["train_loss"][18],reverse=True))
# print(sorted(improved["train_loss"][29],reverse=True))
# print(sorted(not_improved["train_loss"][29],reverse=True))
# print(sorted(improved["train_loss"][30],reverse=True))


#print(stats.normaltest(not_improved['test_loss'][28]))
#print(stats.normaltest(not_improved['test_loss'][29]))


def test_param(param,p_value=0.05,verbose=False):
    verbose_lvl=2 if verbose else 1

    print('Tests for '+param+':')
    if param in not_improved:
        print('not_improved: ',end='')
        normality_test(not_improved[param],p_value,verbose_lvl)
    print('improved: ', end='')
    normality_test(improved[param],p_value,verbose_lvl)
    if improved_more_iter is not None:
        print('improved_more_iter: ', end='')
        normality_test(improved_more_iter[param],p_value,verbose_lvl)
    print('--------------------------------------')


# normality_test(not_improved['test_loss'])
# normality_test(improved['test_loss'])
# normality_test(improved_more_iter['test_loss'])
# #normality_test(improved['test_loss'])
# normality_test(not_improved['train_loss'])
# normality_test(improved['train_loss'])
# normality_test(improved_more_iter['train_loss'])
#
# normality_test(not_improved['test_accuracy'])
# normality_test(improved['test_loss'])
# normality_test(improved_more_iter['test_loss'])
try:
    print()
    test_param('test_loss',verbose=False)
    test_param('train_loss',verbose=False)
    test_param('test_accuracy',verbose=False)#test_param('test_accuracy',verbose=True)
    test_param('train_batch_avg_loss_improvement',verbose=False)
    test_param('train_lower_loss_batch_ratio',verbose=False)
    test_param('train_same_loss_batch_ratio',verbose=False)
    test_param('train_higher_loss_batch_ratio',verbose=False)
    print()
    print('-------------------------------------------------------------------')
    print()
    test_param('test_loss',p_value=0.5,verbose=False)
    test_param('train_loss',p_value=0.5,verbose=False)
    test_param('test_accuracy',p_value=0.5,verbose=False)#test_param('test_accuracy',verbose=True)
    test_param('train_batch_avg_loss_improvement',p_value=0.5,verbose=False)
    test_param('train_lower_loss_batch_ratio',p_value=0.5,verbose=False)
    test_param('train_same_loss_batch_ratio',p_value=0.5,verbose=False)
    test_param('train_higher_loss_batch_ratio',p_value=0.5,verbose=False)
except:
    print('Exception occurred in normality tests')
print()
print('-------------------------------------------------------------------')
print()
def last_epoch_stats(param):
    not_improved_pval=float('inf')
    improved_pval = float('inf')
    improved_more_iter_pval = float('inf')
    data=not_improved
    if param in data:
        not_improved_pval=utils.single_normality_test(data[param][len(data[param])-1])
        print('not improved - '+'Last epoch test loss p-val: '+str(not_improved_pval))
    data = improved
    if param in data:
        improved_pval = utils.single_normality_test(data[param][len(data[param]) - 1])
        print('improved - ' + 'Last epoch test loss p-val: ' + str(improved_pval))
    data = improved_more_iter
    if data is not None and param in data:
        improved_more_iter_pval = utils.single_normality_test(data[param][len(data[param]) - 1])
        print('improved more iter - ' + 'Last epoch test loss p-val: ' + str(improved_more_iter_pval))
    return (not_improved_pval,improved_pval,improved_more_iter_pval)

try:
    last_epoch_stats('test_loss')
except:
    print('Exception occurred in normality tests')

not_improved_to_plot=not_improved['train_loss']
improved_to_plot=improved['train_loss']
improved_more_iter_to_plot=improved_more_iter['train_loss'] if improved_more_iter is not None else None
improved_means,improved_sem=utils.means_and_sem_err(improved_to_plot)
not_improved_means,not_improved_sem=utils.means_and_sem_err(not_improved_to_plot)
improved_more_iter_means,improved_more_iter_sem=utils.means_and_sem_err(improved_more_iter_to_plot) if improved_more_iter is not None else (None,None)
if improved_more_iter is not None:
    better_epochs=0
    epochs=0
    for i in improved_more_iter_means.keys():
        epochs+=1
        if improved_more_iter_means[i]+improved_more_iter_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
            better_epochs+=1
        elif improved_means[i]+improved_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
            better_epochs += 1
    print('Percentage of epochs where any improved method is significantly better: '+str(100.*better_epochs/epochs)+'  '+str(better_epochs)+'/'+str(epochs))
if improved_more_iter is not None:
    better_epochs=0
    epochs=0
    for i in improved_more_iter_means.keys():
        epochs+=1
        if improved_more_iter_means[i]+improved_more_iter_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
            better_epochs+=1
        # elif improved_means[i]+improved_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
        #     better_epochs += 1
    print('Percentage of epochs where improved method with more iterations is significantly better: '+str(100.*better_epochs/epochs)+'  '+str(better_epochs)+'/'+str(epochs))
# if improved_more_iter is not None:
better_epochs=0
epochs=0
for i in improved_more_iter_means.keys():
    epochs+=1
    # if improved_more_iter_means[i]+improved_more_iter_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
    #     better_epochs+=1
    if improved_means[i]+improved_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
        better_epochs += 1
print('Percentage of epochs where improved method is significantly better: '+str(100.*better_epochs/epochs)+'  '+str(better_epochs)+'/'+str(epochs))
if improved_more_iter is not None:
    better_epochs=0
    epochs=0
    for i in improved_more_iter_means.keys():
        epochs+=1
        if improved_more_iter_means[i]-improved_more_iter_sem[i]>not_improved_sem[i]+not_improved_means[i]:
            better_epochs+=1
        elif improved_means[i]-improved_sem[i]>not_improved_sem[i]+not_improved_means[i]:
            better_epochs += 1
    print('Percentage of epochs where any of both improved method is significantly worse: '+str(100.*better_epochs/epochs)+'  '+str(better_epochs)+'/'+str(epochs))
if improved_more_iter is not None:
    better_epochs=0
    epochs=0
    for i in improved_more_iter_means.keys():
        epochs+=1
        if improved_more_iter_means[i]-improved_more_iter_sem[i]>not_improved_sem[i]+not_improved_means[i]:
            better_epochs+=1
        # elif improved_means[i]+improved_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
        #     better_epochs += 1
    print('Percentage of epochs where improved method with more iterations is significantly worse: '+str(100.*better_epochs/epochs)+'  '+str(better_epochs)+'/'+str(epochs))
# if improved_more_iter is not None:
better_epochs=0
epochs=0
for i in improved_more_iter_means.keys():
    epochs+=1
    # if improved_more_iter_means[i]+improved_more_iter_sem[i]<-not_improved_sem[i]+not_improved_means[i]:
    #     better_epochs+=1
    if improved_means[i]-improved_sem[i]>not_improved_sem[i]+not_improved_means[i]:
        better_epochs += 1
print('Percentage of epochs where improved method is significantly worse: '+str(100.*better_epochs/epochs)+'  '+str(better_epochs)+'/'+str(epochs))


if plot_test_loss:
    plt.figure(4, **args)
    not_improved_to_plot = not_improved['test_loss']
    improved_to_plot = improved['test_loss']
    improved_more_iter_to_plot = improved_more_iter['test_loss'] if improved_more_iter is not None else None
    improved_more_iter_means, improved_more_iter_sem = utils.means_and_sem_err(improved_more_iter_to_plot) if improved_more_iter is not None else (None,None)
    improved_means, improved_sem = utils.means_and_sem_err(improved_to_plot)
    not_improved_means, not_improved_sem = utils.means_and_sem_err(not_improved_to_plot)

    if plot_means:
        if plot_means_err:
            plt.fill_between(x=not_improved_means.keys(),
                             y1=[y - utils.confidence_z_score * e for y, e in
                                 zip(not_improved_means.values(), not_improved_sem.values())],
                             y2=[y + utils.confidence_z_score * e for y, e in
                                 zip(not_improved_means.values(), not_improved_sem.values())], alpha=.25, color='darkcyan')
            plt.fill_between(x=improved_means.keys(),
                             y1=[y - utils.confidence_z_score * e for y, e in
                                 zip(improved_means.values(), improved_sem.values())],
                             y2=[y + utils.confidence_z_score * e for y, e in
                                 zip(improved_means.values(), improved_sem.values())], alpha=.25, color='magenta')
        plt.plot(not_improved_means.keys(), not_improved_means.values(), color='darkcyan',
                 label='Standard Training (Mean'+count_not_improved+')',linestyle=means_linestyle,**plot_means_args)
        plt.plot(improved_means.keys(), improved_means.values(), color='magenta',
                 label='My Algorithm (2 Iterations; Mean'+count_improved+')',linestyle=means_linestyle,**plot_means_args)
        if improved_more_iter is not None:
            if plot_means_err:
                plt.fill_between(x=improved_more_iter_means.keys(),
                             y1=[y - utils.confidence_z_score * e for y, e in
                                 zip(improved_more_iter_means.values(), improved_more_iter_sem.values())],
                             y2=[y + utils.confidence_z_score * e for y, e in
                                 zip(improved_more_iter_means.values(), improved_more_iter_sem.values())], alpha=.25, color='limegreen')
            plt.plot(improved_more_iter_means.keys(), improved_more_iter_means.values(), color='limegreen',
                     label='My Algorithm (5 Iterations; Mean'+count_improved_more_iter+')',linestyle=means_linestyle,**plot_means_args)

    if plot_medians:
        medians_plot(plt)

    plt.xlabel("Epoch")
    plt.ylabel("Test Loss")
    plt.legend()
    change_legend_linewidth()

    if log_scale:
        plt.yscale("log")
    else:
        plt.ylim([0, None])
    if save_plots:
        plt.savefig(save_directory + 'test_loss_' + not_improved_file_name + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)

if plot_test_acc:
    plt.figure(5, **args)
    not_improved_to_plot = not_improved['test_accuracy']
    improved_to_plot = improved['test_accuracy']
    improved_more_iter_to_plot = improved_more_iter['test_accuracy'] if improved_more_iter is not None else None
    improved_more_iter_means, improved_more_iter_sem = utils.means_and_sem_err(
        improved_more_iter_to_plot) if improved_more_iter is not None else (None, None)
    improved_means, improved_sem = utils.means_and_sem_err(improved_to_plot)
    not_improved_means, not_improved_sem = utils.means_and_sem_err(not_improved_to_plot)

    if plot_means:
        plt.fill_between(x=not_improved_means.keys(),
                         y1=[y - utils.confidence_z_score * e for y, e in
                             zip(not_improved_means.values(), not_improved_sem.values())],
                         y2=[y + utils.confidence_z_score * e for y, e in
                             zip(not_improved_means.values(), not_improved_sem.values())], alpha=.25, color='darkcyan')
        plt.fill_between(x=improved_means.keys(),
                         y1=[y - utils.confidence_z_score * e for y, e in
                             zip(improved_means.values(), improved_sem.values())],
                         y2=[y + utils.confidence_z_score * e for y, e in
                             zip(improved_means.values(), improved_sem.values())], alpha=.25, color='magenta')
        plt.plot(not_improved_means.keys(), not_improved_means.values(), color='darkcyan',
                 label='Standard Training (Mean'+count_not_improved+')',linestyle=means_linestyle,**plot_means_args)
        plt.plot(improved_means.keys(), improved_means.values(), color='magenta',
                 label='My Algorithm (2 Iterations; Mean'+count_improved+')',linestyle=means_linestyle,**plot_means_args)
        if improved_more_iter is not None:
            plt.fill_between(x=improved_more_iter_means.keys(),
                             y1=[y - utils.confidence_z_score * e for y, e in
                                 zip(improved_more_iter_means.values(), improved_more_iter_sem.values())],
                             y2=[y + utils.confidence_z_score * e for y, e in
                                 zip(improved_more_iter_means.values(), improved_more_iter_sem.values())], alpha=.25,
                             color='limegreen')
            plt.plot(improved_more_iter_means.keys(), improved_more_iter_means.values(), color='limegreen',
                     label='My Algorithm (5 Iterations; Mean'+count_improved_more_iter+')',linestyle=means_linestyle,**plot_means_args)

    if plot_medians:
        medians_plot(plt)

    plt.xlabel("Epoch")
    plt.ylabel("Test Accuracy")
    plt.legend()
    change_legend_linewidth()
    #if log_scale:
    #    plt.yscale("log")
    #else:
    #    plt.ylim([0, None])
    if save_plots:
        plt.savefig(save_directory+'test_accuracy_'+not_improved_file_name+"."+save_format,bbox_inches='tight',dpi=saved_dpi,format=save_format)
plt.show()

#plt.savefig('filename.png', dpi=300)

