"""
PolyILR Experiments (Generalized)
=================================

Supports multiple data domains:
- Microbiome (HMP, cMD3): OTU tables + taxonomy + phylogenetic trees
- Cell biology (DISCO): Cell type proportions + ontology trees

Experiments:
1. Representation equivalence (CLR vs PolyILR)
2. Feature stability (PolyILR vs PhILR with random binarization)
3. Feature interpretation (top features → contrasts)
4. Tree-level inference (depth, nodes, subtrees)
5. Top-2 visualization data
6. Taxon/Feature importance (V² weighting)
7. Semantic stability

Usage:
    python run_experiments_new.py --all
    python run_experiments_new.py --dataset hmp
    python run_experiments_new.py --dataset disco_blood --task healthy_vs_covid
    python run_experiments_new.py --experiment representation stability
"""

import sys

# ============================================
# TESTS (run with --test)
# ============================================

def run_tests():
    """Test the bug fixes. Run with: python run_experiments_new.py --test"""
    print("=" * 70)
    print("TESTING BUG FIXES")
    print("=" * 70)
    
    # Test 1: get_node_leaves
    print("\n[Test 1] get_node_leaves fix...")
    mock_tree = {-1: [-2, -3, 0], -2: [1, 2], -3: [3, 4]}
    D = 5
    
    def get_node_leaves_test(tree, node):
        if node >= 0:
            return [node]
        leaves = []
        queue = [node]
        while queue:
            n = queue.pop(0)
            if n >= 0:
                leaves.append(n)
            else:
                queue.extend(tree.get(n, []))
        return leaves
    
    result = get_node_leaves_test(mock_tree, -2)
    assert set(result) == {1, 2}, f"Expected {{1,2}}, got {result}"
    print(f"  Node -2: {result} ✓")
    
    result = get_node_leaves_test(mock_tree, -1)
    assert set(result) == {0, 1, 2, 3, 4}
    print(f"  Root -1: {result} ✓")
    
    # Test 2: Helmert contrast interpretation
    print("\n[Test 2] Contrast interpretation...")
    children = ['A', 'B', 'C', 'D']
    for contrast_idx in [1, 2, 3]:
        left = children[contrast_idx]
        right = children[:contrast_idx]
        print(f"  contrast_idx={contrast_idx}: {left} vs {right}")
    
    print("\n" + "=" * 70)
    print("ALL TESTS PASSED ✓")
    print("=" * 70)
    return True


if '--test' in sys.argv:
    run_tests()
    sys.exit(0)

# ============================================
# IMPORTS
# ============================================

import argparse
import warnings
from pathlib import Path
from itertools import combinations
from collections import defaultdict
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, train_test_split

from tree_utils import phylo_to_tree, force_binary_random
from polyilr import construct_V, construct_V_with_mapping, ilr_transform

warnings.filterwarnings('ignore')

# ============================================
# PATHS
# ============================================

BASE_DIR = Path(__file__).parent.parent
DATA_DIR = BASE_DIR / "data"
OUT_DIR = BASE_DIR / "out"
OUT_DIR.mkdir(exist_ok=True)

# ============================================
# CONFIGURATION
# ============================================

SEEDS = [0, 1, 2, 3, 4]
N_STABILITY_SEEDS = 25
K_VALUES = [5, 10, 20, 50]
N_ESTIMATORS = 500
CV_FOLDS = 5
TOP_K_FEATURES = 10
TOP_K_NODES = 20

AVAILABLE_MODELS = {
    'RF': lambda: RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1),
    'SVM': lambda: SVC(kernel='rbf', random_state=42),
    'LogReg': lambda: LogisticRegression(max_iter=2000, random_state=42, n_jobs=-1),
}

# ============================================
# ABSTRACT DATA LOADER
# ============================================

class DataLoader(ABC):
    """Abstract base class for dataset loaders."""
    
    @abstractmethod
    def load(self) -> dict:
        """
        Load dataset and return standardized dict:
        {
            'name': str,
            'X_comp': np.array (n_samples, n_features) - compositions summing to 1,
            'tree': dict - {internal_node: [children]},
            'root': int - root node index,
            'edge_lengths': dict - {(parent, child): length},
            'D': int - number of leaves/features,
            'feature_index': list - feature names in order,
            'feature_names': dict - {leaf_idx: display_name},
            'tasks': dict - {task_name: {'y': array, 'mask': array, 'type': str}},
            'get_clade_label': callable - function(leaf_indices) -> str,
        }
        """
        pass


# ============================================
# MICROBIOME LOADER (HMP, cMD3)
# ============================================

