"""mini-swe-agent baseline harness: sets up a per-benchmark workspace and
launches mini-swe-agent to autonomously optimize a JAX kernel.

Every bash command the agent issues counts as one LLM turn (mini-swe-agent
enforces this via step_limit).  The eval.sh budget (144 by default) is
enforced separately by eval_single.py, just like the Gemini CLI harness.

Usage:
    python -m autocomp.baselines.mini_swe_harness \
        --prob_id 12p_RMSNorm --budget 144 \
        --output_dir output/baselines/mini_swe/12p_RMSNorm
"""
from __future__ import annotations

import argparse
import json
import os
import pathlib
import shutil
import sys
import tempfile
import time

from autocomp.baselines.common import load_baseline_code, write_summary
from autocomp.backend.jaxbench.jaxbench_eval import JaxBenchEvalBackend, extract_workload_code
from autocomp.common import logger

_THIS_DIR = pathlib.Path(__file__).resolve().parent
_AUTOCOMP_ROOT = _THIS_DIR.parent.parent
_PROMPT_TEMPLATE = _THIS_DIR / "PROMPT_TEMPLATE.md"
_EVAL_SH_TEMPLATE = _THIS_DIR / "eval_sh_template.sh"


def _setup_workspace(
    workspace: pathlib.Path,
    prob_id: str,
    prob_type: str,
    budget: int,
    output_dir: pathlib.Path,
) -> pathlib.Path:
    workspace.mkdir(parents=True, exist_ok=True)

    prob, baseline_code = load_baseline_code(prob_id, prob_type)
    (workspace / "baseline.py").write_text(baseline_code)

    snippet = extract_workload_code(prob)
    (workspace / "solution.py").write_text(snippet)

    counter_file = output_dir / "eval_counter.txt"
    counter_file.write_text("0")
    trajectory_file = output_dir / "trajectory.jsonl"
    if trajectory_file.exists():
        trajectory_file.unlink()

    eval_sh = _EVAL_SH_TEMPLATE.read_text()
    eval_sh = eval_sh.replace("__AUTOCOMP_ROOT__", str(_AUTOCOMP_ROOT))
    eval_sh = eval_sh.replace("__PROB_ID__", prob_id)
    eval_sh = eval_sh.replace("__TRAJECTORY_FILE__", str(trajectory_file))
    eval_sh = eval_sh.replace("__COUNTER_FILE__", str(counter_file))
    eval_sh = eval_sh.replace("__BUDGET__", str(budget))
    eval_sh = eval_sh.replace("__PYTHON__", sys.executable)
    eval_sh_path = workspace / "eval.sh"
    eval_sh_path.write_text(eval_sh)
    eval_sh_path.chmod(0o755)

    prompt_text = _PROMPT_TEMPLATE.read_text().replace("{budget}", str(budget))
    if prob_type == "jaxbench-pallas":
        prompt_text = prompt_text.replace(
            "the XLA/JAX reference implementation. Rewrite `workload()` as a Pallas kernel that is correct and faster.",
            "an existing Pallas implementation. Optimize it to run faster while keeping it correct.",
        )
    (workspace / "PROMPT.md").write_text(prompt_text)

    return workspace


def _collect_results(output_dir: pathlib.Path) -> list[dict]:
    trajectory_file = output_dir / "trajectory.jsonl"
    results = []
    if trajectory_file.exists():
        for line in trajectory_file.read_text().splitlines():
            line = line.strip()
            if not line:
                continue
            try:
                entry = json.loads(line)
                results.append({
                    "correct": entry.get("correct", False),
                    "latency": entry.get("latency"),
                    "failure_type": entry.get("failure_type"),
                })
            except json.JSONDecodeError:
                pass
    return results


_SYSTEM_TEMPLATE = """\
You are an expert JAX/Pallas kernel engineer optimizing code for TPU v6e (Trillium).

You interact with the system by issuing bash commands. Each command runs in a \
fresh subshell, so use `cd /path && ...` when you need a specific working directory.

When you are done, submit with: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT`
"""

_INSTANCE_TEMPLATE = """\
{{task}}

## Command Execution Rules

Each response must include at least one bash tool call.
Directory/environment changes are not persistent across commands — \
prefix with `cd /path && ...` as needed.

When finished, run: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT`
(do not combine it with other commands).

<system_information>
{{system}} {{release}} {{version}} {{machine}}
</system_information>
"""


