import glob
import matplotlib.pyplot as plt
import numpy as np
import torch
import sys
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)


env = "HandManipulatePenFull"
      #"FetchReach", 
      #"FetchPush",
      #"FetchSlide",
      #"FetchPick", 
      #"HandReach", 
      #"HandManipulateBlockRotateZ",
      #"HandManipulateBlockRotateParallel",
      #"HandManipulateBlockRotateXYZ",
      #"HandManipulateBlockFull",
      #"HandManipulateEggRotate",
      #"HandManipulateEggFull",
      #"HandManipulatePenRotate",
      #"HandManipulatePenFull",

xlim = 200

#if sys.argv[1] == "main":
templates = "(-)rew_monolithic_lr0.001_sd{}.pt"
gcsl_templates = "(-)rew_lr0.001_sd{}.pt"

methods = [
    "gcqs","her","mher","gcsl","am","ddpg"
]
colors = ['xkcd:light green', 'xkcd:orange', 'xkcd:purple', 'xkcd:light orange','xkcd:bright pink', 
          'xkcd:red'] 
    #"C4", "C9", "C8"

seeds = [100,200,300,400,500]

def smooth(x, delta=2):
    n = x.shape[0]
    b = np.zeros((n,))
    for i in range(n):
        b[i] = x[max(0, i-delta):min(n, i+delta)].mean()
    return b

success = {}
for j, method in enumerate(methods):
    success[method] = []
    for seed in seeds:
        try:
            if method == 'gcqs' :
                filename = f"./results/{env}_{method}_{templates.format(seed)}"
            if method == 'ddpg' :
                filename = f"./results/{env}_{method}_{templates.format(seed)}"
            if method == 'her':
                filename = f"./results/{env}_{method}_{templates.format(seed)}"
            if method ==  'mher':
                filename = f"./results/{env}_{method}_{templates.format(seed)}"
            if method == 'am' :
                filename = f"./results/{env}_{method}_{gcsl_templates.format(seed)}"
            if method == 'gcsl' :
                filename = f"./results/{env}_{method}_{gcsl_templates.format(seed)}"
            res = torch.load(filename)
            s = np.array(res['stats']['successes'])
            s = smooth(s)
            success[method].append(s)
            print("[success] ", filename)
        except:
            print("[error] ", env, method, seed)
            continue
    if len(success[method]) > 0:
        s= []
        max_len = max([len(x) for x in success[method]])
        print(env, method, max_len)
        for x in success[method]:
          if len(x) == max_len:
            s.append(x)
        s =np.stack(s)
        plt.plot(s.mean(0), color=colors[j], linewidth=3.0, label=method)
        if len(success[method]) > 1:
            plt.fill_between(np.arange(s.shape[1]), s.mean(0) - s.std(0), s.mean(0) + s.std(0), color=colors[j], alpha=0.3)
    
    plt.xlim(0, xlim)

    plt.ylim(0, 1.05)

    plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    #plt.yticklabels(["0%", "20%", "40%", "60%", "80%", "100%"])
    plt.yticks(np.linspace(0, 1, 21))
    plt.tick_params(which = 'both', direction = 'out')
    

    plt.xlabel("Epoch", fontsize=19)

    plt.ylabel("Success Rate", fontsize=19)

    plt.grid()

plt.tight_layout()
plt.savefig(env+"results.png")
plt.close()