class MicrobiomeLoader(DataLoader):
    """Loader for microbiome datasets (OTU table + taxonomy + Newick tree)."""
    
    CONFIGS = {
        'hmp': {
            'data_dir': 'hmp_taxa',
            'otu_file': 'hmp_otu_table.csv',
            'taxonomy_file': 'hmp_taxonomy.csv',
            'metadata_file': 'hmp_metadata.csv',
            'tree_file': 'hmp_tree_taxonomy_grafen.newick',
            'tasks': {
                'body_sites': {
                    'label_col': 'HMP_BODY_SITE',
                    'type': 'multiclass',
                },
                'body_subsites': {
                    'label_col': 'HMP_BODY_SUBSITE',
                    'type': 'multiclass',
                },
            },
        },
        'cmd3': {
            'data_dir': 'cmd3_taxa',
            'otu_file': 'cmd3_otu_table.csv',
            'taxonomy_file': 'cmd3_taxonomy.csv',
            'metadata_file': 'cmd3_metadata.csv',
            'tree_file': 'cmd3_tree_taxonomy_grafen.newick',
            'tasks': {
                'body_site': {
                    'label_col': 'body_site',
                    'type': 'multiclass',
                },
                'westernized': {
                    'label_col': 'non_westernized',
                    'type': 'binary',
                },
                'age_category': {
                    'label_col': 'age_category',
                    'type': 'multiclass',
                    'filter': lambda y: y != 'unknown',
                },
                'healthy_vs_disease': {
                    'label_col': 'disease',
                    'type': 'binary',
                    'transform': lambda y: np.array(['healthy' if d == 'healthy' else 'disease' for d in y]),
                },
                'stool_healthy_vs_disease': {
                    'label_col': 'disease',
                    'type': 'binary',
                    'transform': lambda y: np.array(['healthy' if d == 'healthy' else 'disease' for d in y]),
                    'filter_col': 'body_site',
                    'filter_val': 'stool',
                },
            },
        },
    }
    
    def __init__(self, dataset_name):
        if dataset_name not in self.CONFIGS:
            raise ValueError(f"Unknown microbiome dataset: {dataset_name}")
        self.dataset_name = dataset_name
        self.config = self.CONFIGS[dataset_name]
    
    def load(self):
        from Bio import Phylo
        
        config = self.config
        data_dir = DATA_DIR / config['data_dir']
        
        print(f"\n[Loading {self.dataset_name} dataset from {data_dir}...]")
        
        # Load files
        otu = pd.read_csv(data_dir / config['otu_file'], index_col=0)
        taxonomy = pd.read_csv(data_dir / config['taxonomy_file'], index_col=0)
        metadata = pd.read_csv(data_dir / config['metadata_file'], index_col=0)
        T = Phylo.read(data_dir / config['tree_file'], "newick")
        
        # Setup tree
        otu_index = list(otu.index)
        D = len(otu_index)
        tree, root, edge_lengths = phylo_to_tree(T, otu_index)
        
        # Align samples
        otu_samples = [str(s) for s in otu.columns]
        metadata.index = metadata.index.astype(str)
        common_samples = [s for s in otu_samples if s in metadata.index]
        
        # Compositions (add pseudocount)
        X_counts = otu.loc[otu_index, common_samples].values.T
        X_comp = (X_counts + 1) / (X_counts + 1).sum(axis=1, keepdims=True)
        
        # Build tasks
        tasks = {}
        for task_name, task_config in config['tasks'].items():
            y_raw = metadata.loc[common_samples, task_config['label_col']].values
            
            if 'transform' in task_config and task_config['transform'] is not None:
                y_raw = task_config['transform'](y_raw)
            
            mask = pd.notna(y_raw)
            
            if 'filter' in task_config and task_config['filter'] is not None:
                mask = mask & task_config['filter'](y_raw)
            
            if 'filter_col' in task_config:
                filter_col = metadata.loc[common_samples, task_config['filter_col']].values
                mask = mask & (filter_col == task_config['filter_val'])
            
            tasks[task_name] = {
                'y': y_raw,
                'mask': mask,
                'type': task_config['type'],
            }
        
        # Feature naming function (microbiome-specific)
        # Capture otu_index and taxonomy in closure
        _otu_index = otu_index
        _taxonomy = taxonomy
        
        def get_clade_label(leaf_indices):
            """Get common taxonomy for a set of leaves."""
            taxa_names = [_otu_index[i] for i in leaf_indices]
            sub_tax = _taxonomy.loc[taxa_names]
            common_level, common_name = None, None
            for level in ['Kingdom', 'Phylum', 'Class', 'Order', 'Family', 'Genus']:
                if level not in sub_tax.columns:
                    continue
                unique = sub_tax[level].dropna().unique()
                if len(unique) == 1:
                    common_level, common_name = level, unique[0]
                else:
                    break
            return f"{common_name} ({common_level})" if common_name else "Mixed"
        
        # Feature names for leaf-level reporting
        feature_names = {}
        for i, otu_name in enumerate(otu_index):
            tax = taxonomy.loc[otu_name]
            genus = tax.get('Genus', 'Unknown') if isinstance(tax, pd.Series) else 'Unknown'
            family = tax.get('Family', 'Unknown') if isinstance(tax, pd.Series) else 'Unknown'
            feature_names[i] = {'name': otu_name, 'genus': genus, 'family': family}
        
        print(f"  Samples: {len(common_samples)}, Taxa: {D}")
        for name, task in tasks.items():
            n_cls = len(np.unique(task['y'][task['mask']]))
            n_samp = task['mask'].sum()
            print(f"  Task '{name}': {n_cls} classes, {n_samp} samples")
        
        return {
            'name': self.dataset_name,
            'X_comp': X_comp,
            'tree': tree,
            'root': root,
            'edge_lengths': edge_lengths,
            'D': D,
            'feature_index': otu_index,
            'feature_names': feature_names,
            'tasks': tasks,
            'get_clade_label': get_clade_label,
            # Keep for backward compat
            'otu_index': otu_index,
            'taxonomy': taxonomy,
        }


# ============================================
# DISCO LOADER (Cell biology)
# ============================================

