"""
Top-2 PolyILR Feature Visualization (Generalized)
Generates scatter plot colored by class labels for any dataset/task.

Usage:
    python plot.py --dataset hmp --task body_sites
    python plot.py --dataset disco_blood --task healthy_vs_leukemia
    python plot.py --dataset cmd3 --task westernized
    python plot.py --all  # Generate all available plots
"""

import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from tree_utils import phylo_to_tree, force_binary_random
from polyilr import construct_V_with_mapping, ilr_transform

# ============================================
# PATHS
# ============================================
BASE_DIR = Path(__file__).parent.parent
DATA_DIR = BASE_DIR / "data"
OUT_DIR = BASE_DIR / "out"
OUT_DIR.mkdir(exist_ok=True)

# ============================================
# PLOT CONFIG
# ============================================
plt.rcParams.update({
    'font.size': 18,
    'axes.labelsize': 22,
    'axes.titlesize': 22,
    'legend.fontsize': 15,
    'legend.title_fontsize': 16,
    'xtick.labelsize': 14,
    'ytick.labelsize': 14,
})

# Color palettes for different tasks
COLOR_PALETTES = {
    # HMP body sites (5 classes)
    'body_sites': {
        'Airways': '#E41A1C',
        'Gastrointestinal Tract': '#377EB8',
        'Oral': '#FF7F00',
        'Skin': '#4DAF4A',
        'Urogenital Tract': '#984EA3',
    },
    # Binary tasks
    'binary': {
        0: '#377EB8',  # blue
        1: '#E41A1C',  # red
    },
    # Default multiclass
    'default': plt.cm.tab10.colors,
}

# ============================================
# DATASET CONFIGS
# ============================================
DATASET_CONFIGS = {
    'hmp': {
        'type': 'microbiome',
        'data_dir': 'hmp_taxa',
        'otu_file': 'hmp_otu_table.csv',
        'taxonomy_file': 'hmp_taxonomy.csv',
        'metadata_file': 'hmp_metadata.csv',
        'tree_file': 'hmp_tree_taxonomy_grafen.newick',
        'tasks': {
            'body_sites': {'label_col': 'HMP_BODY_SITE'},
            'body_subsites': {'label_col': 'HMP_BODY_SUBSITE'},
        },
    },
    'cmd3': {
        'type': 'microbiome',
        'data_dir': 'cmd3_taxa',
        'otu_file': 'cmd3_otu_table.csv',
        'taxonomy_file': 'cmd3_taxonomy.csv',
        'metadata_file': 'cmd3_metadata.csv',
        'tree_file': 'cmd3_tree_taxonomy_grafen.newick',
        'tasks': {
            'westernized': {'label_col': 'non_westernized'},
            'age_category': {'label_col': 'age_category', 'filter': lambda y: y != 'unknown'},
            'healthy_vs_disease': {
                'label_col': 'disease',
                'transform': lambda y: np.array(['healthy' if d == 'healthy' else 'disease' for d in y]),
            },
            'body_site': {'label_col': 'body_site'},
        },
    },
    'disco_blood': {
        'type': 'disco',
        'data_dir': 'disco_200subset',
        'tree_file': 'blood_tree.csv',
        'conditions': {
            'healthy': 'blood_healthy_proportions.csv',
            'covid': 'blood_covid_proportions.csv',
            'leukemia': 'blood_leukemia_proportions.csv',
        },
        'tasks': {
            'healthy_vs_covid': {'conditions': ['healthy', 'covid']},
            'healthy_vs_leukemia': {'conditions': ['healthy', 'leukemia']},
        },
    },
    'disco_liver': {
        'type': 'disco',
        'data_dir': 'disco_200subset',
        'tree_file': 'liver_tree.csv',
        'conditions': {
            'healthy': 'liver_healthy_proportions.csv',
            'hcc': 'liver_hcc_proportions.csv',
        },
        'tasks': {
            'healthy_vs_hcc': {'conditions': ['healthy', 'hcc']},
        },
    },
}

# ============================================
# DATA LOADERS
# ============================================

