"""Best-of-N baseline: sample N independent Pallas implementations in one shot.

No iteration, no menu, no feedback. Just "rewrite this XLA kernel as Pallas" × N.
Pick the best correct result.

Usage:
    python -m autocomp.baselines.best_of_n \
        --prob_id 12p_RMSNorm --n 144 \
        --model gemini-3.1-pro-preview \
        --output_dir output/baselines/best_of_n/12p_RMSNorm
"""
from __future__ import annotations

import argparse
import pathlib
import time

from autocomp.agents.llm_agent import extract
from autocomp.baselines.common import (
    build_prompt, evaluate_many, load_baseline_code,
    save_candidates, write_summary,
)
from autocomp.common import LLMClient, logger


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prob_id", required=True)
    parser.add_argument("--prob_type", default="jaxbench-baseline",
                        choices=["jaxbench-baseline", "jaxbench-pallas"],
                        help="jaxbench-baseline starts from baseline.py; jaxbench-pallas starts from optimized.py")
    parser.add_argument("--n", type=int, default=144)
    parser.add_argument("--model", default="gemini-3.1-pro-preview")
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument(
        "--batch_size", type=int, default=18,
        help="Number of candidates to evaluate per TPU batch.",
    )
    args = parser.parse_args()

    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    prob, baseline_code = load_baseline_code(args.prob_id, args.prob_type)
    prompt = build_prompt(args.prob_id, baseline_code, args.prob_type)
    (output_dir / "prompt.txt").write_text(prompt)

    logger.info("Best-of-N: prob=%s n=%d model=%s", args.prob_id, args.n, args.model)
    t0 = time.perf_counter()

    client = LLMClient(args.model)
    responses = client.chat_async([prompt], num_samples=args.n, temperature=args.temperature)[0]
    logger.info("Generated %d responses in %.1fs", len(responses), time.perf_counter() - t0)

    codes = [extract(r) for r in responses]

    t_eval = time.perf_counter()
    results: list[dict] = []
    for i in range(0, len(codes), args.batch_size):
        batch = codes[i : i + args.batch_size]
        logger.info("Evaluating batch %d–%d / %d", i, i + len(batch), len(codes))
        results.extend(evaluate_many(prob, batch))
    logger.info("Evaluation took %.1fs", time.perf_counter() - t_eval)

    save_candidates(output_dir, responses, results)

    baseline_eval = evaluate_many(prob, [baseline_code])[0]
    baseline_lat = baseline_eval.get("latency")
    summary = write_summary(
        output_dir, args.prob_id, args.n, baseline_lat, results,
        runtime_s=time.perf_counter() - t0,
        extra={"model": args.model, "baseline_correct": baseline_eval.get("correct")},
    )
    logger.info("Summary: %s", summary)


if __name__ == "__main__":
    main()
