import argparse
import json
import os
import warnings
from typing import Dict, Iterable, List, Optional, Union
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message="networkx backend defined more than once: nx-loopback")


parser = argparse.ArgumentParser('ABCDEFG')
# File IO
parser.add_argument('-o', '--out-dir', type=str, required=True, help='Root directory to save simulation results')
parser.add_argument('--plot', action='store_true', help='Plot the results.')
parser.add_argument('-f', '--fig-name', type=str, required=True, help='Figure name.')
parser.add_argument(
    '-m',
    '--eval-metric', type=str, default='f1',
    help="Evaluation metric. If set to 'none', all model instances will be displayed."
)
parser.add_argument(
    '--fix-type', type=str, nargs='*', default=None,
    help='Fix the simulation type'
)
parser.add_argument(
    '--fix-model', type=str, nargs='*', default=None,
    help='Fix the model type'
)
parser.add_argument(
    '--eval-mode', type=str, default='max',
    help='Evaluation mode (max/min)'
)
parser.add_argument(
    '--latest', action='store_true',
    help='Keep the latest results.'
)
parser.add_argument(
    '--show-all', action='store_true',
    help='Keep all results.'
)
parser.add_argument(
    '--cache', action='store_true',
    help='Save the results to csv.'
)


ALL_MODELS = ['dcdi', 'dcdfg', 'enco', 'sdcd', 'ABCDEFG_Basic', 'ABCDEFG_SPN']
ALL_MODEL_DIRS = ['dcdi', 'dcdfg_fine_log', 'enco', 'sdcd', 'ABCDEFG_Basic', 'ABCDEFG_SPN']
MARKERS = ['o', 's', 'd', 'x', 'v', '*', '+', 'p', 'h']


def _to_model_type(model_dir: str) -> str:
    return model_dir if model_dir in ALL_MODELS else 'dcdfg'

def _to_model_display_name(model_type: str) -> str:
    if model_type == 'dcdi':
        return 'DCDI'
    elif model_type == 'dcdfg':
        return 'DCDFG'
    elif model_type == 'enco':
        return 'ENCO'
    elif model_type == 'sdcd':
        return 'SDCD'
    elif model_type == 'ABCDEFG_Basic':
        return 'ABCDEFG\n(Basic)'
    elif model_type == 'ABCDEFG_SPN':
        return 'ABCDEFG\n(SPN)'
    elif 'Intv.' in model_type:
        return model_type
    elif isinstance(model_type, Iterable):
        return '-'.join([str(x).upper() for x in model_type])
    else:
        return model_type


def _to_intv_display_name(model_type: str) -> str:
    return 'Intv. ' + _to_model_display_name(model_type)
    

def _to_sim_display_name(
    sim_type: str,
    show_attr: Optional[Iterable[str]] = None
) -> str:
    attrs = sim_type.split('_')
    out_attrs = []
    if show_attr is None or 'linearity' in show_attr:
        if attrs[0] == 'linear':
            out_attrs.append('Linear')
        elif attrs[0] == 'nn':
            out_attrs.append('NN')
        else:
            out_attrs.append(attrs[0])
    if show_attr is None or 'graph' in show_attr:
        out_attrs.append(attrs[1].upper())
    if show_attr is None or 'target' in show_attr:
        if attrs[2] in ['targeted', 'targetedmulti']:
            out_attrs.append('T')
        elif attrs[2] in ['untargeted', 'untargetedmulti']:
            out_attrs.append('U')
        else:
            out_attrs.append(attrs[2].upper())
    if (show_attr is None or 'intervention' in show_attr) and len(attrs) == 4:
        if attrs[3] == 'hard':
            out_attrs.append('Hard')
        else:
            out_attrs.append('Soft')
    return '-'.join(out_attrs)


def _get_graph_metrics(log_dir: str) -> Optional[Dict[str, float]]:
    if not os.path.exists(os.path.join(log_dir, 'metrics.json')):
        # print(f'Log directory {log_dir} not found. Skipping...')
        return None
    with open(os.path.join(log_dir, 'metrics.json'), 'r', encoding='utf-8') as f:
        metrics = json.load(f)
    return metrics


