#!/usr/bin/env python3
"""
Text Distribution Experiment Runner

This script runs a systematic experiment to test the effect of text distribution
on attention convergence rates. It tests:
    1. Shuffle ablation (sanity check)
    2. Domain variation
    3. Language variation
    4. Mixed sources (optional)

Results are stored in: results_text_distribution/run_YYYYMMDD_HHMMSS/

Usage:
    python run_distribution_sweep.py
    python run_distribution_sweep.py --quick          # Reduced params for testing
    python run_distribution_sweep.py --layer_id 0     # Specific layer
    python run_distribution_sweep.py --skip_shuffle   # Skip shuffle tests
    python run_distribution_sweep.py --skip_language  # Skip language tests
"""

import os
import sys
import argparse
import datetime as dt
import pandas as pd
import json
from pathlib import Path
from typing import List, Optional

# Import the main experiment function
from run_convergence_experiment import (
    run_experiment,
    set_force_cpu,
    set_mc_on_gpu,
    set_limits_on_gpu,
)


# ============== EXPERIMENT CONFIGURATIONS ==============

def get_experiment_suite(include_shuffle=True, include_domain=True,
                         include_language=True, include_mixed=True,
                         include_niah=False, include_mixed_language=False,
                         include_repetition=False):
    """
    Define the full experiment suite.

    Returns list of experiment configs, each with:
        - name: Experiment identifier
        - description: What we're testing
        - text_source, text_language, text_shuffle, etc.
    """
    experiments = []

    # ============== 1. BASELINE ==============
    experiments.append({
        "name": "baseline_wiki_en",
        "description": "Baseline: English Wikipedia",
        "text_source": "wiki",
        "text_language": "en",
        "text_shuffle": "none",
        "text_mix_config": None,
        "text_mix_strategy": "concat",
        "text_language_mix_config": None,
        "text_repeat_config": None,
        "category": "baseline"
    })

    # ============== 2. SHUFFLE ABLATION (Sanity Check) ==============
    if include_shuffle:
        experiments.extend([
            {
                "name": "shuffle_sentence",
                "description": "Sanity check: Sentence-level shuffle",
                "text_source": "wiki",
                "text_language": "en",
                "text_shuffle": "sentence",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "shuffle_ablation"
            },
            {
                "name": "shuffle_word",
                "description": "Sanity check: Word-level shuffle (no grammar)",
                "text_source": "wiki",
                "text_language": "en",
                "text_shuffle": "word",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "shuffle_ablation"
            },
        ])

    # ============== 3. DOMAIN VARIATION ==============
    if include_domain:
        experiments.extend([
            {
                "name": "domain_news",
                "description": "Domain test: News articles",
                "text_source": "news",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "domain"
            },
            {
                "name": "domain_scientific",
                "description": "Domain test: ArXiv scientific papers",
                "text_source": "scientific",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "domain"
            },
        ])

    # ============== 4. LANGUAGE VARIATION ==============
    if include_language:
        experiments.extend([
            {
                "name": "lang_german",
                "description": "Language test: German (morphologically rich)",
                "text_source": "wiki",
                "text_language": "de",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "language"
            },
            {
                "name": "lang_chinese",
                "description": "Language test: Chinese (logographic)",
                "text_source": "wiki",
                "text_language": "zh",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "language"
            },
        ])

    # ============== 5. MIXED SOURCES (Optional) ==============
    if include_mixed:
        experiments.extend([
            {
                "name": "mixed_wiki_news_50_50",
                "description": "Mixed: 50% Wiki + 50% News (concatenated)",
                "text_source": "mixed",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": {"wiki": 0.5, "news": 0.5},
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "mixed"
            },
            {
                "name": "mixed_wiki_news_95_5",
                "description": "Mixed imbalanced: 95% Wiki + 5% News",
                "text_source": "mixed",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": {"wiki": 0.95, "news": 0.05},
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "mixed"
            },
            {
                "name": "mixed_code_news_90_10",
                "description": "Mixed imbalanced: 90% Code + 10% News",
                "text_source": "mixed",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": {"code": 0.9, "news": 0.1},
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "mixed"
            },
        ])

    # ============== 6. MR-NIAH BENCHMARK ==============
    if include_niah:
        experiments.extend([
            {
                "name": "niah_english",
                "description": "MR-NIAH benchmark haystack (English)",
                "text_source": "mr_niah",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "niah"
            },
            {
                "name": "niah_chinese",
                "description": "MR-NIAH benchmark haystack (Chinese)",
                "text_source": "mr_niah",
                "text_language": "zh",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": None,
                "category": "niah"
            },
        ])

    # ============== 7. MIXED LANGUAGES ==============
    if include_mixed_language:
        experiments.extend([
            {
                "name": "mixed_lang_en_de",
                "description": "Mixed languages: 50% English + 50% German",
                "text_source": "wiki",
                "text_language": "mixed",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": {"en": 0.5, "de": 0.5},
                "text_repeat_config": None,
                "category": "mixed_language"
            },
            {
                "name": "mixed_lang_en_zh",
                "description": "Mixed languages: 50% English + 50% Chinese",
                "text_source": "wiki",
                "text_language": "mixed",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": {"en": 0.5, "zh": 0.5},
                "text_repeat_config": None,
                "category": "mixed_language"
            },
            {
                "name": "mixed_lang_en_fr_de",
                "description": "Mixed languages: 34% English + 33% French + 33% German",
                "text_source": "wiki",
                "text_language": "mixed",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": {"en": 0.34, "fr": 0.33, "de": 0.33},
                "text_repeat_config": None,
                "category": "mixed_language"
            },
        ])

    # ============== 8. REPETITION (High-Correlation) ==============
    if include_repetition:
        experiments.extend([
            {
                "name": "repeat_wiki_block",
                "description": "Repetition: Wikipedia with repeated block",
                "text_source": "wiki",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": {
                    "mode": "block",
                    "block_chars": 20000,
                    "block_strategy": "random_span",
                    "seed": 42
                },
                "category": "repetition"
            },
            {
                "name": "repeat_wiki_block_shuffle_sentence",
                "description": "Repetition + Shuffle: Wikipedia with repeated block and sentence shuffle",
                "text_source": "wiki",
                "text_language": "en",
                "text_shuffle": "sentence",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": {
                    "mode": "block",
                    "block_chars": 20000,
                    "block_strategy": "random_span",
                    "seed": 42
                },
                "category": "repetition"
            },
            {
                "name": "repeat_wiki_tile",
                "description": "Repetition (tile): Wikipedia tiled to fill length",
                "text_source": "wiki",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": {
                    "mode": "tile",
                    "seed": 42
                },
                "category": "repetition"
            },
            {
                "name": "repeat_wiki_tile_shuffle_sentence",
                "description": "Repetition (tile) + Shuffle: Wikipedia tiled and sentences shuffled",
                "text_source": "wiki",
                "text_language": "en",
                "text_shuffle": "sentence",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": {
                    "mode": "tile",
                    "seed": 42
                },
                "category": "repetition"
            },
            {
                "name": "repeat_news_block",
                "description": "Repetition: News with repeated block",
                "text_source": "news",
                "text_language": "en",
                "text_shuffle": "none",
                "text_mix_config": None,
                "text_mix_strategy": "concat",
                "text_language_mix_config": None,
                "text_repeat_config": {
                    "mode": "block",
                    "block_chars": 20000,
                    "block_strategy": "random_span",
                    "seed": 42
                },
                "category": "repetition"
            },
        ])

    return experiments