def load_microbiome(dataset_name, task_name):
    """Load microbiome dataset (HMP, cMD3)."""
    from Bio import Phylo
    
    config = DATASET_CONFIGS[dataset_name]
    task_config = config['tasks'][task_name]
    data_dir = DATA_DIR / config['data_dir']
    
    print(f"Loading {dataset_name} / {task_name}...")
    
    # Load files
    otu = pd.read_csv(data_dir / config['otu_file'], index_col=0)
    taxonomy = pd.read_csv(data_dir / config['taxonomy_file'], index_col=0)
    metadata = pd.read_csv(data_dir / config['metadata_file'], index_col=0)
    T = Phylo.read(data_dir / config['tree_file'], "newick")
    
    # Setup tree
    otu_index = list(otu.index)
    D = len(otu_index)
    tree, root, edge_lengths = phylo_to_tree(T, otu_index)
    
    # Align samples
    otu_samples = [str(s) for s in otu.columns]
    metadata.index = metadata.index.astype(str)
    common_samples = [s for s in otu_samples if s in metadata.index]
    
    # Compositions
    X_counts = otu.loc[otu_index, common_samples].values.T
    X_comp = (X_counts + 1) / (X_counts + 1).sum(axis=1, keepdims=True)
    
    # Labels
    y_raw = metadata.loc[common_samples, task_config['label_col']].values
    
    if 'transform' in task_config:
        y_raw = task_config['transform'](y_raw)
    
    mask = pd.notna(y_raw)
    if 'filter' in task_config:
        mask = mask & task_config['filter'](y_raw)
    
    X_comp = X_comp[mask]
    y_raw = y_raw[mask]
    
    # Clade label function
    def get_clade_label(leaf_indices):
        taxa_names = [otu_index[i] for i in leaf_indices]
        sub_tax = taxonomy.loc[taxa_names]
        common_name = "Mixed"
        for level in ['Kingdom', 'Phylum', 'Class', 'Order', 'Family', 'Genus']:
            if level not in sub_tax.columns:
                continue
            unique = sub_tax[level].dropna().unique()
            if len(unique) == 1:
                common_name = unique[0]
            else:
                break
        return common_name
    
    print(f"  Samples: {len(y_raw)}, Taxa: {D}, Classes: {len(np.unique(y_raw))}")
    
    return {
        'X_comp': X_comp,
        'y_raw': y_raw,
        'tree': tree,
        'root': root,
        'edge_lengths': edge_lengths,
        'D': D,
        'get_clade_label': get_clade_label,
    }


