"""Gemini CLI baseline harness: sets up a per-benchmark workspace and launches
Gemini in headless mode to autonomously optimize a JAX kernel.

Usage:
    python -m autocomp.baselines.gemini_cli_harness \
        --prob_id 12p_RMSNorm --budget 144 \
        --output_dir output/baselines/gemini_cli/12p_RMSNorm
"""
from __future__ import annotations

import argparse
import json
import os
import pathlib
import shutil
import subprocess
import sys
import tempfile
import time

from autocomp.baselines.common import load_baseline_code, evaluate_many, write_summary
from autocomp.backend.jaxbench.jaxbench_eval import JaxBenchEvalBackend, extract_workload_code
from autocomp.common import logger

_THIS_DIR = pathlib.Path(__file__).resolve().parent
_AUTOCOMP_ROOT = _THIS_DIR.parent.parent
_PROMPT_TEMPLATE = _THIS_DIR / "PROMPT_TEMPLATE.md"
_EVAL_SH_TEMPLATE = _THIS_DIR / "eval_sh_template.sh"


def _setup_workspace(
    workspace: pathlib.Path,
    prob_id: str,
    budget: int,
    output_dir: pathlib.Path,
) -> pathlib.Path:
    """Create the workspace Gemini CLI will operate in."""
    workspace.mkdir(parents=True, exist_ok=True)

    prob, baseline_code = load_baseline_code(prob_id)
    (workspace / "baseline.py").write_text(baseline_code)

    snippet = extract_workload_code(prob)
    (workspace / "solution.py").write_text(snippet)

    counter_file = output_dir / "eval_counter.txt"
    counter_file.write_text("0")
    trajectory_file = output_dir / "trajectory.jsonl"
    if trajectory_file.exists():
        trajectory_file.unlink()

    eval_sh = _EVAL_SH_TEMPLATE.read_text()
    eval_sh = eval_sh.replace("__AUTOCOMP_ROOT__", str(_AUTOCOMP_ROOT))
    eval_sh = eval_sh.replace("__PROB_ID__", prob_id)
    eval_sh = eval_sh.replace("__TRAJECTORY_FILE__", str(trajectory_file))
    eval_sh = eval_sh.replace("__COUNTER_FILE__", str(counter_file))
    eval_sh = eval_sh.replace("__BUDGET__", str(budget))
    eval_sh = eval_sh.replace("__PYTHON__", sys.executable)
    eval_sh_path = workspace / "eval.sh"
    eval_sh_path.write_text(eval_sh)
    eval_sh_path.chmod(0o755)

    prompt_text = _PROMPT_TEMPLATE.read_text().replace("{budget}", str(budget))
    (workspace / "PROMPT.md").write_text(prompt_text)

    return workspace


def _make_clean_env() -> dict[str, str]:
    """Build a minimal env for the Gemini CLI subprocess.

    Only passes through what's needed for Gemini CLI auth and eval.sh,
    preventing the agent from seeing our API keys, project internals, etc.
    """
    _PASSTHROUGH = {
        "PATH", "HOME", "USER", "SHELL", "TERM", "LANG", "LC_ALL",
        "TMPDIR", "XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME",
        # Vertex AI auth for Gemini CLI
        "GOOGLE_GENAI_USE_VERTEXAI", "GOOGLE_CLOUD_PROJECT",
        "GOOGLE_APPLICATION_CREDENTIALS",
        # gcloud / SSH for eval.sh -> TPU
        "CLOUDSDK_CONFIG", "CLOUDSDK_PYTHON",
        # TPU backend config used by eval.sh -> eval_single.py
        "AUTOCOMP_TPU_SSH_HOST", "AUTOCOMP_TPU_SSH_USER",
        "AUTOCOMP_TPU_SSH_PORT", "AUTOCOMP_TPU_SSH_IDENTITY_FILE",
        "AUTOCOMP_TPU_SSH_EXTRA_ARGS", "AUTOCOMP_TPU_TRANSPORT",
        "AUTOCOMP_TPU_PYTHON", "AUTOCOMP_TPU_NUM_WARMUP",
        "AUTOCOMP_TPU_NUM_TRIALS", "AUTOCOMP_JAXBENCH_PROFILE",
        "AUTOCOMP_JAXBENCH_PROFILE_STEPS", "AUTOCOMP_JAXBENCH_IMPL_TIMEOUT",
        "JAXBENCH_DIR",
    }
    env = {k: v for k, v in os.environ.items() if k in _PASSTHROUGH}
    env["GOOGLE_CLOUD_LOCATION"] = "global"
    env["GOOGLE_GENAI_USE_VERTEXAI"] = "true"
    env["GOOGLE_CLOUD_PROJECT"] = os.environ.get("GOOGLE_CLOUD_PROJECT", "YOUR_GCP_PROJECT")
    env["PYTHONPATH"] = str(_AUTOCOMP_ROOT)
    return env