def _run_agent(
    workspace: pathlib.Path,
    output_dir: pathlib.Path,
    model: str,
    step_limit: int,
    cost_limit: float,
    timeout_s: int,
) -> dict:
    from minisweagent.agents.default import DefaultAgent
    from minisweagent.environments.local import LocalEnvironment
    from minisweagent.models.litellm_model import LitellmModel

    isolated_dir = pathlib.Path(tempfile.mkdtemp(prefix="mini_swe_bench_"))
    for f in workspace.iterdir():
        dest = isolated_dir / f.name
        if f.is_dir():
            shutil.copytree(f, dest)
        else:
            shutil.copy2(f, dest)

    logger.info("Launching mini-swe-agent in isolated dir %s", isolated_dir)

    task_text = (isolated_dir / "PROMPT.md").read_text()

    # Vertex AI auth for litellm
    os.environ.setdefault("GOOGLE_CLOUD_PROJECT", "YOUR_GCP_PROJECT")
    os.environ.setdefault("VERTEXAI_PROJECT", "YOUR_GCP_PROJECT")
    os.environ.setdefault("VERTEXAI_LOCATION", "global")

    env = LocalEnvironment(
        cwd=str(isolated_dir),
        timeout=timeout_s,
        env={
            "PAGER": "cat",
            "MANPAGER": "cat",
            "LESS": "-R",
            "PIP_PROGRESS_BAR": "off",
            "TQDM_DISABLE": "1",
        },
    )

    mdl = LitellmModel(
        model_name=model,
        model_kwargs={"drop_params": True},
        cost_tracking="ignore_errors",
    )

    trajectory_path = output_dir / "mini_swe_trajectory.json"
    agent = DefaultAgent(
        model=mdl,
        env=env,
        system_template=_SYSTEM_TEMPLATE,
        instance_template=_INSTANCE_TEMPLATE,
        step_limit=step_limit,
        cost_limit=cost_limit,
        output_path=trajectory_path,
    )

    result = agent.run(task=task_text)

    for f in isolated_dir.iterdir():
        dest = workspace / f.name
        if f.is_dir():
            if dest.exists():
                shutil.rmtree(dest)
            shutil.copytree(f, dest)
        else:
            shutil.copy2(f, dest)
    shutil.rmtree(isolated_dir, ignore_errors=True)

    logger.info(
        "mini-swe-agent finished: exit_status=%s, steps=%d, cost=$%.4f",
        result.get("exit_status", "unknown"),
        agent.n_calls,
        agent.cost,
    )
    return {
        "exit_status": result.get("exit_status", "unknown"),
        "n_calls": agent.n_calls,
        "cost": agent.cost,
    }


def main():
    parser = argparse.ArgumentParser(description="mini-swe-agent baseline harness")
    parser.add_argument("--prob_id", required=True)
    parser.add_argument("--prob_type", default="jaxbench-baseline",
                        choices=["jaxbench-baseline", "jaxbench-pallas"])
    parser.add_argument("--budget", type=int, default=144,
                        help="Max eval.sh invocations (enforced by eval_single.py)")
    parser.add_argument("--step_limit", type=int, default=300,
                        help="Max LLM turns (generous; eval budget is the real constraint)")
    parser.add_argument("--cost_limit", type=float, default=20.0,
                        help="Cost limit in USD")
    parser.add_argument("--model", default="vertex_ai/gemini-3.1-pro-preview",
                        help="litellm model name (vertex_ai/ prefix uses Vertex AI)")
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--timeout", type=int, default=600,
                        help="Per-command timeout in seconds")
    args = parser.parse_args()

    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    workspace = output_dir / "workspace"
    t0 = time.perf_counter()

    logger.info("Setting up workspace for %s (eval budget=%d, step_limit=%d)",
                args.prob_id, args.budget, args.step_limit)
    _setup_workspace(workspace, args.prob_id, args.prob_type, args.budget, output_dir)

    prob, baseline_code = load_baseline_code(args.prob_id, args.prob_type)
    backend = JaxBenchEvalBackend()
    baseline_eval = backend.evaluate_code(prob, [baseline_code], simulator=None)[0]
    baseline_lat = baseline_eval.get("latency")
    logger.info("Baseline latency: %.3f ms", baseline_lat or -1.0)

    agent_result = _run_agent(
        workspace, output_dir,
        model=args.model,
        step_limit=args.step_limit,
        cost_limit=args.cost_limit,
        timeout_s=args.timeout,
    )

    eval_results = _collect_results(output_dir)
    runtime_s = time.perf_counter() - t0

    final_solution = workspace / "solution.py"
    if final_solution.exists():
        shutil.copy2(final_solution, output_dir / "final_solution.py")

    counter_file = output_dir / "eval_counter.txt"
    evals_used = 0
    if counter_file.exists():
        try:
            evals_used = int(counter_file.read_text().strip())
        except ValueError:
            pass

    summary = write_summary(
        output_dir, args.prob_id, evals_used, baseline_lat, eval_results,
        runtime_s=runtime_s,
        extra={
            "model": args.model,
            "budget": args.budget,
            "evals_used": evals_used,
            "baseline": "mini_swe_agent",
            "step_limit": args.step_limit,
            "agent_steps": agent_result.get("n_calls", 0),
            "agent_cost": agent_result.get("cost", 0.0),
            "agent_exit_status": agent_result.get("exit_status", ""),
        },
    )
    logger.info("Summary: %s", json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