class DISCOLoader(DataLoader):
    """Loader for DISCO cell type datasets (proportions + ontology tree)."""
    
    CONFIGS = {
        'disco_blood': {
            'data_dir': 'disco_200subset',  # or 'disco_200'
            'tree_file': 'blood_tree.csv',
            'conditions': {
                'healthy': 'blood_healthy_proportions.csv',
                'covid': 'blood_covid_proportions.csv',
                'leukemia': 'blood_leukemia_proportions.csv',
            },
            'tasks': {
                'healthy_vs_covid': {
                    'conditions': ['healthy', 'covid'],
                    'type': 'binary',
                },
                'healthy_vs_leukemia': {
                    'conditions': ['healthy', 'leukemia'],
                    'type': 'binary',
                },
            },
        },
        'disco_liver': {
            'data_dir': 'disco_200subset',
            'tree_file': 'liver_tree.csv',
            'conditions': {
                'healthy': 'liver_healthy_proportions.csv',
                'hcc': 'liver_hcc_proportions.csv',
            },
            'tasks': {
                'healthy_vs_hcc': {
                    'conditions': ['healthy', 'hcc'],
                    'type': 'binary',
                },
            },
        },
    }
    
    def __init__(self, dataset_name):
        if dataset_name not in self.CONFIGS:
            raise ValueError(f"Unknown DISCO dataset: {dataset_name}")
        self.dataset_name = dataset_name
        self.config = self.CONFIGS[dataset_name]
    
    def load(self):
        config = self.config
        data_dir = DATA_DIR / config['data_dir']
        
        print(f"\n[Loading {self.dataset_name} dataset from {data_dir}...]")
        
        # Load tree edges
        tree_df = pd.read_csv(data_dir / config['tree_file'])
        
        # Load all condition data
        condition_data = {}
        all_cell_types = set()
        meta_cols = ['sample_id', 'condition', 'tissue']
        
        for cond_name, cond_file in config['conditions'].items():
            df = pd.read_csv(data_dir / cond_file)
            condition_data[cond_name] = df
            ct_cols = [c for c in df.columns if c not in meta_cols]
            all_cell_types.update(ct_cols)
        
        # Find true leaves (cell types that appear in data and are tree leaves)
        tree_parents = set(tree_df['parent'])
        tree_children = set(tree_df['child'])
        tree_leaves = tree_children - tree_parents
        
        # Intersect with data cell types
        valid_leaves = all_cell_types & tree_leaves
        
        # Build tree in internal format
        # leaves = 0..D-1, internal nodes = -1, -2, ...
        leaf_names = sorted(valid_leaves)
        leaf_to_idx = {name: i for i, name in enumerate(leaf_names)}
        D = len(leaf_names)
        
        # Find internal nodes (nodes that are parents)
        internal_names = sorted(tree_parents - valid_leaves - {'Root'})
        internal_to_idx = {name: -(i+1) for i, name in enumerate(internal_names)}
        internal_to_idx['Root'] = -(len(internal_names) + 1)
        
        idx_to_name = {**{v: k for k, v in leaf_to_idx.items()},
                       **{v: k for k, v in internal_to_idx.items()}}
        
        def name_to_idx(name):
            if name in leaf_to_idx:
                return leaf_to_idx[name]
            return internal_to_idx.get(name, None)
        
        # Build tree dict
        tree = {}
        for parent in tree_df['parent'].unique():
            parent_idx = name_to_idx(parent)
            if parent_idx is None:
                continue
            children = tree_df[tree_df['parent'] == parent]['child'].tolist()
            children_idx = [name_to_idx(c) for c in children if name_to_idx(c) is not None]
            if children_idx:
                tree[parent_idx] = children_idx
        
        root = internal_to_idx['Root']
        
        # Edge lengths (all 1.0 for ontology)
        edge_lengths = {}
        for parent_idx, children_idx in tree.items():
            for child_idx in children_idx:
                edge_lengths[(parent_idx, child_idx)] = 1.0
        
        # Build compositions matrix and tasks
        tasks = {}
        all_X = []
        all_y = []
        all_conditions = []
        
        for task_name, task_config in config['tasks'].items():
            task_conditions = task_config['conditions']
            task_X = []
            task_y = []
            
            for cond in task_conditions:
                df = condition_data[cond].copy()
                
                # Align columns to leaf_names, fill missing with 0
                for leaf in leaf_names:
                    if leaf not in df.columns:
                        df[leaf] = 0.0
                
                X = df[leaf_names].values
                # Renormalize
                X = X + 1e-10
                X = X / X.sum(axis=1, keepdims=True)
                
                task_X.append(X)
                task_y.extend([cond] * len(X))
            
            X_task = np.vstack(task_X)
            y_task = np.array(task_y)
            
            tasks[task_name] = {
                'y': y_task,
                'mask': np.ones(len(y_task), dtype=bool),
                'type': task_config['type'],
                'X_comp': X_task,  # Task-specific compositions
            }
        
        # Use first task's data as default X_comp
        first_task = list(tasks.keys())[0]
        X_comp = tasks[first_task]['X_comp']
        
        # Build parent map for finding common ancestors
        parent_map = {}  # child_idx -> parent_idx
        for parent_idx, children_idx in tree.items():
            for child_idx in children_idx:
                parent_map[child_idx] = parent_idx
        
        def get_ancestors(node_idx):
            """Get list of ancestors from node to root."""
            ancestors = []
            current = node_idx
            while current in parent_map:
                current = parent_map[current]
                ancestors.append(current)
            return ancestors
        
        # Feature naming function (cell type specific)
        def get_clade_label(leaf_indices):
            if len(leaf_indices) == 1:
                return idx_to_name.get(leaf_indices[0], "Unknown")
            
            # Find lowest common ancestor (LCA)
            # Get ancestors for first leaf
            leaf_indices = list(leaf_indices)
            ancestors_0 = [leaf_indices[0]] + get_ancestors(leaf_indices[0])
            
            # Find LCA by checking which ancestor contains all leaves
            for ancestor in ancestors_0:
                # Get all leaves under this ancestor
                def get_leaves_under(node):
                    if node >= 0:
                        return {node}
                    leaves = set()
                    for child in tree.get(node, []):
                        leaves.update(get_leaves_under(child))
                    return leaves
                
                leaves_under = get_leaves_under(ancestor)
                if set(leaf_indices).issubset(leaves_under):
                    # This ancestor contains all our leaves
                    # Check if it's the tightest fit (all its leaves are in our set)
                    if leaves_under == set(leaf_indices):
                        return idx_to_name.get(ancestor, "Unknown")
            
            # Fallback: return the first common ancestor found
            for ancestor in ancestors_0:
                leaves_under = set()
                def collect_leaves(node):
                    if node >= 0:
                        leaves_under.add(node)
                    else:
                        for child in tree.get(node, []):
                            collect_leaves(child)
                collect_leaves(ancestor)
                
                if set(leaf_indices).issubset(leaves_under):
                    return idx_to_name.get(ancestor, "Unknown")
            
            # Last resort: just return first leaf name
            return idx_to_name.get(leaf_indices[0], "Unknown")
        
        # Feature names
        feature_names = {i: {'name': name, 'cell_type': name} for i, name in enumerate(leaf_names)}
        
        print(f"  Cell types: {D}")
        print(f"  Internal nodes: {len(internal_to_idx)}")
        for name, task in tasks.items():
            n_cls = len(np.unique(task['y']))
            n_samp = len(task['y'])
            print(f"  Task '{name}': {n_cls} classes, {n_samp} samples")
        
        return {
            'name': self.dataset_name,
            'X_comp': X_comp,
            'tree': tree,
            'root': root,
            'edge_lengths': edge_lengths,
            'D': D,
            'feature_index': leaf_names,
            'feature_names': feature_names,
            'tasks': tasks,
            'get_clade_label': get_clade_label,
            'idx_to_name': idx_to_name,
        }


