#!/usr/bin/env python3
"""Batch runner for JAXBench optimization with Autocomp.

Usage:
    cd /path/to/autocomp

    # Run all Group A (Pallas improvement):
    python run_batch.py --group A

    # Run specific benchmarks:
    python run_batch.py --probs 1p_Flash_Attention 12p_RMSNorm 30k_Matmul_Scaling_ResidualAdd

    # Run first N from a group:
    python run_batch.py --group C --limit 3

    # Dry run (print what would be executed):
    python run_batch.py --group B --dry-run
"""
import argparse
import os
import pathlib
import random
import sys
import traceback

from autocomp.common import logger
from autocomp.search.search import (
    create_backend_and_agents,
    load_initial_code,
    BeamSearchStrategy,
)
from autocomp.search.prob import Prob
from autocomp.hw_config import TpuHardwareConfig

# ======================================================================
# Benchmark groups
# ======================================================================

GROUP_A = [
    ("jaxbench-pallas", "1p_Flash_Attention"),
    ("jaxbench-pallas", "2p_GQA_Attention"),
    ("jaxbench-pallas", "3p_MLA_Attention"),
    ("jaxbench-pallas", "4p_Sparse_Attention"),
    ("jaxbench-pallas", "6p_Paged_Attention"),
    ("jaxbench-pallas", "7p_Ragged_Paged_Attention"),
    ("jaxbench-pallas", "8p_GEMM"),
    ("jaxbench-pallas", "11p_Megablox_GMM"),
]

GROUP_B = [
    ("jaxbench-baseline", "5p_Flex_Attention"),
    ("jaxbench-baseline", "9p_SwiGLU_MLP"),
    ("jaxbench-baseline", "10p_Sparse_MoE"),
    ("jaxbench-baseline", "12p_RMSNorm"),
    ("jaxbench-baseline", "13p_Cross_Entropy"),
    ("jaxbench-baseline", "14p_Ragged_Dot"),
    ("jaxbench-baseline", "15p_RetNet_Retention"),
    ("jaxbench-baseline", "16p_Mamba2_SSD"),
    ("jaxbench-baseline", "17p_Triangle_Multiplication"),
]

GROUP_C = [
    ("jaxbench-baseline", "18k_Conv2D_ReLU_BiasAdd"),
    ("jaxbench-baseline", "19k_Matmul_Subtract_Multiply_ReLU"),
    ("jaxbench-baseline", "20k_Gemm_Multiply_LeakyReLU"),
    ("jaxbench-baseline", "21k_Gemm_Divide_Sum_Scaling"),
    ("jaxbench-baseline", "22k_Conv2d_InstanceNorm_Divide"),
    ("jaxbench-baseline", "23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp"),
    ("jaxbench-baseline", "24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish"),
    ("jaxbench-baseline", "25k_Conv3d_GroupNorm_Mean"),
    ("jaxbench-baseline", "26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply"),
    ("jaxbench-baseline", "27k_Matmul_Mish_Mish"),
    ("jaxbench-baseline", "28k_ConvTranspose3d_LayerNorm_GELU_Scaling"),
    ("jaxbench-baseline", "29k_Matmul_Swish_Sum_GroupNorm"),
    ("jaxbench-baseline", "30k_Matmul_Scaling_ResidualAdd"),
    ("jaxbench-baseline", "31k_Gemm_BatchNorm_GELU_ReLU"),
    ("jaxbench-baseline", "32k_Gemm_Sigmoid_LogSumExp"),
    ("jaxbench-baseline", "33k_Conv3d_Mish_Tanh"),
    ("jaxbench-baseline", "34k_Conv2d_Activation_BatchNorm"),
    ("jaxbench-baseline", "35k_Gemm_Scaling_Hardtanh_GELU"),
    ("jaxbench-baseline", "36k_Matmul_Sigmoid_Sum"),
    ("jaxbench-baseline", "37k_Matmul_Swish_Scaling"),
    ("jaxbench-baseline", "38k_Matmul_Dropout_Softmax"),
    ("jaxbench-baseline", "39k_Conv2d_GELU_GlobalAvgPool"),
    ("jaxbench-baseline", "40k_Gemm_GroupNorm_Min_BiasAdd"),
    ("jaxbench-baseline", "41k_Gemm_Add_ReLU"),
    ("jaxbench-baseline", "42k_Gemm_Max_Subtract_GELU"),
    ("jaxbench-baseline", "43k_Gemm_BatchNorm_Scaling_Softmax"),
    ("jaxbench-baseline", "44k_Matmul_Divide_GELU"),
    ("jaxbench-baseline", "45k_Gemm_GroupNorm_Swish_Multiply_Swish"),
    ("jaxbench-baseline", "46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp"),
    ("jaxbench-baseline", "47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh"),
    ("jaxbench-baseline", "48k_Matmul_BatchNorm_BiasAdd_Divide_Swish"),
    ("jaxbench-baseline", "49k_Matmul_AvgPool_GELU_Scale_Max"),
    ("jaxbench-baseline", "50k_Matmul_GELU_Softmax"),
]

