#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Batch LM evaluation runner.

Usage examples:

1) Pass a timestamp directory to evaluate across all run subdirs (latest checkpoint only):
   python tools/run_lm_eval_batch.py \
       

   This finds run directories like 0/, 1/, 2/, ... and evaluates only the latest step_* checkpoint per run.

2) Pass a specific run dir (e.g., 2/) to evaluate only that run's latest checkpoint:
   python tools/run_lm_eval_batch.py \
       

Outputs:
  - For each run directory, a subdirectory `lm_eval_results/step_XXXXX` is created.
  - We mirror the parameters in `lm_eval.sh` and write stdout/stderr to files.
"""

from __future__ import annotations

import filelock
from filelock import SoftFileLock
filelock.FileLock = SoftFileLock
import os
os.environ["SOFT_FILELOCK"] = "1"

import argparse
import re
import shlex
import subprocess
import sys
import git
from pathlib import Path
from typing import List


STEP_DIR_PATTERN = re.compile(r"^step_\d+$")
PROJECT_ROOT = Path(git.Repo(search_parent_directories=True).git.rev_parse("--show-toplevel"))


def find_run_directories(root: Path) -> List[Path]:
    """Return a list of run directories that may contain `step_*` checkpoints.

    Accepts either a timestamp directory (e.g., .../17-50-03) containing run subdirs
    like `0/`, `1/`... or a specific run directory directly (.../17-50-03/2).
    """
    if not root.exists():
        raise FileNotFoundError(f"Path does not exist: {root}")

    # If user passed a run directory (contains trainer files or step_*), use it directly.
    if (root / "trainer_config.json").exists() or any(
        child.is_dir() and STEP_DIR_PATTERN.match(child.name)
        for child in root.iterdir()
    ):
        return [root]

    # Otherwise treat as timestamp directory and collect numeric subdirs that look like runs.
    run_dirs: List[Path] = []
    for child in root.iterdir():
        if child.is_dir() and child.name.isdigit():
            run_dirs.append(child)
    # Some runs place checkpoints only in a specific run, but others may use just `0/`.
    # Sort for stable order.
    run_dirs.sort(key=lambda p: int(p.name))
    if not run_dirs:
        # Sometimes there is only a single run `0/` under the given directory, try that fallback
        candidate = root / "0"
        if candidate.exists():
            return [candidate]
        raise RuntimeError(f"No run directories found under: {root}")
    return run_dirs


def find_checkpoints(run_dir: Path) -> List[Path]:
    """Return a sorted list of checkpoint directories named `step_*` inside run_dir."""
    checkpoints: List[Path] = []
    if not run_dir.exists():
        return checkpoints
    for child in run_dir.iterdir():
        if child.is_dir() and STEP_DIR_PATTERN.match(child.name):
            checkpoints.append(child)

    # Sort by the numeric step value
    def step_key(p: Path) -> int:
        try:
            return int(p.name.split("_")[1])
        except Exception:
            return 0

    checkpoints.sort(key=step_key)
    return checkpoints


def build_eval_command(checkpoint_dir: Path, output_dir: Path) -> List[str]:
    """Mirror the command from `lm_eval.sh` using Python list for subprocess.

    python exp/lm_evals/harness.py \
        --model fla \
        --model_args pretrained="$LAST_CKPT",dtype=bfloat16,tokenizer=mistralai/Mistral-7B-v0.1 \
        --batch_size 256 \
        --tasks wikitext,lambada,piqa,hellaswag,winogrande,arc_easy,arc_challenge,swde,squad_completion,fda \
        --num_fewshot 0 \
        --device cuda \
        --show_config \
        --output_path "$OUTPUT_PATH" \
        --trust_remote_code \
        --seed 42
    """

    model_args = (
        f"pretrained={checkpoint_dir}",
        "dtype=bfloat16",
        "tokenizer=mistralai/Mistral-7B-v0.1",
    )
    cmd: List[str] = [
        sys.executable,
        "exp/lm_evals/harness.py",
        "--model",
        "fla",
        "--model_args",
        ",".join(model_args),
        "--batch_size",
        "64",
        "--tasks",
        "wikitext,lambada,piqa,hellaswag,winogrande,arc_easy,arc_challenge,swde,squad_completion,fda",
        "--num_fewshot",
        "0",
        "--device",
        "cuda",
        "--show_config",
        "--output_path",
        str(output_dir),
        "--trust_remote_code",
        "--seed",
        "42",
    ]
    return cmd


def run_eval_for_checkpoint(checkpoint_dir: Path, results_dir: Path) -> int:
    """Run evaluation for a single checkpoint and log outputs.

    Returns the subprocess return code.
    """
    results_dir.mkdir(parents=True, exist_ok=True)
    stdout_file = results_dir / "stdout.log"
    stderr_file = results_dir / "stderr.log"

    cmd = build_eval_command(checkpoint_dir, results_dir)

    with stdout_file.open("wb") as out_f, stderr_file.open("wb") as err_f:
        process = subprocess.Popen(
            cmd,
            stdout=out_f,
            stderr=err_f,
            cwd=str(PROJECT_ROOT),  # project root
        )
        returncode = process.wait()

    return returncode


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Batch run LM evaluations across checkpoints."
    )
    parser.add_argument(
        "target",
        type=str,
        help=(
            "Path to a timestamp directory (e.g., .../2025-09-18/17-50-03) or a specific run directory (e.g., .../17-50-03/2)."
        ),
    )
    # Always evaluate only the latest checkpoint; no --limit option
    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="Print actions without running evaluations.",
    )

    args = parser.parse_args()
    target_path = Path(args.target).expanduser().resolve()

    run_dirs = find_run_directories(target_path)

    for run_dir in run_dirs:
        checkpoints = find_checkpoints(run_dir)
        if not checkpoints:
            print(f"No checkpoints found in run: {run_dir}")
            continue

        # Only evaluate the latest checkpoint
        ckpt = checkpoints[-1]

        # Results under the run directory
        lm_eval_root = run_dir / "lm_eval_results"
        lm_eval_root.mkdir(parents=True, exist_ok=True)

        step_name = ckpt.name
        results_dir = lm_eval_root / step_name
        print(f"Evaluating run={run_dir.name} checkpoint={step_name} -> {results_dir}")

        if args.dry_run:
            cmd_preview = " ".join(
                shlex.quote(part) for part in build_eval_command(ckpt, results_dir)
            )
            print(f"DRY RUN: {cmd_preview}")
            continue

        code = run_eval_for_checkpoint(ckpt, results_dir)
        if code != 0:
            print(f"WARNING: Eval failed for {ckpt} with code {code}")


if __name__ == "__main__":
    main()
