from matplotlib import pyplot as plt
from os import path as osp
from itertools import zip_longest, product
import numpy as np
import argparse, sys, os
import pandas as pd
from scipy.ndimage import uniform_filter
import matplotlib as mpl
mpl.style.use('seaborn-whitegrid')
label_fontsize = 18
def unpack(s):
    return " ".join(map(str, s))

def remove_nan(raw_data):
    return raw_data[~np.isnan(raw_data)]

args = argparse.ArgumentParser(sys.argv[0])
args.add_argument('--dir', type = str)
args = args.parse_args()

task = 'exp-16'

def plot_in_subplot(fig,ax,img_id,env_name, title='Learning Curve',label="Unlabelled"):
    global task
    divs = ['fkl', 'rkl', 'js', 'maxentirl', 'f-max-rkl', 'airl']
    metrics = ['Real Sto Return'] #'Real Det Return', 
    x_axis = 'Running Env Steps'
    max_steps = 1000000 if 'Hopper' in env_name else 3000000

    if 'Real Sto Return' in metrics: # cuz we use sto data
        pt_name = env_name if 'FH-v0' in env_name else env_name + 'FH-v0'
        expert_file = open(f'samples/expert_data/meta/{pt_name}.txt', 'r')
        first_line = expert_file.readline().rstrip().split()
        loc = first_line.index('Avg:')
        expert_return = float(first_line[loc+1][:-1])
        ax.axhline(y=1.0, label='expert', linestyle='--')

    for div in divs:
        if not os.path.isdir(osp.join(args.dir,env_name, task, div)):
            print(osp.join(args.dir,env_name, task, div))
            continue
        trial_dir = sorted(os.listdir(osp.join(args.dir,env_name, task, div)))
        returns = {metric: [] for metric in metrics}
        steps = []

        for trial in trial_dir:
            if '2020' not in trial: continue # .DS_store
            file_path = osp.join(args.dir, env_name, task, div, trial, 'progress.csv')
            print(file_path)
            if not os.path.isfile(file_path) or os.path.getsize(file_path) == 0:
                continue
            df = pd.read_csv(file_path)

            x = remove_nan(df.loc[:, x_axis].values)
            steps.append(x)

            for metric in metrics:
                if metric not in list(df):
                    continue
                raw_data = remove_nan(df.loc[:, metric].values) # to numpy array
                returns[metric].append(raw_data)
        
        min_len = min(map(len, steps)) # same len for metrics
        x = steps[0][:min_len]
        metric = metrics[0]
        # for metric, ax in zip(metrics, axes):
        returns[metric] = np.array([trial[:min_len] for trial in returns[metric]])
        if metric == 'Real Sto Return':
            returns[metric] /= expert_return # ratio
        returns_std = returns[metric].std(0)
        returns_std = uniform_filter(returns_std, 20)
        returns[metric] = returns[metric].mean(0) # avg over trials

        data = uniform_filter(returns[metric], 20)
        ax.xaxis.offsetText.set_fontsize(label_fontsize-9)
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
        ax.tick_params(axis='x', labelsize=label_fontsize-4)
        ax.tick_params(axis='y', labelsize=label_fontsize-4)
        ax.plot(x, data, label=div)
        # ax.text(x[-1], data[-1], div)
        ax.fill_between(x, data-returns_std, data+returns_std,alpha=0.1)
    
    metric = metrics[0]

    # for metric, ax in zip(metrics, axes):
    ax.set_xlabel('Environment steps',fontsize=label_fontsize)
    if metric == 'Real Sto Return':
        ax.set_ylabel('Return' + ' Ratio',fontsize=label_fontsize)
    else:
        ax.set_ylabel(metric,fontsize=label_fontsize)
    ax.set_xlim([0,max_steps])
    # ax.set_title(title,fontsize=label_fontsize)

env_names = ['Hopper','HalfCheetah','Walker2d','Ant']
# 20,2
fig, axs = plt.subplots(1,4, figsize=(20, 4))

plt.rcParams.update({'font.size': 20})
for itr, env_name in enumerate(env_names):
    ax = axs[itr]
    ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    
    # for env_name in env_names:
    plot_in_subplot(fig,ax,itr,env_name,title=env_name)


handles, labels = ax.get_legend_handles_labels()

fig.tight_layout()
plt.savefig('IL.pdf',format='pdf')

import pylab
figLegend = pylab.figure(figsize = (11,0.5))
pylab.figlegend(*ax.get_legend_handles_labels(), loc = 'upper left',mode='expand',ncol=5,fontsize=label_fontsize-4,borderaxespad=0, frameon=False)

figLegend.savefig('IL_legend.pdf',format='pdf')



plt.show()
