"""Iterative-refinement baseline: `--num_chains` parallel chains of `--turns` turns each.

Each "step" advances all chains in lock-step: one batched LLM call for all
chains, then one batched TPU evaluation. Each chain sees only its own history
(no cross-pollination). Total LLM budget = num_chains * turns.

On failure within a chain we continue using the failing code + its error log
as feedback; we never roll back.

Default matches Autocomp's Group B config (18 parallel × 8 turns = 144 samples).

Usage:
    AUTOCOMP_JAXBENCH_PROFILE=1 python -m autocomp.baselines.iterative \
        --prob_id 12p_RMSNorm --num_chains 18 --turns 8 \
        --model gemini-3.1-pro-preview \
        --output_dir output/baselines/iterative/12p_RMSNorm
"""
from __future__ import annotations

import argparse
import json
import os
import pathlib
import time

from autocomp.agents.llm_agent import extract
from autocomp.baselines.common import (
    evaluate_many,
    load_baseline_code,
    write_summary,
)
from autocomp.backend.jaxbench.jaxbench_eval import JaxBenchEvalBackend
from autocomp.common import LLMClient, REPO_ROOT, logger
from autocomp.hw_config import TpuHardwareConfig


SYSTEM_PREAMBLE_TRANSLATE = (
    "You are optimizing a JAX kernel for TPU v6e (Trillium) using the Pallas "
    "programming model (`jax.experimental.pallas`).\n"
    "Target hardware: TPU v6e. VMEM = 128 MiB, 8 TensorCores per chip, "
    "peak ≈918 TFLOPS bf16, HBM ≈1526 GB/s.\n"
    "Rules for every turn:\n"
    "- Keep the public `workload(*inputs)` signature unchanged.\n"
    "- Output a single complete Python file inside one ```python ... ``` block.\n"
    "- Do not include explanatory prose outside the code block.\n"
    "Strategy: if you do not yet have a correct Pallas implementation, your "
    "first priority is to produce a correct, straightforward translation — "
    "even if it isn't faster than the XLA baseline. Only after you have a "
    "correct Pallas kernel should you focus on optimizing it for speed.\n"
)

SYSTEM_PREAMBLE_OPTIMIZE = (
    "You are optimizing an existing Pallas TPU kernel for TPU v6e (Trillium) "
    "using the Pallas programming model (`jax.experimental.pallas`).\n"
    "Target hardware: TPU v6e. VMEM = 128 MiB, 8 TensorCores per chip, "
    "peak ≈918 TFLOPS bf16, HBM ≈1526 GB/s.\n"
    "Rules for every turn:\n"
    "- Keep the public `workload(*inputs)` signature unchanged.\n"
    "- Output a single complete Python file inside one ```python ... ``` block.\n"
    "- Do not include explanatory prose outside the code block.\n"
    "Strategy: the provided implementation is already correct. Focus on "
    "optimizing it for speed — better tiling, pipelining, memory access "
    "patterns, etc.\n"
)


def _get_preamble(prob_type: str) -> str:
    if prob_type == "jaxbench-pallas":
        return SYSTEM_PREAMBLE_OPTIMIZE
    return SYSTEM_PREAMBLE_TRANSLATE


# --------------------------------------------------------------------------
# Optional "full context" mode: reuse Autocomp's built agent to load the exact
# architecture.md + per-problem-selected ISA + code examples + rules. Lets us
# isolate "context quality" from "search algorithm" in the paper.
# --------------------------------------------------------------------------

