from matplotlib import pyplot as plt
from os import path as osp
from itertools import zip_longest, product
import numpy as np
import argparse, sys, os
import pandas as pd
from scipy.ndimage import uniform_filter
import matplotlib as mpl
mpl.style.use('seaborn')

def unpack(s):
    return " ".join(map(str, s))

def remove_nan(raw_data):
    return raw_data[~np.isnan(raw_data)]

args = argparse.ArgumentParser(sys.argv[0])
args.add_argument('--dir', type = str)
args = args.parse_args()
if args.dir[-1] == '/':
    args.dir = args.dir[:-1]
parts = args.dir.split('/')
env_name = parts[-1]

tasks = ['exp-1', 'exp-4', 'exp-16']
divs = ['maxentirl', 'f-max-rkl', 'airl', 'fkl', 'rkl', 'js', ]
metrics = ['Real Sto Return'] #'Real Det Return', 
x_axis = 'Running Env Steps'

expert_returns = {}
expert_returns['Ant'] = np.array([5926.179, 5730.677, 5824.137, 5955.398, 5847.839, 6066.736,
       5706.077, 6105.023, 5927.851, 5927.464, 6050.575, 6162.145,
       6043.106, 5888.204, 5999.199, 5701.368, 6121.578, 5922.609,
       5877.748, 6044.576, 5977.024, 6041.073, 6080.149, 5890.946,
       5928.336, 5914.364, 5925.36 , 5944.162, 5971.436, 5997.144,
       5805.856, 5923.695, 5842.542, 5970.315, 5723.259, 5949.78 ,
       6146.484, 5984.89 , 5841.663, 5960.484, 5700.377, 5696.356,
       5890.842, 5784.999, 5955.16 , 5949.988, 5641.807, 5986.947,
       6136.71 , 6118.511, 5846.73 , 6042.562, 5944.697, 5787.549,
       6102.113, 5956.218, 5979.683, 5969.493, 5816.752, 5685.732,
       5871.238, 5883.844, 6048.18 , 5831.78 ])
expert_returns['HalfCheetah'] = np.array([12258.711, 12900.472, 10293.933, 12324.679, 12312.501, 12536.946,
       12579.71 , 12439.509, 12037.562, 13143.215, 12696.208, 12410.66 ,
       12714.029, 12244.488, 12662.132, 12945.88 , 12545.556, 11920.783,
       12760.411, 12655.707, 12542.178, 12514.93 , 12357.07 , 12783.59 ,
       12130.195, 12822.858, 11546.802, 12826.838, 12990.03 , 11033.234,
       13208.215, 12378.109, 12457.043, 12198.133, 12560.387, 12830.687,
       12058.786, 12250.469, 11458.793, 12540.908, 12651.069, 12433.801,
       12668.406, 12161.804, 12577.811, 12098.953, 11933.758, 12452.899,
       12710.076, 12281.293, 12285.43 , 11511.206, 12455.976, 12447.731,
       12200.233, 12241.609, 12865.379, 13039.216, 12443.003, 13084.916,
       12650.83 , 12622.516, 12621.77 , 13047.351])
expert_returns['Walker2d'] = np.array([5468.355, 5355.007, 5318.677, 5209.376, 5386.716, 5345.665,
       5374.757, 5456.733, 5392.301, 5393.831, 5434.147, 5408.793,
       5196.358, 5284.467, 5458.316, 5404.685, 5399.014, 5155.62 ,
       5387.207, 5305.797, 5358.304, 5413.285, 5440.836, 5368.105,
       5348.113, 5374.864, 5368.985, 5447.216, 5153.679, 5471.13 ,
       5374.028, 5401.202, 5397.05 , 5292.218, 5350.855, 5335.233,
       5345.407, 5257.378, 5286.984, 5326.554, 5309.479, 5422.897,
       5448.194, 5223.849, 5425.741, 5349.688, 5370.058, 5343.503,
       5418.851, 5407.414, 5372.199, 5289.034, 5409.694, 5373.184,
       5179.679, 5256.259, 5181.917, 5377.656, 5224.78 , 5334.359,
       5144.533, 5177.68 , 5394.084, 5347.327])