def _evaluate(model_dir: str) -> Dict[str, List[float]]:
    # model_dir/instance
    all_inst = os.listdir(model_dir)
    res = {}
    valid_inst = []
    for inst in all_inst:
        _metrics = _get_graph_metrics(os.path.join(model_dir, inst))
        if _metrics is None:
            continue
        valid_inst.append(inst)
        for key, val in _metrics.items():
            if isinstance(val, dict):
                continue
            if key not in res:
                res[key] = []
            res[key].append(val)
    # print(res)
    res = pd.DataFrame(res, index=pd.Index(valid_inst, name='instance'))

    return res


def _get_hparam(log_dir: str, file_name: str = 'args.json') -> Optional[Dict[str, float]]:
    if not os.path.exists(os.path.join(log_dir, file_name)):
        print(f'Log file {log_dir}/{file_name} not found. Skipping...')
        return None
    with open(os.path.join(log_dir, file_name), 'r', encoding='utf-8') as f:
        hparam = json.load(f)
    return hparam


def get_hparam(result_dir: str, fix_type: Optional[Iterable] = None) -> Dict[str, Dict[str, Dict[str, Dict[str, float]]]]:
    # sim_type/dataset/model_type/instance
    res = {}
    for sim_type in np.sort(os.listdir(result_dir)):
        if fix_type is not None and sim_type not in fix_type:
            continue
        sim_dir = os.path.join(result_dir, sim_type)
        res[sim_type] = {}
        for dataset in np.sort(os.listdir(sim_dir)):
            data_dir = os.path.join(sim_dir, dataset)
            res[sim_type][dataset] = {}
            for model_dir in np.sort(os.listdir(data_dir)):
                model_type = _to_model_type(model_dir)
                if model_type not in ['ABCDEFG_Basic', 'ABCDEFG_SPN']:
                    continue
                model_dir = os.path.join(data_dir, model_dir)
                res[sim_type][dataset][model_type] = {}
                for inst in np.sort(os.listdir(model_dir)):
                    _hparam = _get_hparam(os.path.join(model_dir, inst))
                    if _hparam is None:
                        continue
                    res[sim_type][dataset][model_type][inst] = _hparam
    return res


def evaluate(
    result_dir: str,
    fix_type: Optional[Iterable] = None,
    fix_model: Optional[Iterable] = None
) -> Dict[str, np.ndarray]:
    # sim_type/dataset/model_type/instance
    res = {}
    model_dirs = {}
    hparams = {}

    for sim_type in np.sort(os.listdir(result_dir)):
        if fix_type is not None and sim_type not in fix_type:
            continue
        sim_dir = os.path.join(result_dir, sim_type)
        
        res[sim_type] = {}
        model_dirs[sim_type] = {}
        hparams[sim_type] = {}

        for dataset in np.sort(os.listdir(sim_dir)):
            data_dir = os.path.join(sim_dir, dataset)

            res[sim_type][dataset] = {}
            model_dirs[sim_type][dataset] = {}
            hparams[sim_type][dataset] = {}

            for _model_dir in np.sort(os.listdir(data_dir)):
                if _model_dir not in ALL_MODEL_DIRS:
                    continue
                model_type = _to_model_type(_model_dir)
                if fix_model is not None and model_type not in fix_model:
                    continue

                model_dir = os.path.join(data_dir, _model_dir)
                _res = _evaluate(model_dir)
                if _res.empty:
                    continue
                _res.columns = [col.lower() for col in _res.columns]
                _res = _res.sort_index(axis=0)

                model_dirs[sim_type][dataset][model_type] = os.listdir(model_dir)
                res[sim_type][dataset][model_type] = _res
                # Add hyper-parameters
                if model_type in ['ABCDEFG_Basic', 'ABCDEFG_SPN']:
                    hparams[sim_type][dataset][model_type] = {} 
                    for inst in _res.index:
                        hparams[sim_type][dataset][model_type][inst] = _get_hparam(os.path.join(model_dir, inst))

    return res, hparams


