from matplotlib import pyplot as plt
from os import path as osp
from itertools import zip_longest, product
import numpy as np
import argparse, sys, os
import pandas as pd
from scipy.ndimage import uniform_filter
import matplotlib as mpl
mpl.style.use('seaborn')

def unpack(s):
    return " ".join(map(str, s))

def remove_nan(raw_data):
    return raw_data[~np.isnan(raw_data)]

args = argparse.ArgumentParser(sys.argv[0])
args.add_argument('--dir', type = str)
args = args.parse_args()
if args.dir[-1] == '/':
    args.dir = args.dir[:-1]
parts = args.dir.split('/')
env_name, date_str = parts[-2], parts[-1]
save_name = env_name + '/' + date_str
print(save_name)

divs = sorted(os.listdir(args.dir))

metrics = ['Running Forward KL', 'Running Reverse KL', 'Real Det Return', 'Real Sto Return']

x_axis = 'Running Env Steps'

for div in divs:
    fig = plt.figure(figsize=(7*2, 5*2))
    axes = [fig.add_subplot(2,2,i+1) for i in range(4)]
    for trial_id, trial in enumerate(sorted(os.listdir(osp.join(args.dir, div)))):
        file_path = osp.join(args.dir, div, trial, 'progress.csv')
        print(file_path)
        if not os.path.isfile(file_path) or os.path.getsize(file_path) == 0:
            continue
        df = pd.read_csv(file_path)

        # env_steps = remove_nan(df.loc[:, "Running Env Steps"].values)
        # up_time = remove_nan(df.loc[:, "Running Update Time"].values)
        y = remove_nan(df.loc[:, x_axis].values)

        for metric, ax in zip(metrics, axes):
            if metric not in list(df):
                continue
            raw_data = remove_nan(df.loc[:, metric].values) # to numpy array
            data = uniform_filter(raw_data, 20)
            ax.plot(y[:len(data)], data, label=f'{div}_{trial_id}')
            ax.text(y[-1], data[-1], f'{trial_id}')
            # ax.fill_between(x, y_lower, y_upper, interpolate=True, linewidth=0.0, alpha=0.2)

    if 'Real Sto Return' in metrics: # cuz we use sto data
        ax = axes[metrics.index('Real Sto Return')]
        expert_file = open(f'samples/expert_data/meta/{env_name}.txt', 'r')
        first_line = expert_file.readline().rstrip().split()
        loc = first_line.index('Avg:')
        expert_return = float(first_line[loc+1][:-1])
        ax.axhline(y=expert_return, label='expert')

    for metric, ax in zip(metrics, axes):
        ax.set_xlabel(x_axis)
        ax.set_ylabel(metric)
        # ax.legend()
        # ax.set_ylim([0,2])
    
    plt.suptitle(f"{save_name}-{div}", y=0.98)
    plt.tight_layout()
    os.makedirs(f'./data/figures/{save_name}/', exist_ok=True)
    plt.savefig(f'./data/figures/{save_name}/{div}.pdf', bbox_inches='tight', pad_inches=0.1)
    plt.show()
    plt.close()