def load_disco(dataset_name, task_name):
    """Load DISCO dataset."""
    config = DATASET_CONFIGS[dataset_name]
    task_config = config['tasks'][task_name]
    data_dir = DATA_DIR / config['data_dir']
    
    print(f"Loading {dataset_name} / {task_name}...")
    
    # Load tree
    tree_df = pd.read_csv(data_dir / config['tree_file'])
    
    # Load condition data
    condition_data = {}
    all_cell_types = set()
    meta_cols = ['sample_id', 'condition', 'tissue']
    
    for cond_name in task_config['conditions']:
        cond_file = config['conditions'][cond_name]
        df = pd.read_csv(data_dir / cond_file)
        condition_data[cond_name] = df
        ct_cols = [c for c in df.columns if c not in meta_cols]
        all_cell_types.update(ct_cols)
    
    # Find valid leaves
    tree_parents = set(tree_df['parent'])
    tree_children = set(tree_df['child'])
    tree_leaves = tree_children - tree_parents
    valid_leaves = all_cell_types & tree_leaves
    
    # Build tree
    leaf_names = sorted(valid_leaves)
    leaf_to_idx = {name: i for i, name in enumerate(leaf_names)}
    D = len(leaf_names)
    
    internal_names = sorted(tree_parents - valid_leaves - {'Root'})
    internal_to_idx = {name: -(i+1) for i, name in enumerate(internal_names)}
    internal_to_idx['Root'] = -(len(internal_names) + 1)
    
    idx_to_name = {**{v: k for k, v in leaf_to_idx.items()},
                   **{v: k for k, v in internal_to_idx.items()}}
    
    def name_to_idx(name):
        if name in leaf_to_idx:
            return leaf_to_idx[name]
        return internal_to_idx.get(name, None)
    
    tree = {}
    for parent in tree_df['parent'].unique():
        parent_idx = name_to_idx(parent)
        if parent_idx is None:
            continue
        children = tree_df[tree_df['parent'] == parent]['child'].tolist()
        children_idx = [name_to_idx(c) for c in children if name_to_idx(c) is not None]
        if children_idx:
            tree[parent_idx] = children_idx
    
    root = internal_to_idx['Root']
    edge_lengths = {(p, c): 1.0 for p, children in tree.items() for c in children}
    
    # Build X and y
    X_list, y_list = [], []
    for cond in task_config['conditions']:
        df = condition_data[cond].copy()
        for leaf in leaf_names:
            if leaf not in df.columns:
                df[leaf] = 0.0
        X = df[leaf_names].values + 1e-10
        X = X / X.sum(axis=1, keepdims=True)
        X_list.append(X)
        y_list.extend([cond] * len(X))
    
    X_comp = np.vstack(X_list)
    y_raw = np.array(y_list)
    
    # Clade label function
    parent_map = {c: p for p, children in tree.items() for c in children}
    
    def get_ancestors(node_idx):
        ancestors = []
        current = node_idx
        while current in parent_map:
            current = parent_map[current]
            ancestors.append(current)
        return ancestors
    
    def get_leaves_under(node):
        if node >= 0:
            return {node}
        leaves = set()
        for child in tree.get(node, []):
            leaves.update(get_leaves_under(child))
        return leaves
    
    def get_clade_label(leaf_indices):
        if len(leaf_indices) == 1:
            return idx_to_name.get(leaf_indices[0], "Unknown")
        leaf_indices = list(leaf_indices)
        ancestors_0 = [leaf_indices[0]] + get_ancestors(leaf_indices[0])
        for ancestor in ancestors_0:
            leaves_under = get_leaves_under(ancestor)
            if set(leaf_indices).issubset(leaves_under):
                return idx_to_name.get(ancestor, "Unknown")
        return idx_to_name.get(leaf_indices[0], "Unknown")
    
    print(f"  Samples: {len(y_raw)}, Cell types: {D}, Classes: {len(np.unique(y_raw))}")
    
    return {
        'X_comp': X_comp,
        'y_raw': y_raw,
        'tree': tree,
        'root': root,
        'edge_lengths': edge_lengths,
        'D': D,
        'get_clade_label': get_clade_label,
    }


def load_data(dataset_name, task_name):
    """Load dataset based on type."""
    config = DATASET_CONFIGS[dataset_name]
    if config['type'] == 'microbiome':
        return load_microbiome(dataset_name, task_name)
    elif config['type'] == 'disco':
        return load_disco(dataset_name, task_name)
    else:
        raise ValueError(f"Unknown dataset type: {config['type']}")


# ============================================
# PLOTTING
# ============================================

def get_contrast_label(idx, column_mapping, get_clade_label):
    """Get human-readable contrast label."""
    info = column_mapping[idx]
    children_labels = [get_clade_label(leaves) for leaves in info['child_leaves']]
    m = info['contrast_idx']
    left = children_labels[m] if m < len(children_labels) else "?"
    right_list = [t for i, t in enumerate(children_labels) if i != m]
    right = " + ".join(right_list[:1])
    if len(right_list) > 1:
        right += " + ..."
    return f"{left} $vs$ {right}"


def get_colors(y_raw, task_name):
    """Get color mapping for classes."""
    classes = np.unique(y_raw)
    
    # Check for specific palette
    if task_name in COLOR_PALETTES:
        palette = COLOR_PALETTES[task_name]
        return {cls: palette.get(cls, '#999999') for cls in classes}
    
    # Binary
    if len(classes) == 2:
        return {classes[0]: '#377EB8', classes[1]: '#E41A1C'}
    
    # Multiclass
    colors = plt.cm.tab10.colors
    return {cls: colors[i % len(colors)] for i, cls in enumerate(classes)}