GROUP_D = [
    ("jaxbench-baseline", "1p_Flash_Attention"),
    ("jaxbench-baseline", "2p_GQA_Attention"),
    ("jaxbench-baseline", "3p_MLA_Attention"),
    ("jaxbench-baseline", "4p_Sparse_Attention"),
    ("jaxbench-baseline", "6p_Paged_Attention"),
    ("jaxbench-baseline", "7p_Ragged_Paged_Attention"),
    ("jaxbench-baseline", "8p_GEMM"),
    ("jaxbench-baseline", "11p_Megablox_GMM"),
]

GROUPS = {"A": GROUP_A, "B": GROUP_B, "C": GROUP_C, "D": GROUP_D}

# ======================================================================
# Search configuration
# ======================================================================

# --- Translation phase (XLA baseline → Pallas) ---
TRANSLATE_ITERATIONS = 4
TRANSLATE_USE_EDITS = False
TRANSLATE_PERF_THRESHOLD = 15
TRANSLATE_DROP_ORIGINAL = True
TRANSLATE_SCORE = True

# --- Optimization phase (improve Pallas code) ---
OPT_ITERATIONS = 4
PALLAS_OPT_ITERATIONS = 8       # Group A gets more iterations (no translation budget)
OPT_USE_EDITS = True

# --- Shared settings ---
SEARCH_STRATEGY = "beam"
NUM_PLAN_CANDIDATES = 3  # with SKIP_PLANNING, = number of unique prompts per parent
NUM_CODE_CANDIDATES = 2  # with SKIP_PLANNING, = number of samples per unique prompt
BEAM_SIZE = 3
SKIP_PLANNING = True
PREVENT_DUPLICATE_LEVEL = -1  # allow same-parent candidates; eval is the bottleneck
EARLY_STOP_ITERS = 3       # stop if no improvement for this many iters
EARLY_STOP_THRESHOLD = 0.98 # ratio >= this triggers stop (< 2% improvement = stalled)
DROPOUT_MENU_OPTIONS = 0.25

# Reduce trial count during search (final eval uses full 100 trials separately)
os.environ["AUTOCOMP_TPU_NUM_WARMUP"] = "3"
os.environ["AUTOCOMP_TPU_NUM_TRIALS"] = "20"
os.environ.setdefault("AUTOCOMP_JAXBENCH_PROFILE", "1")

MODELS = [
    "gcp::gemini-3.1-pro-preview",
    "gcp::gemini-3-flash-preview",
]

AGENT_NAME = "built:tpu-v6e"
HW_CONFIG = TpuHardwareConfig("v6e-1")
MENU_STRATEGY = "one-shot"
FINE_GRAINED_ISA = True
EXAMPLE_RATE = 0.25


OUTPUT_BASE = pathlib.Path("output") / "jaxbench-sweep"


def _build_dir(prob_id: str, tag: str, phase: str = "") -> pathlib.Path:
    """Construct a deterministic output directory name for a run."""
    name = f"{prob_id}_{tag}"
    if phase:
        name += f"_{phase}"
    return OUTPUT_BASE / name


def build_output_dir(prob_type: str, prob_id: str, translate_iters: int) -> pathlib.Path:
    """Backwards-compatible: return the optimization dir for Pallas, or the
    translate dir for baselines (so skip-completed logic still works)."""
    tag = "pallas" if prob_type == "jaxbench-pallas" else "baseline"
    return _build_dir(prob_id, tag)


def is_phase_complete(output_dir: pathlib.Path, iterations: int) -> bool:
    """Check if a run already completed all iterations."""
    return (output_dir / f"candidates-iter-{iterations}").exists()


def is_complete(output_dir: pathlib.Path) -> bool:
    """Check if the full pipeline is done (backwards-compatible)."""
    return is_phase_complete(output_dir, OPT_ITERATIONS)