def _get_budget_remaining(output_dir: pathlib.Path, budget: int) -> int:
    counter_file = output_dir / "eval_counter.txt"
    try:
        used = int(counter_file.read_text().strip())
    except (FileNotFoundError, ValueError):
        used = 0
    return budget - used


def _get_best_latency(output_dir: pathlib.Path) -> float | None:
    results = _collect_results(output_dir)
    correct = [r["latency"] for r in results if r.get("correct") and r.get("latency")]
    return min(correct) if correct else None


_CONTINUE_PROMPT = """\
You are continuing to optimize a JAX Pallas kernel for TPU v6e.

Your current best solution achieves **{best_latency:.3f} ms** ({speedup:.2f}x speedup \
over the {baseline_latency:.3f} ms XLA baseline). \
You have **{remaining}** eval.sh calls remaining out of {budget} total.

## Files
- `baseline.py` — the original XLA/JAX reference.
- `solution.py` — your current best implementation (already correct and fast).
- `eval.sh` — run `bash eval.sh` to test. Prints correctness, latency, and profiler info.
- `docs/` — TPU v6e architecture docs, Pallas API reference, code examples.

## Goal
Make `solution.py` even faster while keeping it correct. \
Read the profiler output from eval.sh carefully — it shows MXU utilization, \
memory bandwidth, and top HLO ops to guide your next optimization.

Consult `docs/` and these links if needed:
- https://docs.jax.dev/en/latest/pallas/tpu/index.html
- https://docs.jax.dev/en/latest/jax.experimental.pallas.tpu.html

Start by running `bash eval.sh` to see the current profiler summary, then optimize.
"""


def _run_gemini(
    workspace: pathlib.Path,
    output_dir: pathlib.Path,
    model: str | None = None,
    timeout_s: int = 3600,
    prompt_override: str | None = None,
    session_idx: int = 0,
) -> subprocess.CompletedProcess:
    """Launch Gemini CLI in a temp directory copy of the workspace.

    The temp directory is isolated from the autocomp codebase so the agent
    can't traverse parent directories to read our source code or solutions.
    """
    prompt_text = prompt_override or (workspace / "PROMPT.md").read_text()

    cmd = ["gemini", "--prompt", prompt_text, "--approval-mode", "yolo"]
    if model:
        cmd.extend(["--model", model])
    cmd.extend(["--output-format", "json"])

    # Copy workspace files into an isolated temp dir
    isolated_dir = pathlib.Path(tempfile.mkdtemp(prefix="gemini_bench_"))
    for f in workspace.iterdir():
        dest = isolated_dir / f.name
        if f.is_dir():
            shutil.copytree(f, dest)
        else:
            shutil.copy2(f, dest)

    logger.info("Launching Gemini CLI session %d in isolated dir %s",
                session_idx, isolated_dir)

    stdout_path = output_dir / f"gemini_stdout_{session_idx}.json"
    stderr_path = output_dir / f"gemini_stderr_{session_idx}.txt"
    env = _make_clean_env()

    with open(stdout_path, "w") as fout, open(stderr_path, "w") as ferr:
        proc = subprocess.run(
            cmd,
            cwd=str(isolated_dir),
            stdout=fout,
            stderr=ferr,
            timeout=timeout_s,
            env=env,
        )

    # Copy agent's final files back to the real workspace
    for f in isolated_dir.iterdir():
        dest = workspace / f.name
        if f.is_dir():
            if dest.exists():
                shutil.rmtree(dest)
            shutil.copytree(f, dest)
        else:
            shutil.copy2(f, dest)
    shutil.rmtree(isolated_dir, ignore_errors=True)

    logger.info("Gemini CLI session %d exited with code %d",
                session_idx, proc.returncode)
    return proc


