
import os
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '10'
import glob
import json
import numpy as np
import  matplotlib.pyplot as plt
plt.style.use('seaborn-paper')
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

from tensorboard.backend.event_processing import event_accumulator

COLORS = ['red','green','c', 'm','brown','coral','#6e9ece','green','blue',]
SHOWLAB = ['HER', 'CHER(Ours)','DDPG+HER','GCAL','ESIL','GCSL','WGCSL', 'GRSIL(Ours)','GRSILe', ]
STYLE=[      '-',     '--',     '-.',  ':',    ':',   'solid', 'dashed', 'dashdot', 'dotted']

base_path = '/opt/code/HER/runDouble'
# base_path = '/opt/code/HER/runGSP3'
save_path = '/opt/code/HER/runDouble'

radius  = 3
num_resample1 = 51
num_resample2 = 101

def extract_data(log_files,env,i):
    smooth_scores = []
    max_length = 50

    for logs in log_files:

        data_log = event_accumulator.EventAccumulator(logs)
        data_log.Reload()
        if env[:5]=='Point':
            if 'reward_eval' in data_log.scalars.Keys():
                v0_log = data_log.scalars.Items('reward_eval')
            else:
                v0_log = data_log.scalars.Items('reward_eval1')
            v0_data = [i.value for i in v0_log]
        else:
            v0_log = data_log.scalars.Items('reward_eval')
            v0_data = [i.value for i in v0_log]
        max_length = min(max_length, len(v0_data))
        convkernel = np.ones(2 * radius + 1)
        v0_data = np.convolve(v0_data, convkernel, mode='same')/ np.convolve(np.ones_like(v0_data), convkernel, mode='same')
        smooth_scores.append(v0_data)

    res = []
    for x in smooth_scores:
        res.append(x[:max_length])
    smooth_scores = np.array(res)

    return smooth_scores

def plot(data, i):
    ymean = np.mean(data, axis=0)
    ystd = np.std(data, axis=0)
    ystderr = ystd / np.sqrt(len(data))
    plt.plot(np.arange(len(ymean)), ymean, color=COLORS[i], linestyle=STYLE[i], label=SHOWLAB[i], linewidth=3)
    plt.fill_between(np.arange(len(ymean)), ymean - ystderr, ymean + ystderr, color=COLORS[i], alpha=.2)
#
def extract_log(env, tag, i):
    if i>=1:
        log_files = glob.glob(os.path.join(base_path, tag, '*'))
        log_files = [x for x in log_files if os.path.basename(x).split('-')[1] == env]
    else:
        log_files = glob.glob(os.path.join('/resall/run3/', tag, '*', 'progress.csv'))
        log_files = [x for x in log_files if x.split(tag+'/')[1].split('-seed')[0] == env]
    td3_log = extract_data(log_files, env, i)
    td3_log = np.array(td3_log)
    return td3_log

def plot_res(envs):

    # fig, axes = plt.subplots(2, 4, figsize=(35,12))
    for j, env in enumerate(envs):
        print(env)
        # ax = plt.subplot(2, 4, j+1)

        # DDPG = extract_log(env,"DDPG",1)
        HER = extract_log(env,"HER",1)
        GCAL = extract_log(env, "HERDouGen_1_1", 1)
        # GCAL1 = extract_log(env, "CounterHERMax_1_1", 1)
        print(HER.shape)
        print(GCAL.shape)
        res = [HER,GCAL]

        for i in range(0,len(res)):

            print(env,SHOWLAB[i],np.mean(res[i][:,-1:]))
            plot(res[i],i)
            plt.xlabel('Training Epoch', fontsize=25)
            plt.ylabel('Success Rate', fontsize=25)
            plt.title(env, fontsize=25)
            plt.xticks(fontsize=25)
            plt.yticks(fontsize=25)
            plt.grid('on')

        # plt.show()
        plt.savefig(os.path.join(save_path, 'res.pdf'),bbox_inches='tight')
        plt.close()

# plot_res(envs = ['HandManipulatePenRotate','HandManipulateBlockRotateXYZ','HandManipulateEggFull'])
# plot_res(envs = ['FetchPush','FetchPickAndPlace','FetchSlide','HandReach'])
# plot_res(envs = ['HandReach'])
plot_res(envs = ['FetchPickAndPlace'])