def _make_optimizer(
    output_dir: pathlib.Path,
    prob: Prob,
    initial_code: str,
    *,
    translate_iters: int,
    use_edits: bool,
    continue_from: str | pathlib.Path = "",
) -> BeamSearchStrategy:
    """Create a BeamSearchStrategy with shared settings."""
    models = [m.replace("/", "_") for m in MODELS]
    eval_backend, agent, code_agent = create_backend_and_agents(
        "jaxbench", AGENT_NAME, HW_CONFIG, prob, models, None,
        menu_strategy=MENU_STRATEGY, fine_grained_isa=FINE_GRAINED_ISA,
        example_rate=EXAMPLE_RATE, cache_dir=output_dir,
    )
    return BeamSearchStrategy(
        output_dir=output_dir,
        eval_backend=eval_backend,
        agent=agent,
        orig_code=initial_code,
        prob=prob,
        metric="latency",
        simulator=None,
        give_score_feedback=1,
        give_util_feedback=0,
        give_hw_feedback=0.5,
        include_ancestors=False,
        plan_icl_examples=False,
        code_icl_examples=False,
        dropout_menu_options=DROPOUT_MENU_OPTIONS,
        prevent_duplicate_level=PREVENT_DUPLICATE_LEVEL,
        translate_iters=translate_iters,
        translate_perf_threshold=TRANSLATE_PERF_THRESHOLD,
        translate_drop_original=TRANSLATE_DROP_ORIGINAL,
        translate_score=TRANSLATE_SCORE,
        code_agent=code_agent,
        early_stop_iters=EARLY_STOP_ITERS,
        early_stop_threshold=EARLY_STOP_THRESHOLD,
        continue_from=continue_from,
        use_edits=use_edits,
        num_analyses=0,
        num_plan_candidates=NUM_PLAN_CANDIDATES,
        num_code_candidates=NUM_CODE_CANDIDATES,
        beam_size=BEAM_SIZE,
        num_pairs_to_combine=0,
        num_gen_per_combine=0,
        trigger_exhaustive_threshold=1,
        trigger_exhaustive_iters=20,
        start_exhaustive_iters=0,
        reimplement_failed=False,
        skip_planning=SKIP_PLANNING,
    )


def run_one(prob_type: str, prob_id: str) -> None:
    """Run optimization for a single benchmark.

    For ``jaxbench-pallas`` (Group A): a single optimization-only run.
    For ``jaxbench-baseline`` (Groups B-D): two phases —
      1. Translation (full code generation, no edits)
      2. Optimization (edit-based, continues from translation output)
    """
    needs_translate = prob_type != "jaxbench-pallas"
    tag = "pallas" if not needs_translate else "baseline"

    import autocomp.common.my_logging
    random.seed(1111)

    prob = Prob(prob_type, prob_id)
    initial_code = load_initial_code("jaxbench", prob)

    if needs_translate:
        # --- Phase 1: Translation ---
        tr_dir = _build_dir(prob_id, tag, "translate")
        opt_dir = _build_dir(prob_id, tag)

        # Legacy run detection: if opt_dir already has candidates from an
        # old-style combined translate+optimize run, skip both phases.
        legacy_candidates = opt_dir / "candidates-iter-0"
        if legacy_candidates.exists() and not tr_dir.exists():
            logger.info(
                "Legacy run detected for %s (has %s, no translate dir) — "
                "skipping translation, resuming optimization only",
                prob_id, legacy_candidates,
            )
        else:
            tr_dir.mkdir(parents=True, exist_ok=True)
            if is_phase_complete(tr_dir, TRANSLATE_ITERATIONS):
                logger.info("Translation already complete for %s, skipping", prob_id)
            else:
                autocomp.common.my_logging.move_log(tr_dir, tag="translate")
                logger.info("=" * 60)
                logger.info("Translation: %s / %s", prob_type, prob_id)
                logger.info("Output: %s", tr_dir)
                logger.info("=" * 60)

                optimizer = _make_optimizer(
                    tr_dir, prob, initial_code,
                    translate_iters=TRANSLATE_ITERATIONS,
                    use_edits=TRANSLATE_USE_EDITS,
                )
                optimizer.optimize(TRANSLATE_ITERATIONS)

        # --- Phase 2: Optimization (continues from translation) ---
        opt_dir.mkdir(parents=True, exist_ok=True)
        autocomp.common.my_logging.move_log(opt_dir, tag="optimize")
        logger.info("=" * 60)
        logger.info("Optimization: %s / %s (continuing from translation)", prob_type, prob_id)
        logger.info("Output: %s", opt_dir)
        logger.info("=" * 60)

        continue_from = tr_dir if tr_dir.exists() else ""
        optimizer = _make_optimizer(
            opt_dir, prob, initial_code,
            translate_iters=0,
            use_edits=OPT_USE_EDITS,
            continue_from=continue_from,
        )
        optimizer.optimize(OPT_ITERATIONS)
    else:
        # --- Pallas benchmarks: optimization only ---
        opt_dir = _build_dir(prob_id, tag)
        opt_dir.mkdir(parents=True, exist_ok=True)
        autocomp.common.my_logging.move_log(opt_dir, tag="search")
        logger.info("=" * 60)
        logger.info("Starting: %s / %s", prob_type, prob_id)
        logger.info("Output: %s", opt_dir)
        logger.info("=" * 60)

        optimizer = _make_optimizer(
            opt_dir, prob, initial_code,
            translate_iters=0,
            use_edits=OPT_USE_EDITS,
        )
        optimizer.optimize(PALLAS_OPT_ITERATIONS)