# ============== MAIN RUNNER ==============

def run_text_distribution_experiment(
    N: int = 3,
    layer_id: int = 0,
    layers: Optional[List[int]] = None,
    all_layers: bool = False,
    n_min: int = 100,
    n_max: int = 2000,
    nb_tot: int = 15,
    k: int = 10,
    A_sweep: bool = False,
    A_cap: float = 1e4,
    n_A_points: int = 5,
    output_base: str = "results_text_distribution",
    include_shuffle: bool = True,
    include_domain: bool = True,
    include_language: bool = True,
    include_mixed: bool = True,
    include_niah: bool = False,
    include_mixed_language: bool = False,
    include_repetition: bool = False,
    experiments_to_run: list = None,  # If provided, only run these by name
    model_size: str = "base",  # "base" or "large"
    device_map: Optional[str] = None,  # HF device_map for sharded multi-GPU inference
    shuffle_override: str = None,  # Override shuffle for ALL experiments
    aggressive_cleanup: bool = False,
):
    """
    Run the full text distribution experiment suite.

    Args:
        N: Multiplier for max_length (4096 * N)
    layer_id: Layer to analyze (used if layers/all_layers not provided)
    layers: Explicit list of layers to run
    all_layers: Run all layers for the selected model size
        n_min, n_max, nb_tot, k: Monte Carlo parameters
        A_sweep: Whether to do A-sweep
        A_cap, n_A_points: A-sweep parameters
        output_base: Base directory for results
        include_shuffle: Include shuffle ablation experiments
        include_domain: Include domain variation experiments
        include_language: Include language variation experiments
        include_mixed: Include mixed source experiments
        include_niah: Include MR-NIAH benchmark experiments
        include_mixed_language: Include mixed language experiments
        experiments_to_run: If provided, only run experiments with these names
        model_size: "base" (768 hidden, 12 heads) or "large" (1024 hidden, 16 heads)
        device_map: Optional HF device_map (e.g., "auto") for multi-GPU sharding
        shuffle_override: Override shuffle for ALL experiments ("none", "sentence",
            "paragraph", "word"). When set, ALL experiments will use this shuffle
            level regardless of their default setting.
        aggressive_cleanup: Free as much memory as possible after saving outputs

    Returns:
        Dictionary with all results and summary
    """

    # Create timestamped output directory
    timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(output_base, f"run_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    print(f"\n{'='*70}")
    print(f"TEXT DISTRIBUTION EXPERIMENT")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Model: BigBird-RoBERTa-{model_size.capitalize()}")
    num_layers = 24 if model_size == "large" else 12
    if all_layers:
        layers_to_run = list(range(0, num_layers))
    elif layers:
        layers_to_run = sorted(set(layers))
    else:
        layers_to_run = [layer_id]

    print(f"Layers: {layers_to_run}, N: {N} (max_length: {4096*N})")
    print(f"MC params: n_min={n_min}, n_max={n_max}, nb_tot={nb_tot}, k={k}")
    print(f"A_sweep: {A_sweep}")
    if shuffle_override:
        print(f"Shuffle override: {shuffle_override}")
    print(f"{'='*70}\n")

    # Get experiment suite
    experiments = get_experiment_suite(
        include_shuffle=include_shuffle,
        include_domain=include_domain,
        include_language=include_language,
        include_mixed=include_mixed,
        include_niah=include_niah,
        include_mixed_language=include_mixed_language,
        include_repetition=include_repetition
    )

    # Apply shuffle override if specified
    if shuffle_override and shuffle_override != "none":
        print(f"Applying shuffle override '{shuffle_override}' to all experiments...")
        for exp in experiments:
            exp["text_shuffle"] = shuffle_override

    # Filter if specific experiments requested
    if experiments_to_run:
        experiments = [e for e in experiments if e["name"] in experiments_to_run]

    print(f"Experiments to run ({len(experiments)}):")
    for i, exp in enumerate(experiments):
        print(f"  {i+1}. [{exp['category']}] {exp['name']}: {exp['description']}")
    print()

    # Save experiment plan
    plan_df = pd.DataFrame(experiments)
    plan_df.to_csv(os.path.join(output_dir, "experiment_plan.csv"), index=False)

    # Save global hyperparameters
    global_params = {
        "N": N,
        "max_length": 4096 * N,
        "layers": layers_to_run,
        "n_min": n_min,
        "n_max": n_max,
        "nb_tot": nb_tot,
        "k": k,
        "A_sweep": A_sweep,
        "A_cap": A_cap,
        "n_A_points": n_A_points,
        "model_size": model_size,
        "device_map": device_map,
        "timestamp": timestamp,
        "num_experiments": len(experiments),
    }
    with open(os.path.join(output_dir, "global_hyperparameters.json"), "w") as f:
        json.dump(global_params, f, indent=2)

    # Run experiments
    all_results = []
    summary_rows = []

    for i, exp in enumerate(experiments):
        print(f"\n{'#'*70}")
        print(f"# EXPERIMENT {i+1}/{len(experiments)}: {exp['name']}")
        print(f"# {exp['description']}")
        print(f"{'#'*70}\n")

        exp_output_dir = os.path.join(output_dir, exp["name"])
        os.makedirs(exp_output_dir, exist_ok=True)

        for layer in layers_to_run:
            layer_output_dir = os.path.join(exp_output_dir, f"layer{layer}")

            try:
                result = run_experiment(
                    N=N,
                    layer_id=layer,
                    n_min=n_min,
                    n_max=n_max,
                    nb_tot=nb_tot,
                    k=k,
                    replace=False,
                    tiling=None,
                    plot=True,
                    plot_layer=False,
                    A_sweep=A_sweep,
                    A_cap=A_cap,
                    n_A_points=n_A_points,
                    sanity_plot=A_sweep,
                    output_dir=layer_output_dir,
                    add_timestamp=False,  # We already have timestamp in parent dir
                    # Model size
                    model_size=model_size,
                    device_map=device_map,
                    # Text parameters
                    text_source=exp["text_source"],
                    text_language=exp["text_language"],
                    text_mix_config=exp["text_mix_config"],
                    text_mix_strategy=exp["text_mix_strategy"],
                    text_shuffle=exp["text_shuffle"],
                text_language_mix_config=exp.get("text_language_mix_config"),
                text_repeat_config=exp.get("text_repeat_config"),
                text_experiment_config=None,
                aggressive_cleanup=aggressive_cleanup,
            )

                # Extract key metrics for summary
                if A_sweep:
                    # Find the real A result
                    real_A_result = next(
                        (r for r in result['results_by_A'] if r.get('is_real_A')),
                        result['results_by_A'][0]
                    )
                    slope_mean = real_A_result['slope_m_mean']
                    slope_mean_se = real_A_result['slope_m_se']
                    slope_cov = real_A_result['slope_c_mean']
                    slope_cov_se = real_A_result['slope_c_se']
                else:
                    slope_mean = result.get('slope_mean', None)
                    slope_mean_se = result.get('slope_mean_std', None)
                    slope_cov = result.get('slope_cov', None)
                    slope_cov_se = result.get('slope_cov_std', None)

                summary_rows.append({
                    "experiment": exp["name"],
                    "category": exp["category"],
                    "description": exp["description"],
                    "layer_id": layer,
                    "text_source": exp["text_source"],
                    "text_language": exp["text_language"],
                    "text_shuffle": exp["text_shuffle"],
                    "text_language_mix_config": str(exp.get("text_language_mix_config")) if exp.get("text_language_mix_config") else None,
                    "text_repeat_config": str(exp.get("text_repeat_config")) if exp.get("text_repeat_config") else None,
                    "A_real": result.get('A_real', None),
                    "theoretical_rate": result.get('theoretical_rate', None),
                    "slope_mean": slope_mean,
                    "slope_mean_se": slope_mean_se,
                    "slope_cov": slope_cov,
                    "slope_cov_se": slope_cov_se,
                    "status": "success",
                })

                all_results.append({"config": exp, "result": result, "layer_id": layer})
                print(f"\n[OK] {exp['name']} layer {layer} completed successfully")

            except Exception as e:
                print(f"\n[ERROR] {exp['name']} layer {layer} failed: {e}")
                import traceback
                traceback.print_exc()

                summary_rows.append({
                    "experiment": exp["name"],
                    "category": exp["category"],
                    "description": exp["description"],
                    "layer_id": layer,
                    "text_source": exp["text_source"],
                    "text_language": exp["text_language"],
                    "text_shuffle": exp["text_shuffle"],
                    "text_language_mix_config": str(exp.get("text_language_mix_config")) if exp.get("text_language_mix_config") else None,
                    "text_repeat_config": str(exp.get("text_repeat_config")) if exp.get("text_repeat_config") else None,
                    "A_real": None,
                    "theoretical_rate": None,
                    "slope_mean": None,
                    "slope_mean_se": None,
                    "slope_cov": None,
                    "slope_cov_se": None,
                    "status": f"error: {str(e)[:50]}",
                })

                all_results.append({"config": exp, "result": None, "error": str(e), "layer_id": layer})

    # Save summary
    summary_df = pd.DataFrame(summary_rows)
    summary_path = os.path.join(output_dir, "experiment_summary.csv")
    summary_df.to_csv(summary_path, index=False)

    # Print summary
    print(f"\n{'='*70}")
    print(f"EXPERIMENT SUMMARY")
    print(f"{'='*70}")
    print(f"\nResults saved to: {output_dir}")
    print(f"\nSummary table:")
    if len(summary_df) > 0:
        display_cols = ["experiment", "layer_id", "category", "slope_mean", "slope_cov", "A_real", "status"]
        available_cols = [c for c in display_cols if c in summary_df.columns]
        if available_cols:
            print(summary_df[available_cols].to_string())
        else:
            print(summary_df.to_string())
    else:
        print("  (no results)")
    print(f"\nFull summary saved to: {summary_path}")
    print(f"{'='*70}")

    return {
        "output_dir": output_dir,
        "summary": summary_df,
        "results": all_results,
        "global_params": global_params,
    }


# ============== CLI ==============

def parse_args():
    parser = argparse.ArgumentParser(
        description="Run text distribution experiment suite",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Full experiment suite
    python run_distribution_sweep.py

    # Quick test with reduced parameters
    python run_distribution_sweep.py --quick

    # Specific layer
    python run_distribution_sweep.py --layer_id 5

    # Only shuffle ablation
    python run_distribution_sweep.py --only_shuffle

    # Skip language tests (faster)
    python run_distribution_sweep.py --skip_language

    # Run specific experiments by name
    python run_distribution_sweep.py --experiments baseline_wiki_en,shuffle_word
        """
    )

    # Basic parameters
    parser.add_argument("--N", type=int, default=50,
                        help="Multiplier for max_length (default: 3)")
    parser.add_argument("--layer_id", type=int, default=0,
                        help="Layer to analyze (default: 0)")
    parser.add_argument("--layers", type=str, default=None,
                        help="Comma-separated list of layers (e.g., '0,1,2' or '0-3,7')")
    parser.add_argument("--all_layers", action="store_true",
                        help="Run all layers for the selected model size")

    # MC parameters
    parser.add_argument("--n_min", type=int, default=1000,
                        help="Minimum sample size (default: 100)")
    parser.add_argument("--n_max", type=int, default=4000,
                        help="Maximum sample size (default: 2000)")
    parser.add_argument("--nb_tot", type=int, default=15,
                        help="Number of sample sizes (default: 15)")
    parser.add_argument("--k", type=int, default=100,
                        help="MC repetitions (default: 10)")

    # A-sweep
    parser.add_argument("--A_sweep", action="store_true",
                        help="Enable A-sweep for each experiment")
    parser.add_argument("--A_cap", type=float, default=1e4,
                        help="A-sweep upper bound (default: 1e4)")
    parser.add_argument("--n_A_points", type=int, default=5,
                        help="Number of A values in sweep (default: 5)")

    # Model size
    parser.add_argument("--bigbird_large", action="store_true",
                        help="Use BigBird-RoBERTa-Large (1024 hidden, 16 heads) instead of Base (768 hidden, 12 heads)")

    # Experiment selection
    parser.add_argument("--skip_shuffle", action="store_true",
                        help="Skip shuffle ablation experiments")
    parser.add_argument("--skip_domain", action="store_true",
                        help="Skip domain variation experiments")
    parser.add_argument("--skip_language", action="store_true",
                        help="Skip language variation experiments")
    parser.add_argument("--skip_mixed", action="store_true",
                        help="Skip mixed source experiments")
    parser.add_argument("--skip_repetition", action="store_true",
                        help="Skip repetition experiments")

    parser.add_argument("--only_shuffle", action="store_true",
                        help="Only run shuffle ablation (+ baseline)")
    parser.add_argument("--only_domain", action="store_true",
                        help="Only run domain variation (+ baseline)")
    parser.add_argument("--only_language", action="store_true",
                        help="Only run language variation (+ baseline)")
    parser.add_argument("--only_niah", action="store_true",
                        help="Only run MR-NIAH benchmark (+ baseline)")
    parser.add_argument("--only_mixed_language", action="store_true",
                        help="Only run mixed language experiments (+ baseline)")
    parser.add_argument("--only_repetition", action="store_true",
                        help="Only run repetition experiments (+ baseline)")

    parser.add_argument("--skip_niah", action="store_true",
                        help="Skip MR-NIAH benchmark experiments")
    parser.add_argument("--skip_mixed_language", action="store_true",
                        help="Skip mixed language experiments")
    parser.add_argument("--include_niah", action="store_true",
                        help="Include MR-NIAH benchmark experiments")
    parser.add_argument("--include_mixed_language", action="store_true",
                        help="Include mixed language experiments")
    parser.add_argument("--include_repetition", action="store_true",
                        help="Include repetition experiments")

    parser.add_argument("--shuffle_override", type=str, default=None,
                        choices=["none", "sentence", "paragraph", "word"],
                        help="Override shuffle for ALL experiments in this run")
    parser.add_argument("--aggressive_cleanup", action="store_true",
                        help="Free as much memory as possible after saving outputs")

    parser.add_argument("--experiments", type=str, default=None,
                        help="Comma-separated list of experiment names to run")

    # Quick mode
    parser.add_argument("--quick", action="store_true",
                        help="Quick mode: reduced N, nb_tot, k for testing")

    # Output
    parser.add_argument("--output_dir", type=str, default="results_text_distribution",
                        help="Base output directory (default: results_text_distribution)")
    parser.add_argument("--device_map", type=str, default=None,
                        help="HF device_map for sharded multi-GPU inference (e.g., 'auto')")
    parser.add_argument("--force_cpu", action="store_true",
                        help="Force CPU for the entire run (disables CUDA even if available)")
    parser.add_argument("--mc_on_gpu", action="store_true",
                        help="Run Monte Carlo on GPU even if main compute is on CPU")
    parser.add_argument("--limits_on_gpu", action="store_true",
                        help="Run full-attention limits with GPU matmuls via CPU offload")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    if args.force_cpu:
        set_force_cpu(True)
        print("Force CPU enabled: all computation will run on CPU.")
    if args.mc_on_gpu:
        set_mc_on_gpu(True)
        print("MC-on-GPU enabled: Monte Carlo will use CUDA if available.")
    if args.limits_on_gpu:
        set_limits_on_gpu(True)
        print("Limits-on-GPU enabled: full-attention limits will use GPU offload.")

    # Handle quick mode
    if args.quick:
        print("QUICK MODE: Using reduced parameters for testing")
        args.N = 1
        args.nb_tot = 8
        args.k = 5
        args.n_max = 500

    # Handle "only_*" flags
    include_shuffle = not args.skip_shuffle
    include_domain = not args.skip_domain
    include_language = not args.skip_language
    include_mixed = not args.skip_mixed
    include_repetition = args.include_repetition and not args.skip_repetition
    include_niah = args.include_niah and not args.skip_niah
    include_mixed_language = args.include_mixed_language and not args.skip_mixed_language

    if args.only_shuffle:
        include_domain = False
        include_language = False
        include_mixed = False
        include_niah = False
        include_mixed_language = False
        include_repetition = False
    elif args.only_domain:
        include_shuffle = False
        include_language = False
        include_mixed = False
        include_niah = False
        include_mixed_language = False
        include_repetition = False
    elif args.only_language:
        include_shuffle = False
        include_domain = False
        include_mixed = False
        include_niah = False
        include_mixed_language = False
        include_repetition = False
    elif args.only_niah:
        include_shuffle = False
        include_domain = False
        include_language = False
        include_mixed = False
        include_niah = True
        include_mixed_language = False
        include_repetition = False
    elif args.only_mixed_language:
        include_shuffle = False
        include_domain = False
        include_language = False
        include_mixed = False
        include_niah = False
        include_mixed_language = True
        include_repetition = False
    elif args.only_repetition:
        include_shuffle = False
        include_domain = False
        include_language = False
        include_mixed = False
        include_niah = False
        include_mixed_language = False
        include_repetition = True

    # Parse specific experiments if provided
    experiments_to_run = None
    if args.experiments:
        experiments_to_run = [e.strip() for e in args.experiments.split(",")]
        # When specific experiments are provided, include all categories
        include_shuffle = True
        include_domain = True
        include_language = True
        include_mixed = True
        include_niah = True
        include_mixed_language = True
        include_repetition = True

    # Determine model size
    model_size = "large" if args.bigbird_large else "base"

    def parse_layer_range(layers_str: str) -> List[int]:
        layers = []
        for part in layers_str.split(','):
            part = part.strip()
            if not part:
                continue
            if '-' in part:
                start, end = part.split('-')
                layers.extend(range(int(start), int(end) + 1))
            else:
                layers.append(int(part))
        return sorted(set(layers))

    layers = parse_layer_range(args.layers) if args.layers else None

    # Run experiment suite
    results = run_text_distribution_experiment(
        N=args.N,
        layer_id=args.layer_id,
        layers=layers,
        all_layers=args.all_layers,
        n_min=args.n_min,
        n_max=args.n_max,
        nb_tot=args.nb_tot,
        k=args.k,
        A_sweep=args.A_sweep,
        A_cap=args.A_cap,
        n_A_points=args.n_A_points,
        output_base=args.output_dir,
        include_shuffle=include_shuffle,
        include_domain=include_domain,
        include_language=include_language,
        include_mixed=include_mixed,
        include_repetition=include_repetition,
        include_niah=include_niah,
        include_mixed_language=include_mixed_language,
        experiments_to_run=experiments_to_run,
        model_size=model_size,
        device_map=args.device_map,
        shuffle_override=args.shuffle_override,
        aggressive_cleanup=args.aggressive_cleanup,
    )

    print(f"\nDone! Results in: {results['output_dir']}")
