import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

if __name__=='__main__':
    COLOR = ['blue', 'red', 'orange', 'green', 'k']

    ftsize = 35

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(50,10))
    line_labels = ['Setting 1', 'Setting 2', 'Setting 3', 'Setting 4', 'Expert']
    ##################################################################################################
    name='MountainCar-v0'
    with open('results/{}_expert_prior.pkl'.format(name), 'rb') as f:
        r_best = pkl.load(f)
        r_best = np.array(r_best)
    with open('results/{}_global_prior.pkl'.format(name), 'rb') as f:
        r_global = pkl.load(f)
        r_global = np.array(r_global)
    with open('results/{}_no_expert_batch.pkl'.format(name), 'rb') as f:
        r_noexpert = pkl.load(f)
        r_noexpert = np.array(r_noexpert)
    with open('results/{}_no_kl_reward.pkl'.format(name), 'rb') as f:
        r_nokl = pkl.load(f)
        r_nokl = np.array(r_nokl)
    skip_step=1
    ax1.hlines(-111.917, 0, 200, linestyles='dashed', color='k')
    for i, r in enumerate([r_best, r_noexpert, r_nokl, r_global]):
        mean = r.mean(axis=0)[::skip_step]
        std = r.std(axis=0)[::skip_step]
        #plt.subplot(4, 1, i+1)
        ax1.plot([1]+list(range(5, 201, 5)), mean, color=COLOR[i])
        ax1.fill_between([1]+list(range(5, 201, 5)) , mean-0.5*std, mean+0.5*std,alpha=0.2, color=COLOR[i])
    ax1.set_title(name, fontsize=ftsize)
    ax1.set_xlabel('Number of episodes', fontsize=ftsize)
    ax1.set_ylabel('Mean reward', fontsize=ftsize)
    ##################################################################################################
    name='Acrobot-v1'
    with open('results/{}_expert_prior.pkl'.format(name), 'rb') as f:
        r_best = pkl.load(f)
        r_best = np.array(r_best)
    with open('results/{}_no_expert_batch.pkl'.format(name), 'rb') as f:
        r_noexpert = pkl.load(f)
        r_noexpert = np.array(r_noexpert)
    with open('results/{}_no_kl_reward.pkl'.format(name), 'rb') as f:
        r_nokl = pkl.load(f)
        r_nokl = np.array(r_nokl)
    skip_step=1
    ax2.hlines(-85.137, 0, 200, linestyles='dashed', color='k', label='Expert')
    for i, r in enumerate([r_best, r_noexpert, r_nokl]):
        mean = r.mean(axis=0)[::skip_step]
        std = r.std(axis=0)[::skip_step]
        #plt.subplot(4, 1, i+1)
        ax2.plot([1]+list(range(5, 201, 5)), mean, color=COLOR[i])
        ax2.fill_between([1]+list(range(5, 201, 5)) , mean-0.5*std, mean+0.5*std,alpha=0.2, color=COLOR[i])
    ax2.set_title(name, fontsize=ftsize)
    ax2.set_xlabel('Number of episodes', fontsize=ftsize)
    ax2.set_ylabel('Mean reward', fontsize=ftsize)
    ##################################################################################################
    name='CartPole-v1'
    with open('results/{}_expert_prior.pkl'.format(name), 'rb') as f:
        r_best = pkl.load(f)
        r_best = np.array(r_best)
    with open('results/{}_global_prior.pkl'.format(name), 'rb') as f:
        r_global = pkl.load(f)
        r_global = np.array(r_global)
    with open('results/{}_no_expert_batch.pkl'.format(name), 'rb') as f:
        r_noexpert = pkl.load(f)
        r_noexpert = np.array(r_noexpert)
    with open('results/{}_no_kl_reward.pkl'.format(name), 'rb') as f:
        r_nokl = pkl.load(f)
        r_nokl = np.array(r_nokl)
    skip_step=1
    L = []
    l = ax3.hlines(500, 0, 1000, linestyles='dashed', color='k')
    L.append(l)
    for i, r in enumerate([r_best, r_noexpert, r_nokl, r_global]):
        mean = r.mean(axis=0)[::skip_step]
        std = r.std(axis=0)[::skip_step]
        #plt.subplot(4, 1, i+1)
        l = ax3.plot([1]+list(range(5, 1001, 5)), mean, color=COLOR[i])[0]
        L.append(l)
        ax3.fill_between([1]+list(range(5, 1001, 5)) , mean-0.5*std, mean+0.5*std,alpha=0.2, color=COLOR[i])
    ax3.set_title(name, fontsize=ftsize)
    ax3.set_xlabel('Number of episodes', fontsize=ftsize)
    ax3.set_ylabel('Mean reward', fontsize=ftsize)
    # Create the legend
    fig.legend(L,     # The line objects
               labels=line_labels,   # The labels for each line
               #loc=(.25, 0.001),   # Position of legend
               bbox_to_anchor=(0.7, 0.0001),
               borderaxespad=0.2,    # Small spacing around legend box
               fontsize=ftsize, ncol=5
               )
    plt.setp(ax1.get_xticklabels(), fontsize = ftsize)
    plt.setp(ax2.get_xticklabels(), fontsize = ftsize)
    plt.setp(ax3.get_xticklabels(), fontsize = ftsize)

    plt.setp(ax1.get_yticklabels(), fontsize = ftsize)
    plt.setp(ax2.get_yticklabels(), fontsize = ftsize)
    plt.setp(ax3.get_yticklabels(), fontsize = ftsize)

    # Adjust the scaling factor to fit your legend text completely outside the plot
    # (smaller value results in more space being made for the legend)
    plt.subplots_adjust(right=0.85)
    fig.savefig('results/learning_curve.png')
    plt.show()