def collect_res(
    res: Dict[str, Dict[str, Dict[str, pd.DataFrame]]],
    latest: bool = False
) -> Dict[str, pd.DataFrame]:
    """Combine individual results, one instance per model type, into a single DataFrame.

    Args:
        res (Dict[str, Dict[str, Dict[str, pd.DataFrame]]]): raw results from evaluate

    Returns:
        Dict[str, pd.DataFrame]: keys are metric names and values are DataFrames with columns as model types and rows as (sim_type, dataset)
    """
    data_index = pd.MultiIndex.from_tuples([(sim_type, dataset) for sim_type in res for dataset in res[sim_type]])
    precision = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    recall = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    f1 = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    shd = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    precision_int = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    recall_int = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    f1_int = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    shd_int = pd.DataFrame(index=data_index, columns=ALL_MODELS)
    inst = pd.DataFrame(index=data_index, columns=ALL_MODELS)

    for sim_type, _res_per_sim in res.items():
        for dataset, _res_per_data in _res_per_sim.items():
            for model_type, _res in _res_per_data.items():
                inst.loc[(sim_type, dataset), model_type] = _res.index[0]
                if latest:
                    _res = _res.iloc[[-1]]
                else:
                    _item = _res.iloc[0]
                precision.loc[(sim_type, dataset), model_type] = _item['precision']
                recall.loc[(sim_type, dataset), model_type] = _item['recall']
                if 'f1' not in _item:
                    _item['f1'] = 2 * (_item['precision'] * _item['recall']) / (_item['precision'] + _item['recall'])
                f1.loc[(sim_type, dataset), model_type] = _item['f1']
                shd.loc[(sim_type, dataset), model_type] = _item['shd']
                if 'precision_int' in _item:
                    precision_int.loc[(sim_type, dataset), model_type] = _item['precision_int']
                if 'recall_int' in _item:
                    recall_int.loc[(sim_type, dataset), model_type] = _item['recall_int']
                if 'f1_int' in _item:
                    f1_int.loc[(sim_type, dataset), model_type] = _item['f1_int']
                elif 'precision_int' in _item and 'recall_int' in _item:
                    f1_int.loc[(sim_type, dataset), model_type] = 2 * (_item['precision_int'] * _item['recall_int']) / (_res['precision_int'] + _res['recall_int'])
                if 'shd_int' in _item:
                    shd_int.loc[(sim_type, dataset), model_type] = _item['shd_int']
    precision = precision.sort_index(axis=1)
    recall = recall.sort_index(axis=1)
    f1 = f1.sort_index(axis=1)
    shd = shd.sort_index(axis=1)
    precision_int = precision_int.sort_index(axis=1)
    recall_int = recall_int.sort_index(axis=1)
    f1_int = f1_int.sort_index(axis=1)
    shd_int = shd_int.sort_index(axis=1)
    inst = inst.sort_index(axis=1)
    return inst, {
        'precision': precision.dropna(axis=1, how='all'),
        'recall': recall.dropna(axis=1, how='all'),
        'f1': f1.dropna(axis=1, how='all'),
        'shd': shd.dropna(axis=1, how='all'),
        'precision_int': precision_int.dropna(axis=1, how='all'),
        'recall_int': recall_int.dropna(axis=1, how='all'),
        'f1_int': f1_int.dropna(axis=1, how='all'),
        'shd_int': shd_int.dropna(axis=1, how='all'),
    }


