import numpy as np
import scipy.stats
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes
colors = ['tab:orange', 'tab:blue', 'tab:green']

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

#plt.rcParams.update(bundles.neurips2022())
plt.rcParams.update(bundles.iclr2023())
fig_dict = figsizes.iclr2023(ncols=1)

#plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.56, fraction_height=0.78))
plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.3, fraction_height=0.4))
print(fig_dict)


def barplot_annotate_brackets(fig, ax, num1, num2, data, center, height, yerr=None, dh=.05, barh=.05, fs=None, maxasterix=None):
    """ 
    Annotate barplot with p-values.

    :param num1: number of left bar to put bracket over
    :param num2: number of right bar to put bracket over
    :param data: string to write or number for generating asterixes
    :param center: centers of all bars (like plt.bar() input)
    :param height: heights of all bars (like plt.bar() input)
    :param yerr: yerrs of all bars (like plt.bar() input)
    :param dh: height offset over bar / bar + yerr in axes coordinates (0 to 1)
    :param barh: bar height in axes coordinates (0 to 1)
    :param fs: font size
    :param maxasterix: maximum number of asterixes to write (for very small p-values)
    """

    if type(data) is str:
        text = data
    else:
        # * is p < 0.05
        # ** is p < 0.005
        # *** is p < 0.0005
        # etc.
        text = ''
        p = .05

        while data < p:
            text += '*'
            p /= 10.

            if maxasterix and len(text) == maxasterix:
                break

        if len(text) == 0:
            text = 'n. s.'

    lx, ly = center[num1], height[num1]
    rx, ry = center[num2], height[num2]

    if yerr:
        ly += yerr[num1]
        ry += yerr[num2]

    ax_y0, ax_y1 = fig.gca().get_ylim()
    dh *= (ax_y1 - ax_y0)
    barh *= (ax_y1 - ax_y0)

    y = max(ly, ry) + dh

    barx = [lx, lx, rx, rx]
    bary = [y, y+barh, y+barh, y]
    mid = ((lx+rx)/2, y+barh)

    ax.plot(barx, bary, c='black')

    kwargs = dict(ha='center', va='bottom')
    if fs is not None:
        kwargs['fontsize'] = fs

    ax.text(*mid, text, **kwargs)

if __name__ == '__main__':

    mpo_means = []
    mpo_stds = []
    # data = np.load(f'./neurips_ostrich_eval/speed/mpo_speed.npy')
    data1 = np.load(f'./speed_depmpo_fast.npy')
    mpo_means.append(np.mean(data1))
    mpo_stds.append(np.std(data1))
    # data = np.load(f'./neurips_ostrich_eval/speed/dep_mpo_speed.npy')
    data2 = np.load(f'./speed_mpo_fast.npy')
    mpo_means.append(np.mean(data2))
    mpo_stds.append(np.std(data2))
    # data = np.load(f'./neurips_ostrich_eval/speed/td4_speed.npy')
    data3 = np.load(f'./speed_td4_fast.npy')
    mpo_means.append(np.mean(data3))
    mpo_stds.append(np.std(data3))
    
    val = scipy.stats.ttest_ind(data1, data2).pvalue
    val2 = scipy.stats.ttest_ind(data2, data3).pvalue
    val3 = scipy.stats.ttest_ind(data1, data3).pvalue
    x = np.arange(len(mpo_means))
    #bars = axs.bar(x, height=mpo_means, yerr=mpo_stds, color=colors)
    fig, ax = plt.subplots(1,1)
    ax.bar(x, height=mpo_means, yerr=mpo_stds, color=colors)
    barplot_annotate_brackets(fig, ax, data=val, num1=0, num2=1, center=x, height=mpo_means, yerr=mpo_stds,
                              maxasterix=3, fs=8)
    ax.set_xticks(x)
    ax.set_xticklabels(['DEP-MPO', 'MPO', ' TD4'])


    ax.set_ylabel('max velocity (m/s)')
    ax.set_xlabel('sample', color='white')
    ax.set_ylim([0, 10])

    plt.savefig('robustness_plots_speed.pdf')
    print(data1.shape)
    print(data2.shape)
    print(scipy.stats.ttest_ind(data1, data2))