def _collect_results(output_dir: pathlib.Path) -> list[dict]:
    """Parse trajectory.jsonl into eval_results list."""
    trajectory_file = output_dir / "trajectory.jsonl"
    results = []
    if trajectory_file.exists():
        for line in trajectory_file.read_text().splitlines():
            line = line.strip()
            if not line:
                continue
            try:
                entry = json.loads(line)
                results.append({
                    "correct": entry.get("correct", False),
                    "latency": entry.get("latency"),
                    "failure_type": entry.get("failure_type"),
                })
            except json.JSONDecodeError:
                pass
    return results


def main():
    parser = argparse.ArgumentParser(description="Gemini CLI baseline harness")
    parser.add_argument("--prob_id", required=True)
    parser.add_argument("--budget", type=int, default=144,
                        help="Max eval.sh invocations (= sample budget)")
    parser.add_argument("--model", default="gemini-3.1-pro-preview",
                        help="Gemini model name (default: gemini-3.1-pro-preview)")
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--timeout", type=int, default=3600,
                        help="Gemini CLI process timeout in seconds")
    args = parser.parse_args()

    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    workspace = output_dir / "workspace"
    t0 = time.perf_counter()

    logger.info("Setting up workspace for %s (budget=%d)", args.prob_id, args.budget)
    _setup_workspace(workspace, args.prob_id, args.budget, output_dir)

    # Evaluate baseline for reference latency
    prob, baseline_code = load_baseline_code(args.prob_id)
    backend = JaxBenchEvalBackend()
    baseline_eval = evaluate_many(prob, [baseline_code], backend=backend)[0]
    baseline_lat = baseline_eval.get("latency")
    logger.info("Baseline latency: %.3f ms", baseline_lat or -1.0)

    session_idx = 0
    while True:
        remaining = _get_budget_remaining(output_dir, args.budget)
        if remaining <= 0:
            logger.info("Budget exhausted")
            break

        if session_idx == 0:
            prompt = None  # use PROMPT.md from workspace
        else:
            best_lat = _get_best_latency(output_dir)
            if best_lat is None:
                logger.info("No correct solution yet; relaunching with original prompt")
                prompt = None
            else:
                prompt = _CONTINUE_PROMPT.format(
                    best_latency=best_lat,
                    speedup=(baseline_lat / best_lat) if baseline_lat else 0,
                    baseline_latency=baseline_lat or 0,
                    remaining=remaining,
                    budget=args.budget,
                )

        try:
            _run_gemini(workspace, output_dir, model=args.model,
                        timeout_s=args.timeout, prompt_override=prompt,
                        session_idx=session_idx)
        except subprocess.TimeoutExpired:
            logger.warning("Gemini CLI session %d timed out", session_idx)
            break

        prev_remaining = remaining
        remaining = _get_budget_remaining(output_dir, args.budget)
        if remaining >= prev_remaining:
            logger.info("Session %d used no evals; stopping", session_idx)
            break

        session_idx += 1
        logger.info("Session %d done; %d evals remaining", session_idx - 1, remaining)

    eval_results = _collect_results(output_dir)
    runtime_s = time.perf_counter() - t0

    # Copy final solution.py
    final_solution = workspace / "solution.py"
    if final_solution.exists():
        shutil.copy2(final_solution, output_dir / "final_solution.py")

    counter_file = output_dir / "eval_counter.txt"
    evals_used = 0
    if counter_file.exists():
        try:
            evals_used = int(counter_file.read_text().strip())
        except ValueError:
            pass

    summary = write_summary(
        output_dir, args.prob_id, evals_used, baseline_lat, eval_results,
        runtime_s=runtime_s,
        extra={
            "model": args.model or "gemini-cli-default",
            "budget": args.budget,
            "evals_used": evals_used,
            "baseline": "gemini_cli",
        },
    )
    logger.info("Summary: %s", json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
