from matplotlib import pyplot as plt
import numpy as np
from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

plt.rcParams.update(bundles.iclr2023())
fig_dict = figsizes.iclr2023()
plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.3, fraction_height=0.75))
#plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.82, fraction_height=0.78))
#bundle['figure.figsize'] = (4.0, 3.3)
## bundle['font.size'] = 14
#bundle['axs[0]es.labelsize'] = 12
## bundle['legend.fontsize'] = 8
#bundle['xtick.labelsize'] = 12
#bundle['ytick.labelsize'] = 12
#bundle['axs[0]es.titlesize'] = 18
PLOT_STEPDOWN = 0
PLOT_SLOPETROTTER = 0
PLOT_SPEED = 0
PLOT_HUMAN_TROTTER = 1
width = 0.7
window = 10

colors = ['tab:orange', 'tab:blue', 'tab:green']
fig, axs = plt.subplots(1, 1)
if PLOT_STEPDOWN:
    distances = ['005', '01', '015', '02', '025']
    x = [5, 10, 15, 20, 25]
    mpo_means = []
    mpo_stds = []
    for d in distances:
        # data = np.load(f'./stepdown_depmpo_robust_{d}.npy')
        data = np.load(f'./stepdown_depmpo_robust_no_correction_{d}.npy')
        print(data.shape)
        data_packets = [data[k*window: (k+1) * window] for k in range(0, data.shape[-1] // window)]
        data_packets = np.array(data_packets)
        mpo_means.append(np.mean(np.mean(data_packets, axis=0), axis=-1))
        # mpo_stds.append(np.std(data))
        print(data_packets.shape)
        mpo_stds.append(np.std(np.mean(data_packets, axis=0)))

    #x = np.arange(len(mpo_means))
    axs.errorbar(x, mpo_means, yerr=mpo_stds, label='DEP-MPO', color='tab:orange')


    mpo_means = []
    mpo_stds = []
    for d in distances:
        # data = np.load(f'./stepdown_mpo_robust_{d}.npy')
        data = np.load(f'./stepdown_mpo_robust_no_correction_{d}.npy')
        data_packets = [data[k*window: (k+1) * window] for k in range(0, data.shape[-1] // window)]
        data_packets = np.array(data_packets)
        mpo_means.append(np.mean(np.mean(data_packets, axis=0), axis=-1))
        # mpo_stds.append(np.std(data))
        mpo_stds.append(np.std(np.mean(data_packets, axis=0)))

    #x = np.arange(len(mpo_means))
    #axs[0].bar(x, height=mpo_means, width=width/3, yerr=mpo_stds, label='MPO', color='tab:blue')
    axs.errorbar(x, mpo_means, yerr=mpo_stds, label='MPO', color='tab:blue')
    mpo_means = []
    mpo_stds = []
    for d in distances:
        # data = np.load(f'./stepdown_td4_{d}.npy')
        data = np.load(f'./stepdown_td4_no_correction_{d}.npy')
        data_packets = [data[k*window: (k+1) * window] for k in range(0, data.shape[-1] // window)]
        data_packets = np.array(data_packets)
        mpo_means.append(np.mean(np.mean(data_packets, axis=0), axis=-1))
        mpo_stds.append(np.std(np.mean(data_packets, axis=0)))
        # mpo_stds.append(np.std(data))

    #x = np.arange(len(mpo_means))
    #axs[0].bar(x+width/3, height=mpo_means, width=width/3, yerr=mpo_stds, label='TD4', color='tab:green')
    axs.errorbar(x, mpo_means, yerr=mpo_stds, label='TD4', color='tab:green')
    xticks = np.arange(len(mpo_means))
    axs.set_xticks(x)
    axs.set_xticklabels(['5', '10', '15', '20', '25'])
    axs.set_ylim([0, 1.1])
    #axs[0].legend()
    axs.set_xlabel('step height (cm)')
    axs.set_ylabel('success rate')
    #plt.tight_layout()
    # plt.savefig('robustness_ostrich.pdf')
    # plt.show()

    mpo_means = []
    mpo_stds = []
    for d in distances:
        # data = np.load(f'./stepdown_mpo_robust_{d}.npy')
        data = np.load(f'./stepdown_depmpo_init_no_correction_{d}.npy')
        data_packets = [data[k*window: (k+1) * window] for k in range(0, data.shape[-1] // window)]
        data_packets = np.array(data_packets)
        mpo_means.append(np.mean(np.mean(data_packets, axis=0), axis=-1))
        # mpo_stds.append(np.std(data))
        mpo_stds.append(np.std(np.mean(data_packets, axis=0)))

    #x = np.arange(len(mpo_means))
    #axs[0].bar(x, height=mpo_means, width=width/3, yerr=mpo_stds, label='MPO', color='tab:blue')
    #axs[0].errorbar(x, mpo_means, yerr=mpo_stds, label='DEPMPOinit', color='tab:purple')
    plt.savefig('stepdownostr.pdf')

if PLOT_SLOPETROTTER:
    plt.rcParams.update(bundles.iclr2023())
    fig_dict = figsizes.iclr2023()
    plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.3, fraction_height=0.65))
    fig, axs = plt.subplots(1, 1)
    mpo_means = []
    mpo_stds = []
    # data = np.load(f'./neurips_ostrich_eval/speed/mpo_speed.npy')
    data = np.load(f'./obstacles_depmpo_robust.npy')
    mpo_means.append(np.mean(data))
    mpo_stds.append(np.std(data))
    # data = np.load(f'./neurips_ostrich_eval/speed/dep_mpo_speed.npy')
    data = np.load(f'./obstacles_mpo_robust.npy')
    mpo_means.append(np.mean(data))
    mpo_stds.append(np.std(data))


    # data = np.load(f'./neurips_ostrich_eval/speed/td4_speed.npy')
    data = np.load(f'./obstacles_td4.npy')
    mpo_means.append(np.mean(data))
    mpo_stds.append(np.std(data))

    #data = np.load(f'./obstacles_depmpo_init.npy')
    #mpo_means.append(np.mean(data))
    #mpo_stds.append(np.std(data))

    x = np.arange(len(mpo_means))
    bars = axs.bar(x, height=mpo_means, yerr=mpo_stds, color=colors)
    axs.set_xticks(x)
    #axs[1].set_xticklabels(['DEP-MPO', 'MPO', ' TD4', 'DEPinit'])
    axs.set_xticklabels(['DEP-MPO', 'MPO', ' TD4'])


    axs.set_ylabel('max distance (m)')
    plt.savefig('slopetrotterostr.pdf')

if PLOT_SPEED:
    plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.50, fraction_height=1.0))
    print(fig_dict)
    fig, axs = plt.subplots(1, 1)
    mpo_means = []
    mpo_stds = []
    # data = np.load(f'./neurips_ostrich_eval/speed/mpo_speed.npy')
    data = np.load(f'./speed_depmpo_fast.npy')
    mpo_means.append(np.mean(data))
    mpo_stds.append(np.std(data))
    # data = np.load(f'./neurips_ostrich_eval/speed/dep_mpo_speed.npy')
    data = np.load(f'./speed_mpo_fast.npy')
    mpo_means.append(np.mean(data))
    mpo_stds.append(np.std(data))
    # data = np.load(f'./neurips_ostrich_eval/speed/td4_speed.npy')
    data = np.load(f'./speed_td4_fast.npy')
    mpo_means.append(np.mean(data))
    mpo_stds.append(np.std(data))

    x = np.arange(len(mpo_means))
    bars = axs.bar(x, height=mpo_means, yerr=mpo_stds, color=colors)
    axs.set_xticks(x)
    axs.set_xticklabels(['DEP-MPO', 'MPO', ' TD4'])


    axs.set_ylabel('max velocity (m/s)')
    axs.set_xlabel('sample', color='white')
    axs.set_title('speed')

    plt.savefig('robustness_plots_speed.pdf')

if PLOT_HUMAN_TROTTER:
    plt.rcParams.update(bundles.iclr2023())
    fig_dict = figsizes.iclr2023()
    plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.3, fraction_height=0.75))
    dx = 0.45

    d = './normal_hyfydy_dep_iclr.npy'

    data = np.load(d)
    means = np.mean(data, axis=-1)
    stds = np.std(data, axis=-1)
    color = 'tab:orange'
    x = np.arange(len(means))
    axs.bar(x-dx/2, means, width=dx, yerr=stds, color=color, label='ID')


    d = './hobstacle_hyfydy_dep_iclr.npy'
    data = np.load(d)
    means = np.mean(data, axis=-1)
    stds = np.std(data, axis=-1)
    color = 'tab:red'
    x = np.arange(len(means))
    axs.bar(x+dx/2, means, width=dx, yerr=stds, color=color, label='OOD')

    axs.set_ylim([0, 1.5])
    xticks = np.arange(len(means))
    axs.set_xticks(x)
    #axs.set_xticklabels(['5', '10', '15', '20', '25'])
    #axs.set_ylim([0, 1.1])
    #axs[0].legend()
    axs.set_xlabel('random seed')
    axs.set_ylabel('success rate')
    axs.legend(loc='upper left', fontsize=4)
    #plt.tight_layout()
    # plt.savefig('robustness_ostrich.pdf')
    # plt.show()

    plt.savefig('humanhopstacle.pdf')