def main():
    parser = argparse.ArgumentParser(description="Batch JAXBench optimization")
    parser.add_argument("--group", choices=["A", "B", "C", "D", "all"],
                        help="Run all benchmarks in a group")
    parser.add_argument("--probs", nargs="+",
                        help="Run specific prob_ids (auto-detects prob_type from group membership)")
    parser.add_argument("--limit", type=int, default=0,
                        help="Only run first N benchmarks from the group")
    parser.add_argument("--skip-completed", action="store_true", default=True,
                        help="Skip benchmarks with completed output (default: True)")
    parser.add_argument("--no-skip-completed", action="store_false", dest="skip_completed")
    parser.add_argument("--dry-run", action="store_true",
                        help="Print what would run without executing")
    parser.add_argument("--models", nargs="+",
                        help="Override MODELS list (e.g. --models gcp::gemini-3-flash-preview)")
    parser.add_argument("--output-base", default=None,
                        help="Override OUTPUT_BASE (default: output/jaxbench-sweep)")
    args = parser.parse_args()

    if args.models:
        global MODELS
        MODELS = args.models

    if args.output_base:
        global OUTPUT_BASE
        OUTPUT_BASE = pathlib.Path(args.output_base)

    if args.group and args.probs:
        parser.error("Use --group or --probs, not both")
    if not args.group and not args.probs:
        parser.error("Specify --group or --probs")

    # Build run list
    if args.group:
        if args.group == "all":
            runs = GROUP_A + GROUP_B + GROUP_D + GROUP_C
        else:
            runs = GROUPS[args.group]
    else:
        all_benchmarks = {pid: pt for pt, pid in GROUP_A + GROUP_B + GROUP_C + GROUP_D}
        runs = []
        for pid in args.probs:
            if pid not in all_benchmarks:
                print(f"Warning: {pid} not found in any group, assuming jaxbench-baseline")
                runs.append(("jaxbench-baseline", pid))
            else:
                runs.append((all_benchmarks[pid], pid))

    if args.limit > 0:
        runs = runs[:args.limit]

    # Filter completed
    pending = []
    for prob_type, prob_id in runs:
        tag = "pallas" if prob_type == "jaxbench-pallas" else "baseline"
        opt_dir = _build_dir(prob_id, tag)
        iters = PALLAS_OPT_ITERATIONS if prob_type == "jaxbench-pallas" else OPT_ITERATIONS
        if args.skip_completed and is_phase_complete(opt_dir, iters):
            print(f"  SKIP (complete): {prob_type} / {prob_id}")
        else:
            pending.append((prob_type, prob_id))

    needs_translate = lambda pt: pt != "jaxbench-pallas"
    print(f"\n{'DRY RUN — ' if args.dry_run else ''}Will run {len(pending)} benchmarks "
          f"(skipped {len(runs) - len(pending)} completed):\n")
    for i, (pt, pid) in enumerate(pending, 1):
        if needs_translate(pt):
            print(f"  {i:3d}. [{pt}] {pid}  (translate={TRANSLATE_ITERATIONS}, optimize={OPT_ITERATIONS})")
        else:
            print(f"  {i:3d}. [{pt}] {pid}  (optimize={PALLAS_OPT_ITERATIONS})")

    if args.dry_run:
        return

    print(f"\n{'=' * 60}")
    print(f"Starting batch: {len(pending)} benchmarks")
    print(f"{'=' * 60}\n")

    results = []
    for i, (prob_type, prob_id) in enumerate(pending, 1):
        print(f"\n[{i}/{len(pending)}] {prob_type} / {prob_id}")
        print("-" * 60)
        try:
            run_one(prob_type, prob_id)
            results.append((prob_type, prob_id, "OK"))
        except Exception as e:
            logger.error("FAILED: %s / %s: %s", prob_type, prob_id, e)
            traceback.print_exc()
            results.append((prob_type, prob_id, f"FAILED: {e}"))

    print(f"\n{'=' * 60}")
    print("BATCH COMPLETE")
    print(f"{'=' * 60}")
    for pt, pid, status in results:
        print(f"  [{pt}] {pid}: {status}")


if __name__ == "__main__":
    main()