def collect_all_res(
    res: Dict[str, Dict[str, Dict[str, pd.DataFrame]]],
    eval_metric: Optional[str] = None,
    mode: str = 'max'
) -> Dict[str, pd.DataFrame]:
    """Combine individual results into a single DataFrame

    Args:
        res (Dict[str, Dict[str, Dict[str, pd.DataFrame]]]): raw results from evaluate

    Returns:
        Dict[str, pd.DataFrame]: keys are metric names and values are DataFrames with columns as model types and rows as (sim_type, dataset)
    """
    data_index = pd.MultiIndex.from_tuples([(sim_type, dataset) for sim_type in res for dataset in res[sim_type]], names=['Simulation type', 'Dataset'])
    col_index = pd.MultiIndex.from_product([ALL_MODELS, []], names=['model', 'instance'])

    precision = pd.DataFrame(index=data_index, columns=col_index)
    recall = pd.DataFrame(index=data_index, columns=col_index)
    f1 = pd.DataFrame(index=data_index, columns=col_index)
    shd = pd.DataFrame(index=data_index, columns=col_index)
    precision_int = pd.DataFrame(index=data_index, columns=col_index)
    recall_int = pd.DataFrame(index=data_index, columns=col_index)
    f1_int = pd.DataFrame(index=data_index, columns=col_index)
    shd_int = pd.DataFrame(index=data_index, columns=col_index)
    inst = pd.DataFrame(index=data_index, columns=col_index)

    for sim_type, _res_per_sim in res.items():
        for dataset, _res_per_data in _res_per_sim.items():
            for model_type, _res in _res_per_data.items():
                for i, (_inst, _item) in enumerate(_res.iterrows()):
                    inst.loc[(sim_type, dataset), (model_type, i)] = _inst
                    precision.loc[(sim_type, dataset), (model_type, i)] = _item.loc['precision']
                    recall.loc[(sim_type, dataset), (model_type, i)] = _item.loc['recall']
                    if 'f1' not in _item:
                        _item['f1'] = 2 * (_item.loc['precision'] * _item.loc['recall']) / (_item.loc['precision'] + _item.loc['recall'])
                    f1.loc[(sim_type, dataset), (model_type, i)] = _item.loc['f1']
                    shd.loc[(sim_type, dataset), (model_type, i)] = _item.loc['shd']
                    if 'precision_int' in _item:
                        precision_int.loc[(sim_type, dataset), (model_type, i)] = _item.loc['precision_int']
                    if 'recall_int' in _item:
                        recall_int.loc[(sim_type, dataset), (model_type, i)] = _item.loc['recall_int']
                    if 'f1_int' in _item:
                        f1_int.loc[(sim_type, dataset), (model_type, i)] = _item.loc['f1_int']
                    elif 'precision_int' in _item and 'recall_int' in _item:
                        f1_int.loc[(sim_type, dataset), (model_type, i)] = 2 * (_item.loc['precision_int'] * _item.loc['recall_int']) / (_item.loc['precision_int'] + _item.loc['recall_int'])
                    if 'shd_int' in _item:
                        shd_int.loc[(sim_type, dataset), (model_type, i)] = _item.loc['shd_int']
    precision = precision.sort_index(axis=1)
    recall = recall.sort_index(axis=1)
    f1 = f1.sort_index(axis=1)
    shd = shd.sort_index(axis=1)
    precision_int = precision_int.sort_index(axis=1)
    recall_int = recall_int.sort_index(axis=1)
    f1_int = f1_int.sort_index(axis=1)
    shd_int = shd_int.sort_index(axis=1)
    inst = inst.sort_index(axis=1)
    out = {
        'precision': precision.dropna(axis=1, how='all'),
        'recall': recall.dropna(axis=1, how='all'),
        'f1': f1.dropna(axis=1, how='all'),
        'shd': shd.dropna(axis=1, how='all'),
        'precision_int': precision_int.dropna(axis=1, how='all'),
        'recall_int': recall_int.dropna(axis=1, how='all'),
        'f1_int': f1_int.dropna(axis=1, how='all'),
        'shd_int': shd_int.dropna(axis=1, how='all'),
    }
    # Filter out NaN
    for key, df in out.items():
        if key not in ['shd', 'shd_int']:
            out[key].fillna(0, inplace=True)
    # Filter out best instances based on evaluation metric for each simulation type
    if eval_metric in out:
        best_out = {}
        for key, df in out.items():
            best_out[key] = pd.DataFrame(index=df.index, columns=df.columns.get_level_values(0).unique())
            if not df.empty:
                best_insts = pd.DataFrame(index=df.index, columns=df.columns.get_level_values(0).unique())

        for sim_type in shd.index.get_level_values(0).unique():
            mean = out[eval_metric].loc[sim_type].mean(axis=0)
            if mode == 'max':
                best_inst = mean.groupby('model').idxmax()
            else:
                best_inst = mean.groupby('model').idxmin()
            for key, df in out.items():
                if not df.empty:
                    # print(df.loc[sim_type].loc[:, best_inst].droplevel(1, axis=1)) 
                    # print(best_out[key].loc[sim_type].shape)
                    best_out[key].loc[sim_type] = df.loc[sim_type].loc[:, best_inst].droplevel(1, axis=1).values

            best_insts.loc[sim_type] = inst.loc[sim_type, best_inst].droplevel(1, axis=1).values
        return best_insts, best_out
        
    return inst, out
    

