import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
from os.path import exists
from matplotlib.ticker import MaxNLocator
parser = argparse.ArgumentParser(description='Plot')
parser.add_argument('--env-name', metavar='G',
                    help='name of the environment to run')
args = parser.parse_args()

#plt.style.use('seaborn')
# Enable LaTeX rendering
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
})
index = 0
Y_max = 0
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
expert_values = {"CartPole-v1": 500, "Acrobot-v1": -63.678, "LunarLander-v2": 318.7}
min_values = {"CartPole-v1": 25, "Acrobot-v1": -500, "LunarLander-v2": -200}
for method,c in zip([ "iq","ppil", "bc", "spoil"], [ "goldenrod","blue", "green", "brown"]):
    x_tot = [1,2,3,5,10]
    y_tot = []
    y_std = []
    for n_trajs in x_tot:
        ys_max=[]
        for seed in range(10):
            if method=="ppil" or method == "spoil":
                if args.env_name in ["CartPole-v1", "Acrobot-v1", "LunarLander-v2"]:
                    with open("pickle_results/"+method+"/"+args.env_name+str(seed)
                                +"n_trajs"+str(n_trajs)
                                +"_lr_w0.005"
                                +"_lr_theta0.005.pt","rb") as f:
                        _, y, _ = pickle.load(f)
                        #ax.plot(x_tot, y_tot, color = c)
                        ys_max.append(np.max(y))
                else:
                    if exists("pickle_results/"+method+"/init_temp_0.01_lunar_lander/"+args.env_name+str(seed)
                                +"n_trajs"+str(n_trajs)
                                +"_lr_w0.0001" #+"_lr_w0.005"
                                +"_lr_theta0.0001.pt"):
                        with open("pickle_results/"+method+"/init_temp_0.01_lunar_lander/"+args.env_name+str(seed)
                                    +"n_trajs"+str(n_trajs)
                                    +"_lr_w0.0001" #+"_lr_w0.005"
                                    +"_lr_theta0.0001.pt","rb") as f: #+"_lr_theta0.005.pt","rb") as f:
                            _, y, _ = pickle.load(f)
                            #ax.plot(x_tot, y_tot, color = c)
                            ys_max.append(np.max(y))
            elif method=="bc":
                with open("pickle_results/"+method+"/"+args.env_name+str(seed)
                                +"n_trajs"+str(n_trajs)
                                +"_lr_w0.005"
                                +"_lr_theta0.005.pt","rb") as f:
                        y, _ = pickle.load(f)
                        #ax.plot(x_tot, y_tot, color = c)
                        ys_max.append(np.max(y))
            else:
                if exists("pickle_results/"+method+"/"+args.env_name+str(seed)
                            +"n_trajs"+str(n_trajs)
                            +".pt"):
                    with open("pickle_results/"+method+"/"+args.env_name+str(seed)
                                +"n_trajs"+str(n_trajs)
                                +".pt","rb") as f:
            
                        _, y, _ = pickle.load(f)
                        #ax.plot(x_tot, y_tot, color = c)
                        ys_max.append(np.max(y))
        average_max=np.mean(ys_max)
        std = np.std(ys_max) 
        y_tot.append(average_max)
        y_std.append(std)
    if method == "ppil":
        label = r"\texttt{P2IL}"  
    elif method == "spoil":
        label=r"\texttt{SPOIL (Ours)}"
    elif method == "bc":
        label=r"\texttt{BC}"
    else:
        label = r"\texttt{IQLearn}"
    ax.plot(x_tot,y_tot, "-*", color=c, markersize=15, label=label)
    
    ax.fill_between(x_tot,
            np.array(y_tot) + np.array(y_std),
            np.array(y_tot) - np.array(y_std), alpha=0.05, facecolor=c)
hor=np.arange(1,11)
ax.plot(hor,expert_values[args.env_name]*np.ones_like(hor), "--", color="gray")         
ax.xaxis.set_major_locator(MaxNLocator(6)) 
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
plt.title(f"{args.env_name}",fontsize=30)
plt.xlabel(r"Expert Trajectories",fontsize=30)
plt.ylabel(r"Total Return",fontsize=30)
plt.legend(fontsize=15)
plt.tight_layout()
file_name = "figs/"+args.env_name+"_offline.pdf"
plt.savefig(file_name)