def generate_plot(dataset_name, task_name, figsize=(7, 6), save=True):
    """Generate top-2 PolyILR feature plot."""
    
    # Load data
    data = load_data(dataset_name, task_name)
    X_comp = data['X_comp']
    y_raw = data['y_raw']
    tree = data['tree']
    root = data['root']
    edge_lengths = data['edge_lengths']
    D = data['D']
    get_clade_label = data['get_clade_label']
    
    # PolyILR transform
    print("Computing PolyILR transform...")
    V, column_mapping = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_ilr = ilr_transform(X_comp, V)
    
    # Encode labels
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    # Train RF
    print("Training RF...")
    X_train, X_test, y_train, y_test = train_test_split(
        X_ilr, y, test_size=0.1, stratify=y, random_state=42
    )
    clf = RandomForestClassifier(n_estimators=500, random_state=42, n_jobs=-1)
    clf.fit(X_train, y_train)
    
    test_acc = clf.score(X_test, y_test)
    print(f"Test accuracy: {test_acc:.3f}")
    
    # Get top-2 features
    importance = clf.feature_importances_
    top2_idx = np.argsort(importance)[-2:][::-1]
    
    # Get labels - break both into two lines
    label1 = 'Coord 1: ' + get_contrast_label(top2_idx[0], column_mapping, get_clade_label)
    label2 = 'Coord 2: ' + get_contrast_label(top2_idx[1], column_mapping, get_clade_label)
    label1 = label1.replace(' $vs$ ', '\n$vs$ ')
    label2 = label2.replace(' $vs$ ', '\n$vs$ ')
    
    print(f"Feature 1: {label1} (imp={importance[top2_idx[0]]:.4f})")
    print(f"Feature 2: {label2} (imp={importance[top2_idx[1]]:.4f})")
    
    # Plot
    X_top2 = X_ilr[:, top2_idx]
    colors = get_colors(y_raw, task_name)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    for cls in le.classes_:
        mask_cls = y_raw == cls
        ax.scatter(
            X_top2[mask_cls, 0],
            X_top2[mask_cls, 1],
            c=colors.get(cls, '#999999'),
            label=cls,
            alpha=0.6,
            s=20,
            edgecolors='none'
        )
    
    ax.set_xlabel(label1)
    ax.set_ylabel(label2)
    
    # Legend title based on dataset
    legend_title = "Body Site" if "body" in task_name else "Condition" if "vs" in task_name else "Class"
    ax.legend(title=legend_title, loc='best', framealpha=0.9,
              markerscale=0.8, handletextpad=0.3, borderpad=0.3, 
              labelspacing=0.3, handlelength=1.0)
    
    plt.tight_layout()
    
    if save:
        outfile = OUT_DIR / f"{dataset_name}_{task_name}_top2.pdf"
        plt.savefig(outfile, dpi=300, bbox_inches='tight')
        plt.savefig(outfile.with_suffix('.png'), dpi=300, bbox_inches='tight')
        print(f"Saved: {outfile}")
    
    return fig, ax


