import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


def smooth(values, weight=0.5):
    scalar = values
    last = scalar[0]
    smoothed = []
    for point in scalar:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed


plt.rcParams['figure.figsize']=[24, 38]
plt.subplots_adjust(wspace=0.3, hspace=0.3)

x_ticks = [-1, -5.0/7, -3.0/7, -1.0/7, 1.0/7, 3.0/7, 5.0/7, 1.0]
x_labels = ["-1e6", "-5e5", "0", "2e4", "4e4", "6e4", "8e4", "1e5"]

# set parameters for online data
n_trials = 3
n_trials_offline = 2
n_samples = 100
offline_samples = 1000

# set colors and labels
colors = ["#ff007f", "#990000", "#3498db", "#9b59b6", "#2ecc71", "orange"]
algo_labels = ['HiP3', 'AWAC', 'CQL', 'IQL', 'MBPO', 'PEX']


def draw_subplot(ax, figure_title):

    # set algorithms
    algos = ['HiP3/' + figure_title,
             'CoRL/' + figure_title + '/AWAC',
             'CoRL/' + figure_title + '/CQL',
             'CoRL/' + figure_title + '/IQL',
             'mbpo/' + figure_title,
             'pex/' + figure_title]

    # colors = ['deepskyblue', 'orange', 'green', 'red', 'purple', 'brown']
    item_color = 0
    for algo, algo_label in zip(algos, algo_labels):

        if 'mbpo' not in algo:
            # read offline data
            if algo == 'HiP3/' + figure_title:
                offline_samples = 1155
                offline_summary = np.zeros((n_trials_offline, offline_samples))
            else:
                offline_samples = 1000
                offline_summary = np.zeros((n_trials_offline, offline_samples))
            for i in range(n_trials_offline):
                offline_data_file = algo + '/offline/seed_{:s}/log.txt'.format(str(i))
                offline_data = np.genfromtxt(offline_data_file, dtype=None, delimiter='	', names=True)
                if algo == 'HiP3/' + figure_title:
                    offline_summary[i] = smooth(offline_data['Max_pool_normalized_return_rollout'], 0.8)
                else:
                    offline_summary[i] = smooth(offline_data['Normalized_reward_mean'], 0.8)
            offline_mean_y = np.mean(offline_summary, axis=0)
            offline_mean_y = smooth(offline_mean_y, 0.8)
            offline_std_y = np.std(offline_summary, axis=0)
            ax.plot((offline_data['Iterations'] - 1e6)/1e6*4.0/7 - 3.0/7, offline_mean_y, '-', label=algo_label,
                    linewidth=2.0, color=colors[item_color], zorder=2)
            ax.fill_between((offline_data['Iterations'] - 1e6)/1e6*4.0/7 - 3.0/7, offline_mean_y - offline_std_y,
                            offline_mean_y + offline_std_y, alpha=0.2, color=colors[item_color], zorder=2)

        # read online data
        summary = np.zeros((n_trials, n_samples))
        for i in range(n_trials):
            online_data_file = algo + '/online/seed_{:s}/log.txt'.format(str(i))
            online_data = np.genfromtxt(online_data_file, dtype=None, delimiter='	', names=True)
            summary[i] = smooth(online_data['Normalized_reward_mean'])

        # plot all data
        mean_y = np.mean(summary, axis=0)
        std_y = np.std(summary, axis=0)
        if 'mbpo' in algo:
            ax.plot(online_data['Iterations'] / 1e5 * 10.0 / 7 - 3.0 / 7, mean_y, '-', label=algo_label, linewidth=2.0,
                    color=colors[item_color], zorder=2)
        else:
            ax.plot(online_data['Iterations'] / 1e5 * 10.0 / 7 - 3.0 / 7, mean_y, '-', label='_nolegend_',
                    linewidth=2.0, color=colors[item_color], zorder=2)
        ax.fill_between(online_data['Iterations']/1e5*10.0/7-3.0/7, mean_y - std_y, mean_y + std_y, alpha=0.2, color=colors[item_color], zorder=2)
        item_color += 1

    square1 = plt.Rectangle(xy=(-1.0, 0), width=4.0/7, height=150.0, alpha=0.2, angle=0, color='#72BCD5', zorder=1)
    ax.add_patch(square1)
    square2 = plt.Rectangle(xy=(-3.0/7, 0), width=10.0/7, height=150.0, alpha=0.2, angle=0, color='#F7D058', zorder=1) 
    ax.add_patch(square2)
    ax.set_xlabel("Number of Training Steps", fontsize=20, fontweight='bold', color='black')
    ax.set_ylabel("Normalized Score", fontsize=20, fontweight='bold', color='black') #72BCD5  376795
    ax.set_title(figure_title, fontweight='bold', fontdict={'fontsize': 20})
    ax.axvline(x=-3.0/7, linewidth='4', color='black', zorder=3.5)  # to split offline and online data


ax1 = plt.subplot(531)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 15, 30, 45, 60])
ax1.set_ylim([0, 70])
figure_title = 'walker2d-random-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(532)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 20, 40, 60, 80])
ax1.set_ylim([0, 100])
figure_title = 'hopper-random-v0'
draw_subplot(ax1, figure_title)
ax1.legend(fontsize=25, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=7)

ax1 = plt.subplot(533)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 15, 30, 45, 60])
ax1.set_ylim([0, 75])
figure_title = 'halfcheetah-random-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(534)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.tick_params(axis="y", labelsize=20, color="blue", labelcolor="blue")
ax1.set_yticks([0, 16, 32, 48, 64])
ax1.set_ylim([0, 80])
figure_title = 'walker2d-medium-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(535)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'hopper-medium-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(536)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 16, 32, 48, 64])
ax1.set_ylim([0, 80])
figure_title = 'halfcheetah-medium-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(537)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 20, 40, 60, 80])
ax1.set_ylim([0, 100])
figure_title = 'walker2d-medium-replay-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(538)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 20, 40, 60, 80])
ax1.set_ylim([0, 100])
figure_title = 'hopper-medium-replay-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(539)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 16, 32, 48, 64])
ax1.set_ylim([0, 80])
figure_title = 'halfcheetah-medium-replay-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(5, 3, 10)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'walker2d-medium-expert-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(5, 3, 11)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'hopper-medium-expert-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(5, 3, 12)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'halfcheetah-medium-expert-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(5, 3, 13)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'walker2d-expert-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(5, 3, 14)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'hopper-expert-v0'
draw_subplot(ax1, figure_title)

ax1 = plt.subplot(5, 3, 15)
ax1.grid(axis="x", color="black", linewidth=1.0, zorder=0, linestyle='--')
ax1.set_xlim([-1, 1])
ax1.set_xticks(np.linspace(-1, 1, len(x_ticks)))
ax1.set_xticklabels(x_labels)
ax1.tick_params(labelsize=20)
ax1.set_yticks([0, 24, 48, 72, 96])
ax1.set_ylim([0, 120])
figure_title = 'halfcheetah-expert-v0'
draw_subplot(ax1, figure_title)

plt.savefig("same_online_tasks.pdf", bbox_inches='tight')
# plt.show()

