from matplotlib import pyplot as plt
from os import path as osp
from itertools import zip_longest, product
import numpy as np
import argparse, sys, os
import pandas as pd
from scipy.ndimage import uniform_filter
import matplotlib as mpl
mpl.style.use('seaborn')

def unpack(s):
    return " ".join(map(str, s))

def remove_nan(raw_data):
    return raw_data[~np.isnan(raw_data)]

args = argparse.ArgumentParser(sys.argv[0])
args.add_argument('--dir', type = str)
args = args.parse_args()
if args.dir[-1] == '/':
    args.dir = args.dir[:-1]
parts = args.dir.split('/')
env_name, date_str = parts[-2], parts[-1]
save_name = env_name + '/' + date_str
print(save_name)

divs = sorted(os.listdir(args.dir))
log_lists = [sorted(os.listdir(osp.join(args.dir, div))) for div in divs]
print(log_lists)
# exit()

metrics = ['Running Forward KL', 'Running Reverse KL', 'Real Det Return', 'Real Sto Return'] 

x_axis = 'Running Env Steps'

for idx, trials in enumerate(zip_longest(*log_lists)): # zip of unknown number of lists
    fig = plt.figure(figsize=(7*2, 5*2))
    axes = [fig.add_subplot(2,2,i+1) for i in range(4)]

    for div, trial in zip(divs, trials):
        if trial is None:
            continue
        file_path = osp.join(args.dir, div, trial, 'progress.csv')
        print(file_path)
        if not os.path.isfile(file_path) or os.path.getsize(file_path) == 0:
            continue
        df = pd.read_csv(file_path)

        y = remove_nan(df.loc[:, x_axis].values)

        for metric, ax in zip(metrics, axes):
            if metric not in list(df):
                continue
            raw_data = remove_nan(df.loc[:, metric].values) # to numpy array
            data = uniform_filter(raw_data, 20)
            ax.plot(y, data, label=div)
            ax.text(y[-1], data[-1], f'{div}')
            # ax.fill_between(x, y_lower, y_upper, interpolate=True, linewidth=0.0, alpha=0.2)
    
    if 'Real Sto Return' in metrics: # cuz we use sto data
        ax = axes[metrics.index('Real Sto Return')]
        expert_file = open(f'samples/expert_data/meta/{env_name}.txt', 'r')
        first_line = expert_file.readline().rstrip().split()
        loc = first_line.index('Avg:')
        expert_return = float(first_line[loc+1][:-1])
        ax.axhline(y=expert_return, label='expert')

    for metric, ax in zip(metrics, axes):
        ax.set_xlabel(x_axis)
        ax.set_ylabel(metric)
        ax.legend()
        # ax.set_xlim([0,3e6])
    
    plt.suptitle(f"{save_name}-{idx}", y=0.98)
    plt.tight_layout()
    os.makedirs(f'./data/figures/{save_name}/', exist_ok=True)
    plt.savefig(f'./data/figures/{save_name}/{idx}.pdf', bbox_inches='tight', pad_inches=0.1)
    plt.show()
    plt.close()