def print_hparams(hparams: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]) -> pd.DataFrame:
    for sim_type, _hparam_per_sim in hparams.items():
        for dataset, _hparam_per_data in _hparam_per_sim.items():
            print('Simulation type:', sim_type)
            print('\tDataset:', dataset)
            row_index = pd.MultiIndex.from_product([['ABCDEFG_Basic', 'ABCDEFG_SPN'], []], names=['model', 'instance'])
            hparam_df = pd.DataFrame(index=row_index)
            for model_type, _hparam in _hparam_per_data.items():
                if model_type not in ['ABCDEFG_Basic', 'ABCDEFG_SPN']:
                    continue
                for i, (_inst, _hparam) in enumerate(_hparam.items()):
                    for key, val in _hparam.items():
                        hparam_df.loc[(model_type, _inst), key] = val
            hparam_df = hparam_df.sort_index(axis=0).sort_index(axis=1)
            print(hparam_df.to_string())
            print('------------------------------------------------------------------')
    return hparam_df


def _classify_sim_type(sim_type: str, dataset: str) -> str:
    _sim_type = sim_type.split('_')[0]
    _graph_type = sim_type.split('_')[1].upper()
    targeted = sim_type.split('_')[2] in ['targeted', 'targetedmulti']
    _data_size = int(dataset.split('_')[0][1:])
    if _sim_type == 'linear':
        if targeted:
            return f'L-{_graph_type}-{_data_size}'
        else:
            return f'LU-{_graph_type}-{_data_size}'
    elif _sim_type == 'nn':
        if targeted:
            return f'N-{_graph_type}-{_data_size}'
        else:
            return f'NU-{_graph_type}-{_data_size}'
    else:
        raise ValueError(f'Unknown simulation type: {sim_type}')


def _get_stats(df: pd.DataFrame, metric_name: str) -> pd.DataFrame:
    """Calculate mean and std of the performance metrics

    Args:
        df (pd.DataFrame): output from collect_res

    Returns:
        pd.DataFrame: columns are simulation types and rows are methods
    """
    out = {}
    counts = {}
    for idx in df.index:
        sim_type, dataset = idx
        # get the new simulation type name in output (e.g. L-16)
        combined_sim_type = _classify_sim_type(sim_type, dataset)
        if combined_sim_type not in counts:
            out[combined_sim_type] = [df.loc[idx].values]
            counts[combined_sim_type] = 1
        else:
            out[combined_sim_type].append(df.loc[idx].values)
            counts[combined_sim_type] += 1

    # get mean and std of each method
    mean = {}
    std = {}
    for sim_type, values in out.items():
        out[sim_type] = np.stack(values).astype(float)
        mean[sim_type] = np.nanmean(out[sim_type], axis=0)
        std[sim_type] = np.nanstd(out[sim_type], axis=0)

    mean_std = {}
    out_cols = []
    for sim_type, mu in mean.items():
        out_cols.append(sim_type)
        if metric_name in ['shd', 'shd_int']:
            mean_std[sim_type] = ["\u00B1".join([f'{mu[i]:.1f}', f'{std[sim_type][i]:.1f}']) for i in range(mu.shape[0])]
        else:
            mean_std[sim_type] = ["\u00B1".join([f'{mu[i]:.2f}', f'{std[sim_type][i]:.2f}']) for i in range(mu.shape[0])]
    if isinstance(df.columns, pd.MultiIndex):
        inst = df.columns.get_level_values('instance')
        if len(inst) == 0:
            out = pd.DataFrame(mean_std, index=df.columns.get_level_values('model'), columns=out_cols)
            out = out.reindex(pd.Index(list(filter(lambda x: x in out.index, ALL_MODELS)), name='method'))
        else:
            out = pd.DataFrame(mean_std, index=df.columns, columns=out_cols)
    else:
        out = pd.DataFrame(mean_std, index=df.columns, columns=out_cols)
        out = out.reindex(pd.Index(list(filter(lambda x: x in out.index, ALL_MODELS)), name='method'))
    return out
    