# ============================================
# UNIFIED LOADER
# ============================================

def get_loader(dataset_name):
    """Return appropriate loader for dataset."""
    if dataset_name in MicrobiomeLoader.CONFIGS:
        return MicrobiomeLoader(dataset_name)
    elif dataset_name in DISCOLoader.CONFIGS:
        return DISCOLoader(dataset_name)
    else:
        available = list(MicrobiomeLoader.CONFIGS.keys()) + list(DISCOLoader.CONFIGS.keys())
        raise ValueError(f"Unknown dataset: {dataset_name}. Available: {available}")


def load_dataset(dataset_name):
    """Load dataset using appropriate loader."""
    loader = get_loader(dataset_name)
    return loader.load()


# ============================================
# HELPER FUNCTIONS
# ============================================

def get_node_depth(tree, root):
    """Compute depth of each node."""
    depth = {root: 0}
    queue = [root]
    while queue:
        node = queue.pop(0)
        for child in tree.get(node, []):
            depth[child] = depth[node] + 1
            queue.append(child)
    return depth


def get_node_leaves(tree, node, D):
    """Get all leaf indices under a node."""
    if node >= 0:
        return [node]
    leaves = []
    queue = [node]
    while queue:
        n = queue.pop(0)
        if n >= 0:
            leaves.append(n)
        else:
            queue.extend(tree.get(n, []))
    return leaves


def jaccard(set1, set2):
    """Jaccard similarity between two sets."""
    if len(set1) == 0 and len(set2) == 0:
        return 1.0
    return len(set1 & set2) / len(set1 | set2)


# ============================================
# EXPERIMENT 1: REPRESENTATION EQUIVALENCE
# ============================================