def generate_combined_plot(dataset_task_pairs, figsize=(14, 5), save=True, outname='combined_top2', trim_percentile=None):
    """Generate side-by-side top-2 PolyILR feature plots.
    
    Args:
        trim_percentile: If set (e.g., 1), clip data to 1st-99th percentile per axis.
    """
    
    n = len(dataset_task_pairs)
    fig, axes = plt.subplots(1, n, figsize=figsize)
    if n == 1:
        axes = [axes]
    
    for idx, (dataset_name, task_name) in enumerate(dataset_task_pairs):
        ax = axes[idx]
        
        # Load data
        data = load_data(dataset_name, task_name)
        X_comp = data['X_comp']
        y_raw = data['y_raw']
        tree = data['tree']
        root = data['root']
        edge_lengths = data['edge_lengths']
        D = data['D']
        get_clade_label = data['get_clade_label']
        
        # PolyILR transform
        print(f"[{idx+1}/{n}] {dataset_name}/{task_name}: Computing PolyILR...")
        V, column_mapping = construct_V_with_mapping(tree, root, D, edge_lengths)
        X_ilr = ilr_transform(X_comp, V)
        
        # Encode labels
        le = LabelEncoder()
        y = le.fit_transform(y_raw)
        
        # Train RF
        X_train, X_test, y_train, y_test = train_test_split(
            X_ilr, y, test_size=0.1, stratify=y, random_state=42
        )
        clf = RandomForestClassifier(n_estimators=500, random_state=42, n_jobs=-1)
        clf.fit(X_train, y_train)
        
        test_acc = clf.score(X_test, y_test)
        print(f"  Test accuracy: {test_acc:.3f}")
        
        # Get top-2 features
        importance = clf.feature_importances_
        top2_idx = np.argsort(importance)[-2:][::-1]
        
        # Get labels - break both into two lines
        label1 = get_contrast_label(top2_idx[0], column_mapping, get_clade_label)
        label2 = get_contrast_label(top2_idx[1], column_mapping, get_clade_label)
        label1 = label1.replace(' $vs$ ', '\n$vs$ ')
        label2 = label2.replace(' $vs$ ', '\n$vs$ ')
        
        # Plot
        X_top2 = X_ilr[:, top2_idx]
        
        # Trim outliers if requested
        if trim_percentile is not None:
            lo, hi = trim_percentile, 100 - trim_percentile
            x_lo, x_hi = np.percentile(X_top2[:, 0], [lo, hi])
            y_lo, y_hi = np.percentile(X_top2[:, 1], [lo, hi])
            mask_inlier = (
                (X_top2[:, 0] >= x_lo) & (X_top2[:, 0] <= x_hi) &
                (X_top2[:, 1] >= y_lo) & (X_top2[:, 1] <= y_hi)
            )
            n_trimmed = (~mask_inlier).sum()
            print(f"  Trimmed {n_trimmed} outliers ({trim_percentile}-{100-trim_percentile} percentile)")
            X_top2 = X_top2[mask_inlier]
            y_raw_plot = y_raw[mask_inlier]
        else:
            y_raw_plot = y_raw
        
        colors = get_colors(y_raw, task_name)
        
        for cls in le.classes_:
            mask_cls = y_raw_plot == cls
            ax.scatter(
                X_top2[mask_cls, 0],
                X_top2[mask_cls, 1],
                c=colors.get(cls, '#999999'),
                label=cls,
                alpha=0.6,
                s=70,
                edgecolors='none'
            )
        
        ax.set_xlabel(label1)
        ax.xaxis.set_label_coords(0.4, -0.15)
        ax.set_ylabel(label2)
    
    # Create two separate legend rows with title inline on left, left-aligned
    from matplotlib.lines import Line2D
    from matplotlib.patches import Patch
    
    # Collect handles/labels per subplot
    row_data = []
    for idx, ax in enumerate(axes):
        handles, labels = ax.get_legend_handles_labels()
        dataset_name, task_name = dataset_task_pairs[idx]
        title = "Body\\ Site" if "body" in task_name else "Blood\\ Condition" if "vs" in task_name else "Class"
        full_title = f"{title}:"
        # Get colors from handles
        colors = [h.get_facecolor()[0] for h in handles]
        row_data.append((full_title, labels, colors))
    
    # Row 1: HMP Body Site: [items...]
    title1, labels1, colors1 = row_data[0]
    # Add invisible handle for title, then real handles
    handles1 = [Line2D([0], [0], marker='none', linestyle='none')] + \
               [Line2D([0], [0], marker='o', color='w', markerfacecolor=c, markersize=14) for c in colors1]
    labels1_with_title = [f'$\\bf{{{title1}}}$'] + list(labels1)
    leg1 = fig.legend(handles1, labels1_with_title, loc='upper left',
                      bbox_to_anchor=(-0.05, 1.04), ncol=len(labels1_with_title),
                      frameon=False, fontsize=18,
                      handletextpad=0.1, columnspacing=0.5)
    
    # Row 2: DISCO Condition: [items...]
    if len(row_data) > 1:
        title2, labels2, colors2 = row_data[1]
        handles2 = [Line2D([0], [0], marker='none', linestyle='none')] + \
                   [Line2D([0], [0], marker='o', color='w', markerfacecolor=c, markersize=14) for c in colors2]
        labels2_with_title = [f'$\\bf{{{title2}}}$'] + list(labels2)
        leg2 = fig.legend(handles2, labels2_with_title, loc='upper left',
                          bbox_to_anchor=(-0.05, 0.97), ncol=len(labels2_with_title),
                          frameon=False, fontsize=18,
                          handletextpad=0.1, columnspacing=0.5)
    
    # Add tight box around legend area
    from matplotlib.patches import FancyBboxPatch
    legend_box = FancyBboxPatch((0.02, 0.88), 0.94, 0.12,
                                 boxstyle="round,pad=0.01,rounding_size=0.01",
                                 transform=fig.transFigure,
                                 facecolor='white', edgecolor='#cccccc',
                                 linewidth=1, zorder=0)
    fig.patches.append(legend_box)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.82, left=0.06)  # Make room for two-row legend, reduce left margin
    
    if save:
        outfile = OUT_DIR / f"{outname}.pdf"
        plt.savefig(outfile, dpi=300, bbox_inches='tight')
        plt.savefig(outfile.with_suffix('.png'), dpi=300, bbox_inches='tight')
        print(f"Saved: {outfile}")
    
    return fig, axes