def get_stats(res: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]:
    out = {}
    for metric, df in res.items():
        out[metric] = _get_stats(df, metric)
    return out


def plot_shd(
    df: pd.DataFrame,
    fig_name: str,
    interventional: bool = False
):
    colors = list(mcolors.TABLEAU_COLORS)
    if interventional:
        colors = colors[:len(df.columns)//2] + colors[-len(df.columns)//2:]
    box_colors = {
        'boxes': 'black', 'whiskers': 'black', 'medians': 'black', 'caps': 'black'
    }
    fig, ax = plt.subplots(figsize=(4, 2))
    ax, props = df.astype(float).plot.box(
        patch_artist=True,
        return_type='both',
        color=box_colors,
        ax=ax
    )
    for patch, color in zip(props['boxes'], colors):
        patch.set_facecolor(color)
    ax.set_xticklabels([_to_model_display_name(model) for model in df.columns])
    ax.set_ylabel('SHD', fontsize=14)
    plt.xticks(rotation=15, fontsize=12)
    plt.yticks(fontsize=12)
    plt.tight_layout()
    if interventional:
        plt.savefig(f'./figures/{fig_name}-shd-intv.png', dpi=300)
    else:
        plt.savefig(f'./figures/{fig_name}-shd.png', dpi=300)


def plot_shd_multi_sim(
    df: pd.DataFrame,
    fig_name: str,
    groupby: Optional[Union[str, List[str]]] = None,
    interventional: bool = False
):
    colors = list(mcolors.TABLEAU_COLORS)
    if interventional:
        colors = colors[:len(df.columns)//2] + colors[-len(df.columns)//2:]
    box_colors = {
        'boxes': 'black', 'whiskers': 'black', 'medians': 'black', 'caps': 'black'
    }
    sim_types = {}
    for sim_type in df.index.get_level_values(0).unique():
        sim_name = _to_sim_display_name(sim_type, groupby)
        if sim_name in sim_types:
            sim_types[sim_name].append(sim_type)
        else:
            sim_types[sim_name] = [sim_type]
    
    fig, ax = plt.subplots(1, 2, figsize=(6 * len(sim_types), 4))
    for i, (sim_name, sim_type_group) in enumerate(sim_types.items()):
        ax[i], props = df.loc[sim_type_group].astype(float).plot.box(
            patch_artist=True,
            return_type='both',
            color=box_colors,
            ax=ax[i]
        )
        for patch, color in zip(props['boxes'], colors):
            patch.set_facecolor(color)
        if i == 0:
            ax[i].set_ylabel('SHD', fontsize=14)
        ax[i].set_xticks(
            range(1, df.shape[1]+1), [_to_model_display_name(model) for model in df.columns],
            rotation=15,
            fontsize=7.5
        )
        ax[i].tick_params(axis='y', labelsize=8)
        xmin, _ = ax[i].get_xlim()
        _, ymax = ax[i].get_ylim()
        ax[i].text(xmin + 0.02, ymax * 0.99, sim_name, fontsize=12, ha='left', va='top')
        
    plt.tight_layout()
    if interventional:
        fig.savefig(f'./figures/{fig_name}-shd-intv.png', dpi=300)
    else:
        fig.savefig(f'./figures/{fig_name}-shd.png', dpi=300)


def plot_prec_rec(
    precision: pd.DataFrame,
    recall: pd.DataFrame,
    fig_name: str,
    interventional: bool = False
):
    def group_sim_types(sim_types: Iterable[str]):
        groups = {}
        for sim_type in sim_types:
            name = _to_sim_display_name(sim_type, ['linearity', 'graph', 'intervention'])
            if name in groups:
                groups[name].append(sim_type)
            else:
                groups[name] = [sim_type]
        return groups

    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    # Plot legend separately
    fig_lgd_sim, ax_lgd_sim = plt.subplots(figsize=(8, 2))
    fig_lgd_method, ax_lgd_method = plt.subplots(figsize=(2, 4))

    # Plot by simulation type
    COLORS = list(mcolors.TABLEAU_COLORS)
    if interventional:
        COLORS = COLORS[:len(precision.columns)//2] + COLORS[-len(precision.columns)//2:]
    sim_types = precision.index.get_level_values(0).unique()
    # Group subgroups together
    sim_type_group = group_sim_types(sim_types)
    # Colors represent methods and marker represent simulation types
    for i, ((_, sim_type), marker) in enumerate(zip(sim_type_group.items(), MARKERS)):
        _precision = precision.loc[sim_type]
        _recall = recall.loc[sim_type]
        for (model, color) in zip(_precision.columns, COLORS):
            ax.scatter(_precision[model], _recall[model], color=color, marker=marker)
            # Legend for methods
            if i == 0:
                ax_lgd_method.plot([], [], color=color, linewidth=5, label=_to_model_display_name(model))
    ax.plot([0, 1], [0, 1], color='red', linestyle='--')
    # Legend for simulation types
    for ((name, _), marker) in zip(sim_type_group.items(), MARKERS):
        ax_lgd_sim.scatter([], [], color='black', marker=marker, label=name)

    ax.set_xlabel('Precision', fontsize=14)
    ax.set_ylabel('Recall', fontsize=14)
    ax.set_xlim(0, 1.1)
    ax.set_ylim(0, 1.1)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.tight_layout()

    # Generate legends
    lgd_sim = ax_lgd_sim.legend(
        title='Simulation Type',
        ncols=len(sim_type_group),
        loc='center',
        fontsize=12,
        handlelength=1, 
        handleheight=1
    )
    ax_lgd_sim.axis('off')
    ax_lgd_sim.set_facecolor('none')
    
    lgd_method = ax_lgd_method.legend(
        title='Method',
        loc='center',
        fontsize=12,
        handlelength=1, 
        handleheight=1
    )
    ax_lgd_method.axis('off')
    ax_lgd_method.set_facecolor('none')
    
    if interventional:
        fig.savefig(f'./figures/{fig_name}-prec-rec-intv.png', bbox_inches='tight', dpi=300)
        fig_lgd_sim.savefig(f'./figures/{fig_name}-sim-legend.png', bbox_extra_artists=(lgd_sim,), bbox_inches='tight', dpi=300)
        fig_lgd_method.savefig(f'./figures/{fig_name}-method-legend.png', bbox_extra_artists=(lgd_method,), bbox_inches='tight', dpi=300)
    else:
        fig.savefig(f'./figures/{fig_name}-prec-rec.png', bbox_inches='tight', dpi=300)
        fig_lgd_sim.savefig(f'./figures/{fig_name}-sim-intv-legend.png', bbox_extra_artists=(lgd_sim,), bbox_inches='tight', dpi=300)
        fig_lgd_method.savefig(f'./figures/{fig_name}-method-intv-legend.png', bbox_extra_artists=(lgd_method,), bbox_inches='tight', dpi=300)


if __name__ == '__main__':
    args = parser.parse_args()
    res, hparams = evaluate(args.out_dir, args.fix_type, args.fix_model)

    print('--------------------------------- Hyper-Parameter Setting ---------------------------------')
    print_hparams(hparams)
    
    eval_metric = None if args.show_all else args.eval_metric
    inst, out = collect_all_res(res, eval_metric, args.eval_mode)
    
    print('--------------------------------- Instance ---------------------------------')
    print(inst.to_string())
    print('\n\n')
    
    stats = get_stats(out)        

    print('--------------------------------- SHD ---------------------------------')
    print(out['shd'].to_string())
    print('******************************** stats ********************************')
    print(stats['shd'].to_string())
    print('\n')
    print('--------------------------------- F1 ---------------------------------')
    print(out['f1'].to_string())
    print('******************************** stats ********************************')
    print(stats['f1'].to_string())
    print('\n')
    print('--------------------------------- Precision ---------------------------------')
    print(out['precision'].to_string())
    print('******************************** stats ********************************')
    print(stats['precision'].to_string())
    print('\n')
    print('--------------------------------- Recall ---------------------------------')
    print(out['recall'].to_string())
    print('******************************** stats ********************************')
    print(stats['recall'].to_string())

    if len(out['shd_int']) > 0:
        print('\n')
        print('--------------------------------- SHD_int ---------------------------------')
        print(out['shd_int'].to_string())
        print('******************************** stats ********************************')
        print(stats['shd_int'].to_string())
    
    if len(out['f1_int']) > 0:
        print('\n')
        print('--------------------------------- F1_int ---------------------------------')
        print(out['f1_int'].to_string())
        print('******************************** stats ********************************')
        print(stats['f1_int'].to_string())
    
    if len(out['precision_int']) > 0:
        print('\n')
        print('--------------------------------- Precision_int ---------------------------------')
        print(out['precision_int'].to_string())
        print('******************************** stats ********************************')
        print(stats['precision_int'].to_string())
    
    if len(out['recall_int']) > 0:
        print('\n')
        print('--------------------------------- Recall_int ---------------------------------')
        print(out['recall_int'].to_string())
        print('******************************** stats ********************************')
        print(stats['recall_int'].to_string())
    
    # Save to csv
    if args.cache:
        inst.to_csv(f'./logs/{args.fig_name}-inst.csv')
        out['shd'].to_csv(f'./logs/{args.fig_name}-shd.csv')
        out['f1'].to_csv(f'./logs/{args.fig_name}-f1.csv')
        out['precision'].to_csv(f'./logs/{args.fig_name}-precision.csv')
        out['recall'].to_csv(f'./logs/{args.fig_name}-recall.csv')
        if 'shd_int' in out:
            out['shd_int'].to_csv(f'./logs/{args.fig_name}-shd_int.csv')
        if 'f1_int' in out:
            out['f1_int'].to_csv(f'./logs/{args.fig_name}-f1_int.csv')
        if 'precision_int' in out:
            out['precision_int'].to_csv(f'./logs/{args.fig_name}-precision_int.csv')
        if 'recall_int' in out:
            out['recall_int'].to_csv(f'./logs/{args.fig_name}-recall_int.csv')
    
    # Box plot comparing SHD of different methods
    if args.plot:
        if not out['shd_int'].empty:
            out['shd_int'].columns = pd.Index([_to_intv_display_name(model) for model in out['shd_int'].columns])
            df_plot = pd.concat([out['shd'], out['shd_int']], axis=1)
            # plot_shd(df_plot, args.fig_name, interventional=True)
            plot_shd_multi_sim(df_plot, args.fig_name, ['graph'], interventional=True)
        else:
            # plot_shd(out['shd'], args.fig_name)
            plot_shd_multi_sim(out['shd'], args.fig_name, ['graph'])
        # Scatter plot of recall vs precision colored by different methods
        if not out['precision_int'].empty and not out['recall_int'].empty:
            out['precision_int'].columns = pd.Index([_to_intv_display_name(model) for model in out['precision_int'].columns])
            out['recall_int'].columns = pd.Index([_to_intv_display_name(model) for model in out['recall_int'].columns])
            prec_plot = pd.concat([out['precision'], out['precision_int']], axis=1)
            rec_plot = pd.concat([out['recall'], out['recall_int']], axis=1)
            plot_prec_rec(prec_plot, rec_plot, args.fig_name, interventional=True)
        else:
            plot_prec_rec(out['precision'], out['recall'], args.fig_name)