def run_representation(data, task_name, models):
    """Compare CLR vs PolyILR accuracy."""
    print(f"\n  [Representation: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    
    # Use task-specific X_comp if available (for DISCO)
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    X_comp = task.get('X_comp')
    if X_comp is None:
        X_comp = data['X_comp'][mask]
    
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    
    # CLR
    log_X = np.log(X_comp)
    X_clr = log_X - log_X.mean(axis=1, keepdims=True)
    
    # PolyILR
    V, _ = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_polyilr = ilr_transform(X_comp, V)
    
    representations = {'CLR': X_clr, 'PolyILR': X_polyilr}
    
    results = []
    for rep_name, X in representations.items():
        for model_name in models:
            if model_name not in AVAILABLE_MODELS:
                continue
            
            clf = AVAILABLE_MODELS[model_name]()
            scores = cross_val_score(clf, X, y, cv=CV_FOLDS, scoring='accuracy')
            
            results.append({
                'representation': rep_name,
                'model': model_name,
                'accuracy_mean': scores.mean(),
                'accuracy_std': scores.std(),
                'n_samples': len(y),
                'n_classes': len(np.unique(y)),
            })
            print(f"    {rep_name:8s} + {model_name:6s}: {scores.mean():.3f} ± {scores.std():.3f}")
    
    return pd.DataFrame(results)


# ============================================
# EXPERIMENT 2: FEATURE STABILITY
# ============================================

def run_stability(data, task_name):
    """Compare PolyILR vs PhILR feature stability."""
    print(f"\n  [Stability: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    
    # PolyILR (deterministic)
    V_poly, _ = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_poly = ilr_transform(X_comp, V_poly)
    
    def get_top_k(X, y, k, seed):
        clf = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=seed, n_jobs=-1)
        clf.fit(X, y)
        return set(np.argsort(clf.feature_importances_)[-k:])
    
    results = []
    for k in K_VALUES:
        print(f"    k={k}...", end=" ")
        
        # PolyILR stability
        poly_features = [get_top_k(X_poly, y, k, seed) for seed in range(N_STABILITY_SEEDS)]
        poly_jaccards = [jaccard(f1, f2) for f1, f2 in combinations(poly_features, 2)]
        
        # PhILR stability (random binarizations)
        philr_features = []
        for seed in range(N_STABILITY_SEEDS):
            rng = np.random.default_rng(seed)
            tree_binary, root_binary, edge_lengths_binary = force_binary_random(tree, root, edge_lengths, rng)
            V_philr = construct_V(tree_binary, root_binary, D, edge_lengths=edge_lengths_binary)
            X_philr = ilr_transform(X_comp, V_philr)
            philr_features.append(get_top_k(X_philr, y, k, seed))
        
        philr_jaccards = [jaccard(f1, f2) for f1, f2 in combinations(philr_features, 2)]
        
        results.append({
            'k': k,
            'method': 'PolyILR',
            'jaccard_mean': np.mean(poly_jaccards),
            'jaccard_std': np.std(poly_jaccards),
        })
        results.append({
            'k': k,
            'method': 'PhILR',
            'jaccard_mean': np.mean(philr_jaccards),
            'jaccard_std': np.std(philr_jaccards),
        })
        
        print(f"PolyILR={np.mean(poly_jaccards):.3f}, PhILR={np.mean(philr_jaccards):.3f}")
    
    return pd.DataFrame(results)


# ============================================
# EXPERIMENT 3: FEATURE INTERPRETATION
# ============================================

def run_interpretation(data, task_name, top_k=TOP_K_FEATURES):
    """Identify top features and their contrasts."""
    print(f"\n  [Interpretation: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    get_clade_label = data['get_clade_label']
    
    V, column_mapping = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_ilr = ilr_transform(X_comp, V)
    
    # Train RF
    X_train, X_test, y_train, y_test = train_test_split(
        X_ilr, y, test_size=0.1, stratify=y, random_state=42
    )
    clf = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
    clf.fit(X_train, y_train)
    
    test_acc = clf.score(X_test, y_test)
    print(f"    Test accuracy: {test_acc:.3f}")
    
    importance = clf.feature_importances_
    top_indices = np.argsort(importance)[-top_k:][::-1]
    
    results = []
    for rank, idx in enumerate(top_indices, 1):
        info = column_mapping[idx]
        children_labels = [get_clade_label(leaves) for leaves in info['child_leaves']]
        m = info['contrast_idx']
        left = children_labels[m] if m < len(children_labels) else "?"
        right_list = children_labels[:m]
        right = " + ".join(right_list[:3])
        if len(right_list) > 3:
            right += f" + ... ({len(right_list)} total)"
        
        results.append({
            'rank': rank,
            'feature_idx': idx,
            'importance': importance[idx],
            'node_id': info['node'],
            'contrast_idx': m,
            'n_contrasts': info['n_contrasts'],
            'left': left,
            'right': right,
            'contrast': f"{left} vs {right}",
        })
        print(f"    [{rank}] {left} vs {right} (imp={importance[idx]:.4f})")
    
    df = pd.DataFrame(results)
    df.attrs['test_accuracy'] = test_acc
    return df


# ============================================
# EXPERIMENT 4: TREE-LEVEL INFERENCE
# ============================================

def run_tree_inference(data, task_name):
    """Analyze importance by depth, nodes, subtrees."""
    print(f"\n  [Tree Inference: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    get_clade_label = data['get_clade_label']
    
    V, column_mapping = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_ilr = ilr_transform(X_comp, V)
    
    # Train RF for importance
    clf = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
    clf.fit(X_ilr, y)
    importance = clf.feature_importances_
    
    # Build mappings
    node_depth = get_node_depth(tree, root)
    
    depth_to_coords = defaultdict(list)
    node_to_coords = defaultdict(list)
    
    for idx, info in enumerate(column_mapping):
        node = info['node']
        depth = node_depth.get(node, 0)
        depth_to_coords[depth].append(idx)
        node_to_coords[node].append(idx)
    
    max_depth = max(depth_to_coords.keys())
    
    # --- Depth analysis ---
    print("    By Depth:")
    depth_results = []
    for d in range(max_depth + 1):
        coord_idx = depth_to_coords[d]
        if len(coord_idx) == 0:
            continue
        
        X_depth = X_ilr[:, coord_idx]
        clf_d = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
        scores = cross_val_score(clf_d, X_depth, y, cv=CV_FOLDS, scoring='accuracy')
        
        imp_sum = np.sum(importance[coord_idx])
        imp_pct = 100 * imp_sum / np.sum(importance)
        
        depth_results.append({
            'depth': d,
            'n_coords': len(coord_idx),
            'accuracy_mean': scores.mean(),
            'accuracy_std': scores.std(),
            'importance_sum': imp_sum,
            'importance_pct': imp_pct,
        })
        print(f"      Depth {d}: {len(coord_idx)} coords, acc={scores.mean():.3f}, imp={imp_pct:.1f}%")
    
    # Cumulative depth
    print("    Cumulative Depth:")
    cumulative_results = []
    for d in range(max_depth + 1):
        coord_idx = []
        for dd in range(d + 1):
            coord_idx.extend(depth_to_coords[dd])
        
        if len(coord_idx) == 0:
            continue
        
        X_cum = X_ilr[:, coord_idx]
        clf_c = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
        scores = cross_val_score(clf_c, X_cum, y, cv=CV_FOLDS, scoring='accuracy')
        
        imp_sum = np.sum(importance[coord_idx])
        imp_pct = 100 * imp_sum / np.sum(importance)
        
        cumulative_results.append({
            'depth_up_to': d,
            'n_coords': len(coord_idx),
            'accuracy_mean': scores.mean(),
            'accuracy_std': scores.std(),
            'importance_sum': imp_sum,
            'importance_pct': imp_pct,
        })
        print(f"      Depth 0-{d}: {len(coord_idx)} coords, acc={scores.mean():.3f}, imp={imp_pct:.1f}%")
    
    # --- Node analysis ---
    print(f"    Top {TOP_K_NODES} Nodes by Importance:")
    node_importance = {}
    for node, coord_idx in node_to_coords.items():
        node_importance[node] = np.sum(importance[coord_idx])
    
    top_nodes = sorted(node_importance.items(), key=lambda x: -x[1])[:TOP_K_NODES]
    
    node_results = []
    for rank, (node, imp) in enumerate(top_nodes, 1):
        coord_idx = node_to_coords[node]
        leaves = get_node_leaves(tree, node, D)
        node_label = get_clade_label(leaves)
        imp_pct = 100 * imp / np.sum(importance)
        
        node_results.append({
            'rank': rank,
            'node_id': node,
            'label': node_label,
            'n_coords': len(coord_idx),
            'depth': node_depth.get(node, -1),
            'n_leaves': len(leaves),
            'importance_sum': imp,
            'importance_pct': imp_pct,
        })
        print(f"      [{rank}] {node_label}: {len(coord_idx)} coords, imp={imp_pct:.1f}%")
    
    # --- Subtree analysis ---
    print("    Subtrees (children of root):")
    root_children = tree.get(root, [])
    
    subtree_results = []
    for child in root_children:
        subtree_nodes = [child]
        queue = [child]
        while queue:
            n = queue.pop(0)
            for c in tree.get(n, []):
                subtree_nodes.append(c)
                queue.append(c)
        
        coord_idx = []
        for n in subtree_nodes:
            coord_idx.extend(node_to_coords.get(n, []))
        
        if len(coord_idx) == 0:
            continue
        
        leaves = get_node_leaves(tree, child, D)
        subtree_label = get_clade_label(leaves)
        
        X_sub = X_ilr[:, coord_idx]
        clf_s = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
        scores = cross_val_score(clf_s, X_sub, y, cv=CV_FOLDS, scoring='accuracy')
        
        imp_sum = np.sum(importance[coord_idx])
        imp_pct = 100 * imp_sum / np.sum(importance)
        
        subtree_results.append({
            'subtree_root': child,
            'label': subtree_label,
            'n_coords': len(coord_idx),
            'n_leaves': len(leaves),
            'accuracy_mean': scores.mean(),
            'accuracy_std': scores.std(),
            'importance_sum': imp_sum,
            'importance_pct': imp_pct,
        })
        print(f"      {subtree_label}: {len(coord_idx)} coords, acc={scores.mean():.3f}, imp={imp_pct:.1f}%")
    
    return {
        'depth': pd.DataFrame(depth_results),
        'cumulative_depth': pd.DataFrame(cumulative_results),
        'nodes': pd.DataFrame(node_results),
        'subtrees': pd.DataFrame(subtree_results),
    }


# ============================================
# EXPERIMENT 5: TOP-2 VISUALIZATION
# ============================================

def run_top2_viz(data, task_name):
    """Get data for top-2 feature visualization."""
    print(f"\n  [Top-2 Visualization: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    get_clade_label = data['get_clade_label']
    
    V, column_mapping = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_ilr = ilr_transform(X_comp, V)
    
    # Train RF
    X_train, X_test, y_train, y_test = train_test_split(
        X_ilr, y, test_size=0.1, stratify=y, random_state=42
    )
    clf = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
    clf.fit(X_train, y_train)
    
    test_acc = clf.score(X_test, y_test)
    importance = clf.feature_importances_
    top2_idx = np.argsort(importance)[-2:][::-1]
    
    def get_contrast_label(idx):
        info = column_mapping[idx]
        children_labels = [get_clade_label(leaves) for leaves in info['child_leaves']]
        m = info['contrast_idx']
        left = children_labels[m] if m < len(children_labels) else "?"
        right_list = children_labels[:m]
        right = " + ".join(right_list[:2])
        if len(right_list) > 2:
            right += f" + ..."
        return f"{left} vs {right}"
    
    label1 = get_contrast_label(top2_idx[0])
    label2 = get_contrast_label(top2_idx[1])
    
    print(f"    Test accuracy: {test_acc:.3f}")
    print(f"    Feature 1: {label1} (imp={importance[top2_idx[0]]:.4f})")
    print(f"    Feature 2: {label2} (imp={importance[top2_idx[1]]:.4f})")
    
    X_top2 = X_ilr[:, top2_idx]
    df = pd.DataFrame({
        'coord1': X_top2[:, 0],
        'coord2': X_top2[:, 1],
        'label': y_raw,
        'label_encoded': y,
    })
    
    df.attrs['test_accuracy'] = test_acc
    df.attrs['feature1_idx'] = int(top2_idx[0])
    df.attrs['feature2_idx'] = int(top2_idx[1])
    df.attrs['feature1_label'] = label1
    df.attrs['feature2_label'] = label2
    df.attrs['feature1_importance'] = float(importance[top2_idx[0]])
    df.attrs['feature2_importance'] = float(importance[top2_idx[1]])
    
    return df


# ============================================
# EXPERIMENT 6: FEATURE IMPORTANCE (V² weighting)
# ============================================

def run_paths(data, task_name):
    """Compute feature-specific importance using V² weighting."""
    print(f"\n  [Feature Importance: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    feature_index = data['feature_index']
    feature_names = data['feature_names']
    
    V, column_mapping = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_ilr = ilr_transform(X_comp, V)
    
    # Train RF for importance
    clf = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=42, n_jobs=-1)
    clf.fit(X_ilr, y)
    importance = clf.feature_importances_
    
    # Compute feature importance using V² weighting
    print(f"    Computing feature importance for {D} features...")
    
    V_squared = V ** 2
    feature_importance = V_squared @ importance
    
    total_imp = np.sum(feature_importance)
    feature_importance_pct = 100 * feature_importance / total_imp if total_imp > 0 else feature_importance
    
    # Build results
    results = []
    for leaf_idx in range(D):
        leaf_name = feature_index[leaf_idx]
        info = feature_names.get(leaf_idx, {})
        
        results.append({
            'leaf_idx': leaf_idx,
            'leaf_name': leaf_name,
            'label1': info.get('genus', info.get('cell_type', leaf_name)),
            'label2': info.get('family', ''),
            'importance': feature_importance[leaf_idx],
            'importance_pct': feature_importance_pct[leaf_idx],
        })
    
    results_sorted = sorted(results, key=lambda x: -x['importance'])
    
    print(f"    Top 20 Features by Importance:")
    for i, res in enumerate(results_sorted[:20], 1):
        print(f"      [{i}] {res['label1']} ({res['label2']}): imp={res['importance_pct']:.2f}%")
    
    total_pct = sum(r['importance_pct'] for r in results)
    print(f"    Total importance: {total_pct:.1f}%")
    
    return pd.DataFrame(results_sorted)


# ============================================
# EXPERIMENT 7: SEMANTIC STABILITY
# ============================================

def get_contrast_semantics(column_mapping, feature_idx, D, tree):
    """Extract the semantic contrast (feature partition) for a feature."""
    info = column_mapping[feature_idx]
    child_leaves = info['child_leaves']
    m = info['contrast_idx']
    
    left_leaves = frozenset(child_leaves[m]) if m < len(child_leaves) else frozenset()
    
    right_leaves = frozenset()
    for i, leaves in enumerate(child_leaves):
        if i != m:
            right_leaves = right_leaves | frozenset(leaves)
    
    if len(left_leaves) <= len(right_leaves):
        return (left_leaves, right_leaves)
    else:
        return (right_leaves, left_leaves)


def run_stability_semantic(data, task_name):
    """Compare PolyILR vs PhILR stability using semantic contrasts."""
    print(f"\n  [Semantic Stability: {task_name}]")
    
    task = data['tasks'][task_name]
    mask = task['mask']
    if 'X_comp' in task:
        X_comp = task['X_comp']
    else:
        X_comp = data['X_comp'][mask]
    y_raw = task['y'][mask]
    
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    
    if len(np.unique(y)) < 2:
        print("    Skipping: fewer than 2 classes")
        return None
    
    tree, root, D = data['tree'], data['root'], data['D']
    edge_lengths = data['edge_lengths']
    
    V_poly, mapping_poly = construct_V_with_mapping(tree, root, D, edge_lengths)
    X_poly = ilr_transform(X_comp, V_poly)
    
    def get_top_k_contrasts(X, y, column_mapping, tree, D, k, seed):
        clf = RandomForestClassifier(n_estimators=N_ESTIMATORS, random_state=seed, n_jobs=-1)
        clf.fit(X, y)
        top_k_idx = np.argsort(clf.feature_importances_)[-k:]
        
        contrasts = set()
        for idx in top_k_idx:
            contrast = get_contrast_semantics(column_mapping, idx, D, tree)
            contrasts.add(contrast)
        return contrasts
    
    def jaccard_contrasts(set1, set2):
        if len(set1) == 0 and len(set2) == 0:
            return 1.0
        return len(set1 & set2) / len(set1 | set2)
    
    results = []
    
    for k in K_VALUES:
        print(f"    k={k}...", end=" ")
        
        poly_contrasts = [
            get_top_k_contrasts(X_poly, y, mapping_poly, tree, D, k, seed)
            for seed in range(N_STABILITY_SEEDS)
        ]
        poly_jaccards = [jaccard_contrasts(c1, c2) for c1, c2 in combinations(poly_contrasts, 2)]
        
        philr_contrasts = []
        for seed in range(N_STABILITY_SEEDS):
            rng = np.random.default_rng(seed)
            tree_binary, root_binary, edge_lengths_binary = force_binary_random(tree, root, edge_lengths, rng)
            V_philr, mapping_philr = construct_V_with_mapping(tree_binary, root_binary, D, edge_lengths_binary)
            X_philr = ilr_transform(X_comp, V_philr)
            
            contrasts = get_top_k_contrasts(X_philr, y, mapping_philr, tree_binary, D, k, seed)
            philr_contrasts.append(contrasts)
        
        philr_jaccards = [jaccard_contrasts(c1, c2) for c1, c2 in combinations(philr_contrasts, 2)]
        
        results.append({
            'k': k,
            'method': 'PolyILR',
            'jaccard_mean': np.mean(poly_jaccards),
            'jaccard_std': np.std(poly_jaccards),
        })
        results.append({
            'k': k,
            'method': 'PhILR',
            'jaccard_mean': np.mean(philr_jaccards),
            'jaccard_std': np.std(philr_jaccards),
        })
        
        print(f"PolyILR={np.mean(poly_jaccards):.3f}, PhILR={np.mean(philr_jaccards):.3f}")
    
    return pd.DataFrame(results)


# ============================================
# MAIN RUNNER
# ============================================

def run_experiments(datasets, tasks, experiments, models, force=False):
    """Run specified experiments."""
    
    for dataset_name in datasets:
        data = load_dataset(dataset_name)
        
        task_list = tasks if tasks else list(data['tasks'].keys())
        
        for task_name in task_list:
            if task_name not in data['tasks']:
                print(f"  Warning: Task '{task_name}' not found in {dataset_name}, skipping.")
                continue
            
            prefix = f"{dataset_name}_{task_name}"
            
            # Experiment 1: Representation
            if 'representation' in experiments or 'all' in experiments:
                outfile = OUT_DIR / f"{prefix}_representation.csv"
                if outfile.exists() and not force:
                    print(f"\n  [Representation: {task_name}] Skipping (file exists)")
                else:
                    df = run_representation(data, task_name, models)
                    if df is not None:
                        df.to_csv(outfile, index=False)
                        print(f"    Saved: {outfile}")
            
            # Experiment 2: Stability
            if 'stability' in experiments or 'all' in experiments:
                outfile = OUT_DIR / f"{prefix}_stability.csv"
                if outfile.exists() and not force:
                    print(f"\n  [Stability: {task_name}] Skipping (file exists)")
                else:
                    df = run_stability(data, task_name)
                    if df is not None:
                        df.to_csv(outfile, index=False)
                        print(f"    Saved: {outfile}")
            
            # Experiment 3: Interpretation
            if 'interpretation' in experiments or 'all' in experiments:
                outfile = OUT_DIR / f"{prefix}_interpretation.csv"
                if outfile.exists() and not force:
                    print(f"\n  [Interpretation: {task_name}] Skipping (file exists)")
                else:
                    df = run_interpretation(data, task_name)
                    if df is not None:
                        with open(outfile, 'w') as f:
                            f.write(f"# test_accuracy={df.attrs.get('test_accuracy', 'NA')}\n")
                        df.to_csv(outfile, index=False, mode='a')
                        print(f"    Saved: {outfile}")
            
            # Experiment 4: Tree Inference
            if 'tree_inference' in experiments or 'all' in experiments:
                tree_files = [f"{prefix}_tree_{name}.csv" for name in ['depth', 'cumulative_depth', 'nodes', 'subtrees']]
                if all((OUT_DIR / f).exists() for f in tree_files) and not force:
                    print(f"\n  [Tree Inference: {task_name}] Skipping (files exist)")
                else:
                    results = run_tree_inference(data, task_name)
                    if results is not None:
                        for name, df in results.items():
                            outfile = OUT_DIR / f"{prefix}_tree_{name}.csv"
                            df.to_csv(outfile, index=False)
                            print(f"    Saved: {outfile}")
            
            # Experiment 5: Top-2 Visualization
            if 'top2_viz' in experiments or 'all' in experiments:
                outfile = OUT_DIR / f"{prefix}_top2_viz.csv"
                if outfile.exists() and not force:
                    print(f"\n  [Top-2 Visualization: {task_name}] Skipping (file exists)")
                else:
                    df = run_top2_viz(data, task_name)
                    if df is not None:
                        with open(outfile, 'w') as f:
                            f.write(f"# test_accuracy={df.attrs.get('test_accuracy', 'NA')}\n")
                            f.write(f"# feature1_label={df.attrs.get('feature1_label', 'NA')}\n")
                            f.write(f"# feature2_label={df.attrs.get('feature2_label', 'NA')}\n")
                            f.write(f"# feature1_importance={df.attrs.get('feature1_importance', 'NA')}\n")
                            f.write(f"# feature2_importance={df.attrs.get('feature2_importance', 'NA')}\n")
                        df.to_csv(outfile, index=False, mode='a')
                        print(f"    Saved: {outfile}")
            
            # Experiment 6: Paths
            if 'paths' in experiments:
                outfile = OUT_DIR / f"{prefix}_tree_paths.csv"
                if outfile.exists() and not force:
                    print(f"\n  [Feature Importance: {task_name}] Skipping (file exists)")
                else:
                    df = run_paths(data, task_name)
                    if df is not None:
                        df.to_csv(outfile, index=False)
                        print(f"    Saved: {outfile}")
            
            # Experiment 7: Semantic Stability
            if 'stability_semantic' in experiments or 'all' in experiments:
                outfile = OUT_DIR / f"{prefix}_stability_semantic.csv"
                if outfile.exists() and not force:
                    print(f"\n  [Semantic Stability: {task_name}] Skipping (file exists)")
                else:
                    df = run_stability_semantic(data, task_name)
                    if df is not None:
                        df.to_csv(outfile, index=False)
                        print(f"    Saved: {outfile}")


# ============================================
# CLI
# ============================================

def get_all_datasets():
    """Get list of all available datasets."""
    return list(MicrobiomeLoader.CONFIGS.keys()) + list(DISCOLoader.CONFIGS.keys())


def main():
    parser = argparse.ArgumentParser(
        description='PolyILR Experiments (Generalized)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python run_experiments_new.py --all
    python run_experiments_new.py --dataset hmp
    python run_experiments_new.py --dataset disco_blood --task healthy_vs_covid
    python run_experiments_new.py --experiment representation stability
        """
    )
    
    all_datasets = get_all_datasets()
    
    parser.add_argument('--all', action='store_true',
                        help='Run all experiments on all datasets and tasks')
    parser.add_argument('--dataset', nargs='+', default=None,
                        choices=all_datasets,
                        help='Dataset(s) to run experiments on')
    parser.add_argument('--task', nargs='+', default=None,
                        help='Task(s) to run (dataset-specific)')
    parser.add_argument('--models', nargs='+', default=['RF', 'SVM', 'LogReg'],
                        choices=list(AVAILABLE_MODELS.keys()),
                        help='Models for representation comparison')
    parser.add_argument('--force', action='store_true',
                        help='Force re-run even if output files exist')
    parser.add_argument('--experiment', nargs='+', default=None,
                        choices=['representation', 'stability', 'stability_semantic',
                                'interpretation', 'tree_inference', 'top2_viz', 'paths', 'all'],
                        help='Experiment(s) to run')
    
    args = parser.parse_args()
    
    if args.all:
        datasets = all_datasets
        tasks = None
        experiments = ['all']
    else:
        datasets = args.dataset if args.dataset else all_datasets
        tasks = args.task
        experiments = args.experiment if args.experiment else ['all']
    
    print("=" * 70)
    print("PolyILR EXPERIMENTS (GENERALIZED)")
    print("=" * 70)
    print(f"Datasets: {datasets}")
    print(f"Tasks: {tasks if tasks else 'all'}")
    print(f"Experiments: {experiments}")
    print(f"Models: {args.models}")
    print(f"Force re-run: {args.force}")
    print(f"Output directory: {OUT_DIR}")
    print("=" * 70)
    
    run_experiments(datasets, tasks, experiments, args.models, args.force)
    
    print("\n" + "=" * 70)
    print("DONE!")
    print(f"Results saved to: {OUT_DIR}")
    print("=" * 70)


if __name__ == '__main__':
    main()