# ============================================
# CLI
# ============================================

def get_all_dataset_tasks():
    """Get list of all (dataset, task) pairs."""
    pairs = []
    for dataset, config in DATASET_CONFIGS.items():
        for task in config['tasks'].keys():
            pairs.append((dataset, task))
    return pairs


def main():
    parser = argparse.ArgumentParser(
        description='Generate Top-2 PolyILR Feature Plots',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python plot.py --dataset hmp --task body_sites
    python plot.py --dataset disco_blood --task healthy_vs_leukemia
    python plot.py --combine hmp:body_sites disco_blood:healthy_vs_leukemia
    python plot.py --all
        """
    )
    
    parser.add_argument('--dataset', type=str, choices=list(DATASET_CONFIGS.keys()),
                        help='Dataset to plot')
    parser.add_argument('--task', type=str, help='Task to plot')
    parser.add_argument('--all', action='store_true', help='Generate all plots')
    parser.add_argument('--combine', nargs='+', metavar='DATASET:TASK',
                        help='Combine multiple plots (e.g., hmp:body_sites disco_blood:healthy_vs_leukemia)')
    parser.add_argument('--outname', type=str, default='combined_top2',
                        help='Output filename for combined plot (default: combined_top2)')
    parser.add_argument('--trim', type=float, default=None,
                        help='Trim outliers: percentile to clip (e.g., 1 clips to 1st-99th percentile)')
    parser.add_argument('--figsize', type=float, nargs=2, default=[7, 6],
                        help='Figure size (width height) for single plot')
    parser.add_argument('--combsize', type=float, nargs=2, default=[12, 5.5],
                        help='Figure size (width height) for combined plot')
    parser.add_argument('--show', action='store_true', help='Show plot interactively')
    
    args = parser.parse_args()
    
    if args.combine:
        # Parse dataset:task pairs
        pairs = []
        for item in args.combine:
            if ':' not in item:
                print(f"Error: Invalid format '{item}'. Use DATASET:TASK")
                return
            dataset, task = item.split(':', 1)
            if dataset not in DATASET_CONFIGS:
                print(f"Error: Unknown dataset '{dataset}'")
                return
            if task not in DATASET_CONFIGS[dataset]['tasks']:
                available = list(DATASET_CONFIGS[dataset]['tasks'].keys())
                print(f"Error: Task '{task}' not found in {dataset}. Available: {available}")
                return
            pairs.append((dataset, task))
        generate_combined_plot(pairs, figsize=tuple(args.combsize), outname=args.outname, trim_percentile=args.trim)
    elif args.all:
        pairs = get_all_dataset_tasks()
        print(f"Generating {len(pairs)} plots...")
        for dataset, task in pairs:
            try:
                generate_plot(dataset, task, figsize=tuple(args.figsize))
            except Exception as e:
                print(f"  Error with {dataset}/{task}: {e}")
    elif args.dataset and args.task:
        # Validate task
        if args.task not in DATASET_CONFIGS[args.dataset]['tasks']:
            available = list(DATASET_CONFIGS[args.dataset]['tasks'].keys())
            print(f"Error: Task '{args.task}' not found. Available: {available}")
            return
        generate_plot(args.dataset, args.task, figsize=tuple(args.figsize))
    else:
        parser.print_help()
        print("\nAvailable dataset/task combinations:")
        for dataset, config in DATASET_CONFIGS.items():
            tasks = list(config['tasks'].keys())
            print(f"  {dataset}: {tasks}")
        return
    
    if args.show:
        plt.show()


if __name__ == '__main__':
    main()