class _IterativeAgentContext:
    """Lazy wrapper around BuiltLLMAgent; used only in --context full mode."""

    def __init__(self, model: str, agent_dir: str | pathlib.Path = "built:tpu-v6e"):
        from autocomp.agent_builder.built_agent import BuiltLLMAgent

        if isinstance(agent_dir, str) and agent_dir.startswith("built:"):
            name = agent_dir[len("built:"):]
            config_dir = REPO_ROOT / "autocomp" / "agent_builder" / ".built" / name
        else:
            config_dir = pathlib.Path(agent_dir)
        if not config_dir.is_dir():
            raise ValueError(f"Built agent config not found: {config_dir}")

        self.agent = BuiltLLMAgent(
            model=model,
            config_dir=config_dir,
            hw_config=TpuHardwareConfig("v6e-1"),
            eval_backend=JaxBenchEvalBackend(),
            # Match run_batch.py: no fine-grained ISA, no code-example dropout.
            fine_grained_isa=False,
            example_rate=1.0,
        )
        logger.info("Iterative+context: loaded agent profile from %s", config_dir)

    def prefix_for(self, prob, code: str) -> str:
        """Assemble arch + ISA + examples exactly as Autocomp's prompt scaffold does."""
        isa_text, examples_text = self.agent._get_problem_context(prob, code)
        parts = [self.agent._architecture, isa_text]
        if examples_text:
            parts.append(examples_text)
        return "\n".join(p for p in parts if p)

    def rules_for(self, prob, translate: bool) -> str:
        """Assembled rules text (hw + backend + general + coding; planning omitted)."""
        return self.agent._get_prompt_rules(
            planning=False, coding=True, prob=prob, translate=translate,
        )


def _initial_prompt(prob_id: str, baseline_code: str, prob_type: str = "jaxbench-baseline",
                    ctx: _IterativeAgentContext | None = None, prob=None) -> str:
    preamble = _get_preamble(prob_type)
    if prob_type == "jaxbench-pallas":
        task = ("Below is an existing Pallas implementation. Optimize it to "
                "run faster while producing identical outputs.\n\n")
    else:
        task = ("Below is the XLA/JAX reference implementation. Rewrite it as a Pallas "
                "kernel that produces identical outputs and runs faster.\n\n")

    context_block = ""
    rules_block = ""
    if ctx is not None and prob is not None:
        context_block = ctx.prefix_for(prob, baseline_code) + "\n\n"
        translate = (prob_type != "jaxbench-pallas")
        rules_block = (
            "\nMake sure to follow these rules:\n"
            + ctx.rules_for(prob, translate=translate)
        )

    return (
        f"{preamble}\n"
        f"{context_block}"
        f"Benchmark: {prob_id}\n\n"
        f"{task}"
        f"```python\n{baseline_code}\n```\n"
        f"{rules_block}"
    )


def _refine_prompt(
    prob_id: str,
    current_code: str,
    prev_result: dict,
    baseline_latency: float | None,
    best_so_far_ms: float | None,
    prob_type: str = "jaxbench-baseline",
    ctx: _IterativeAgentContext | None = None,
    prob=None,
) -> str:
    preamble = _get_preamble(prob_type)
    correct = prev_result.get("correct")
    latency = prev_result.get("latency")
    stdout = prev_result.get("stdout", "")
    profile = prev_result.get("profile") or ""

    if correct and latency is not None:
        fb = f"✅ Previous attempt was correct. Latency: {latency:.3f} ms"
        if baseline_latency:
            fb += f" (XLA baseline: {baseline_latency:.3f} ms, speedup: {baseline_latency/latency:.2f}x)"
        if best_so_far_ms is not None:
            fb += f". Best-so-far in this chain: {best_so_far_ms:.3f} ms"
        fb += ".\n"
    else:
        tail = "\n".join(stdout.strip().splitlines()[-40:])
        fb = (
            "❌ Previous attempt was incorrect or failed to run.\n"
            f"Failure log (last 40 lines):\n```\n{tail}\n```\n"
        )

    parts = [preamble]
    if ctx is not None and prob is not None:
        parts.append(ctx.prefix_for(prob, current_code))
    parts.extend([f"Benchmark: {prob_id}", "", fb])
    if profile:
        parts.append("Profiler summary from previous run:")
        parts.append("```")
        parts.append(profile)
        parts.append("```")
        parts.append("")
    parts.extend([
        "Previous implementation:",
        "```python",
        current_code,
        "```",
        "",
        "Produce a new complete Pallas implementation. If the previous attempt "
        "was incorrect, first get it correct (simplify if needed). If it was "
        "correct, make it faster. Output one ```python ... ``` block.",
    ])
    if ctx is not None and prob is not None:
        translate = (prob_type != "jaxbench-pallas")
        parts.extend([
            "",
            "Make sure to follow these rules:",
            ctx.rules_for(prob, translate=translate),
        ])
    return "\n".join(parts)


