import os
import tqdm
import argparse
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set_style('whitegrid')


parser = argparse.ArgumentParser()
parser.add_argument('--root-dir', type=str, required=True)
parser.add_argument('--is-driving', action='store_true', default=False)
args = parser.parse_args()

all_metrics = dict()
for sdir_name in tqdm.tqdm(sorted(os.listdir(args.root_dir))):
    if args.is_driving:
        ckpt_idx = int(sdir_name.split('_')[-1].split('=')[-1].split('.')[-2])
        run_idx = 0
        arch = '_'.join(sdir_name.split('_')[:-1])
    else:
        ckpt_idx = int(sdir_name.split('_')[-1])
        run_idx = int(sdir_name.split('_')[-2])
        arch = '_'.join(sdir_name.split('_')[:-2])
    
    # if arch != 'lstm': continue # DEBUG
    # if arch == 'ode_rnn': continue # DEBUG
    if ckpt_idx > 100: continue # DEBUG
    
    metrics_path = os.path.join(args.root_dir, sdir_name, 'dt', 'metrics.pkl')
    with open(metrics_path, 'rb') as f:
        metrics = pickle.load(f)
    leaf_impurity = np.mean([vv for v in metrics['leaf_impurity'] for vv in v])
    mig = np.mean(metrics['mig'])
    modularity = np.mean(metrics['modularity'])
    
    if False: # recompute
        results_path = os.path.join(args.root_dir, sdir_name, 'results.pkl')
        with open(results_path, 'rb') as f:
            results = pickle.load(f)
        ep_reward = [np.sum([step_v[-3] for step_v in ep_v]) for ep_v in results]
        performance = np.mean(ep_reward)
    else:
        try:
            performance = metrics['performance']
        except:
            print(sdir_name)
        
    if arch not in all_metrics.keys():
        all_metrics[arch] = dict()
    if run_idx not in all_metrics[arch].keys():
        all_metrics[arch][run_idx] = dict()
    all_metrics[arch][run_idx][ckpt_idx] = {
        'leaf_impurity': leaf_impurity,
        'mig': mig,
        'modularity': modularity,
        'performance': performance,
    }

df = pd.DataFrame(columns=['arch', 'run_id', 'ckpt_id', 'leaf_impurity', 'mig', 'modularity', 'performance'])
for arch in all_metrics.keys():
    for run_id in all_metrics[arch].keys():
        for ckpt_id in all_metrics[arch][run_id].keys():
            metrics = []
            for m_name in ['leaf_impurity', 'mig', 'modularity', 'performance']:
                metrics.append(all_metrics[arch][run_id][ckpt_id][m_name])
            df.loc[len(df.index)] = [arch, run_id, ckpt_id] + metrics

nrow, ncol = 4, 1
fig, axes = plt.subplots(nrow, ncol, figsize=(6.4*ncol, 4.8*nrow))
sns.lineplot(ax=axes[0], data=df, x='ckpt_id', y='leaf_impurity', hue='arch')
sns.lineplot(ax=axes[1], data=df, x='ckpt_id', y='mig', hue='arch')
sns.lineplot(ax=axes[2], data=df, x='ckpt_id', y='modularity', hue='arch')
sns.lineplot(ax=axes[3], data=df, x='ckpt_id', y='performance', hue='arch')
fig.savefig('./local/test.png')
import pdb; pdb.set_trace()