expert_returns['Hopper'] = np.array([3506.584, 3490.196, 3501.847, 3479.176, 3483.817, 3493.678,
       3490.362, 3502.376, 3512.96 , 3503.75 , 3480.46 , 3509.224,
       3493.809, 3489.614, 3501.114, 3507.032, 3487.71 , 3510.834,
       3501.506, 3512.044, 3493.558, 3503.494, 3507.378, 3506.112,
       3502.106, 3481.233, 3497.378, 3491.08 , 3494.132, 3488.929,
       3494.967, 3482.417, 3481.459, 3489.951, 3488.257, 3492.08 ,
       3484.725, 3502.682, 3499.977, 3493.042, 3514.677, 3495.574,
       3504.513, 3493.816, 3494.608, 3484.887, 3488.873, 3486.479,
       3487.311, 3497.588, 3492.103, 3482.593, 3479.339, 3489.158,
       3490.726, 3491.861, 3504.798, 3494.202, 3482.183, 3495.005,
       3492.167, 3495.345, 3502.05 , 3463.131])
# np.array([3570.866, 3575.767, 3597.924, 3597.796, 3580.849, 3630.259,
#        3577.996, 3623.137, 3592.353, 3579.305])

max_steps = 2000000 if 'Hopper' in env_name else 3000000
start_steps = 1800000 if 'Hopper' in env_name else 2700000

for task in tasks:
    fig = plt.figure(figsize=(7*len(metrics), 5))
    axes = [fig.add_subplot(1,len(metrics),i+1) for i in range(len(metrics))]

    if 'Real Sto Return' in metrics: # cuz we use sto data
        ax = axes[metrics.index('Real Sto Return')]
        # pt_name = env_name if 'FH-v0' in env_name else env_name + 'FH-v0'
        # expert_file = open(f'samples/expert_data/meta/{pt_name}.txt', 'r')
        # first_line = expert_file.readline().rstrip().split()
        # loc = first_line.index('Avg:')
        # expert_return = float(first_line[loc+1][:-1])
        expert_num_trajs = int(task.split('-')[-1])
        expert_return, expert_std = np.mean(expert_returns[env_name][:expert_num_trajs]), np.std(expert_returns[env_name][:expert_num_trajs])
        print(expert_num_trajs, expert_return, expert_std)
        ax.axhline(y=1.0, label='expert')

    for div in divs:
        if not os.path.isdir(osp.join(args.dir, task, div)): continue
        trial_dir = sorted(os.listdir(osp.join(args.dir, task, div)))
        returns = {metric: [] for metric in metrics}
        steps = []

        for trial in trial_dir:
            if '2020' not in trial: continue # .DS_store
            file_path = osp.join(args.dir, task, div, trial, 'progress.csv')
            print(file_path)
            if not os.path.isfile(file_path) or os.path.getsize(file_path) == 0:
                continue
            df = pd.read_csv(file_path)

            x = remove_nan(df.loc[:, x_axis].values)
            steps.append(x)

            for metric in metrics:
                if metric not in list(df):
                    continue
                raw_data = remove_nan(df.loc[:, metric].values) # to numpy array
                returns[metric].append(raw_data)
        
        min_len = min(map(len, steps)) # same len for metrics
        x = steps[0][:min_len]
        end = np.argmin(abs(x - max_steps)) + 1
        x = steps[0][:end]
        start = np.argmin(abs(x - start_steps))

        for metric, ax in zip(metrics, axes):
            returns[metric] = np.array([trial[:end] for trial in returns[metric]])
            # if metric == 'Real Sto Return':
            #     returns[metric] /= expert_return # ratio
            print(f"last avg {returns[metric][:, start:].mean():.2f} std {returns[metric][:, start:].std():.2f}")

            returns[metric] = returns[metric].mean(0) # avg over trials

            data = uniform_filter(returns[metric], 2)
            ax.plot(x, data, label=div)
            # ax.text(x[-1], data[-1], div)

        # # ax.fill_between(x, y_lower, y_upper, interpolate=True, linewidth=0.0, alpha=0.2)

    for metric, ax in zip(metrics, axes):
        ax.set_xlabel(x_axis)
        if metric == 'Real Sto Return':
            ax.set_ylabel(metric + ' ratio')
        else:
            ax.set_ylabel(metric)
        ax.legend()
        # ax.set_xlim([0,max_steps])
    
    plt.suptitle(f"{env_name}-{task}", y=0.98)
    plt.tight_layout()
    os.makedirs(f'./data/figures/{env_name}/', exist_ok=True)
    plt.savefig(f'./data/figures/{env_name}/{task}.pdf', bbox_inches='tight', pad_inches=0.1)
    plt.show()
    plt.close()