def _save_turn(turn_dir: pathlib.Path, prompt: str, response: str, code: str, result: dict) -> None:
    turn_dir.mkdir(parents=True, exist_ok=True)
    (turn_dir / "prompt.txt").write_text(prompt)
    (turn_dir / "response.txt").write_text(response)
    (turn_dir / "code.py").write_text(code)
    slim = {k: v for k, v in result.items() if k not in ("stderr",)}
    (turn_dir / "result.json").write_text(json.dumps(slim, indent=2))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prob_id", required=True)
    parser.add_argument("--prob_type", default="jaxbench-baseline",
                        choices=["jaxbench-baseline", "jaxbench-pallas"],
                        help="jaxbench-baseline starts from baseline.py; jaxbench-pallas starts from optimized.py")
    parser.add_argument("--num_chains", type=int, default=18,
                        help="Number of independent parallel chains.")
    parser.add_argument("--turns", type=int, default=8,
                        help="Number of refinement turns per chain.")
    parser.add_argument("--model", default="gemini-3.1-pro-preview")
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--eval_batch_size", type=int, default=18,
                        help="Max candidates to send to TPU in one batch.")
    parser.add_argument("--context", choices=["minimal", "full"], default="minimal",
                        help="'minimal' = current preamble only; "
                             "'full' = prepend Autocomp's agent profile "
                             "(architecture + per-problem ISA + code examples + rules).")
    parser.add_argument("--agent_dir", default="built:tpu-v6e",
                        help="Agent profile to load when --context=full "
                             "(e.g. 'built:tpu-v6e' or a path).")
    args = parser.parse_args()

    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    if os.getenv("AUTOCOMP_JAXBENCH_PROFILE") != "1":
        logger.warning(
            "AUTOCOMP_JAXBENCH_PROFILE is not set; profile feedback will be empty. "
            "Export AUTOCOMP_JAXBENCH_PROFILE=1 to enable."
        )

    prob, baseline_code = load_baseline_code(args.prob_id, args.prob_type)
    total_budget = args.num_chains * args.turns
    logger.info("Iterative (parallel): prob=%s chains=%d turns=%d total=%d model=%s",
                args.prob_id, args.num_chains, args.turns, total_budget, args.model)

    t0 = time.perf_counter()

    backend = JaxBenchEvalBackend()
    baseline_eval = evaluate_many(prob, [baseline_code], backend=backend)[0]
    baseline_lat = baseline_eval.get("latency")
    logger.info("Baseline latency: %.3f ms (correct=%s)",
                baseline_lat or -1.0, baseline_eval.get("correct"))

    client = LLMClient(args.model)

    ctx: _IterativeAgentContext | None = None
    if args.context == "full":
        ctx = _IterativeAgentContext(args.model, args.agent_dir)

    # Per-chain state.
    chain_current_code = [baseline_code] * args.num_chains
    chain_current_result: list[dict] = [{} for _ in range(args.num_chains)]
    chain_best_latency: list[float | None] = [None] * args.num_chains
    chain_best_turn: list[int | None] = [None] * args.num_chains
    turns_log: list[dict] = []

    for turn in range(args.turns):
        # 1. Build the prompt for every chain.
        prompts = []
        for c in range(args.num_chains):
            if turn == 0:
                prompts.append(_initial_prompt(
                    args.prob_id, baseline_code, args.prob_type,
                    ctx=ctx, prob=prob,
                ))
            else:
                prompts.append(_refine_prompt(
                    args.prob_id, chain_current_code[c],
                    chain_current_result[c], baseline_lat, chain_best_latency[c],
                    prob_type=args.prob_type,
                    ctx=ctx, prob=prob,
                ))

        # 2. Batched LLM call: one prompt per chain, num_samples=1 each.
        logger.info("turn %d/%d: asking LLM for %d chains…",
                    turn + 1, args.turns, args.num_chains)
        try:
            batched = client.chat_async(
                prompts, num_samples=1, temperature=args.temperature
            )
            responses = [(batch[0] if batch else "") for batch in batched]
        except Exception as e:
            logger.warning("turn %d: LLM batch call failed: %s", turn, e)
            responses = [""] * args.num_chains

        codes = [extract(r) if r else "" for r in responses]

        # 3. Batched TPU eval: one candidate per chain.
        # Use placeholder baseline code for empty candidates (we'll mark them failed below).
        placeholder = baseline_code
        eval_codes = [c if c.strip() else placeholder for c in codes]
        results: list[dict] = []
        for i in range(0, len(eval_codes), args.eval_batch_size):
            batch = eval_codes[i : i + args.eval_batch_size]
            logger.info("turn %d: evaluating batch %d–%d / %d",
                        turn, i, i + len(batch), len(eval_codes))
            try:
                results.extend(evaluate_many(prob, batch, backend=backend))
            except Exception as e:
                logger.warning("turn %d: eval batch failed: %s", turn, e)
                results.extend([{"correct": False, "latency": None,
                                 "stdout": f"[eval exception: {e}]",
                                 "stderr": "", "profile": "",
                                 "failure_type": "eval_exception"}] * len(batch))

        # 4. Per-chain bookkeeping + save artifacts.
        for c in range(args.num_chains):
            result = results[c]
            if not codes[c].strip():
                result = {"correct": False, "latency": None,
                          "stdout": "[no code extracted]", "stderr": "", "profile": "",
                          "failure_type": "no_code_extracted"}
            result["_turn"] = turn
            result["_chain"] = c

            turn_dir = output_dir / f"chain_{c:02d}" / f"turn_{turn:03d}"
            _save_turn(turn_dir, prompts[c], responses[c], codes[c], result)

            correct = bool(result.get("correct"))
            latency = result.get("latency")
            if correct and latency is not None:
                if chain_best_latency[c] is None or latency < chain_best_latency[c]:
                    chain_best_latency[c] = latency
                    chain_best_turn[c] = turn

            chain_current_code[c] = codes[c] if codes[c].strip() else chain_current_code[c]
            chain_current_result[c] = result
            turns_log.append({
                "chain": c, "turn": turn,
                "correct": correct, "latency": latency,
                "failure_type": result.get("failure_type"),
            })

        # Log per-step summary.
        correct_this_step = sum(1 for c in range(args.num_chains)
                                if results[c].get("correct") and codes[c].strip())
        best_this_step = min(
            (r["latency"] for r, cd in zip(results, codes)
             if r.get("correct") and r.get("latency") and cd.strip()),
            default=None,
        )
        logger.info("turn %d: %d/%d correct, best-this-step=%s ms",
                    turn, correct_this_step, args.num_chains,
                    f"{best_this_step:.3f}" if best_this_step else "—")

    (output_dir / "turns.json").write_text(json.dumps(turns_log, indent=2))

    eval_results = [
        {"correct": t["correct"], "latency": t["latency"], "failure_type": t.get("failure_type")}
        for t in turns_log
    ]
    best_chain_idx = None
    best_overall = None
    for c, lat in enumerate(chain_best_latency):
        if lat is not None and (best_overall is None or lat < best_overall):
            best_overall = lat
            best_chain_idx = c

    summary = write_summary(
        output_dir, args.prob_id, total_budget, baseline_lat, eval_results,
        runtime_s=time.perf_counter() - t0,
        extra={
            "model": args.model,
            "baseline_correct": baseline_eval.get("correct"),
            "num_chains": args.num_chains,
            "turns": args.turns,
            "context": args.context,
            "agent_dir": args.agent_dir if args.context == "full" else None,
            "best_chain": best_chain_idx,
            "best_turn_in_best_chain": chain_best_turn[best_chain_idx] if best_chain_idx is not None else None,
            "per_chain_best_ms": chain_best_latency,
            "baseline": "iterative_parallel",
        },
    )
    logger.info("Summary: %s", summary)


if __name__ == "__main__":
    main()
