from __future__ import annotations

import json
import os
import re
import shutil
import subprocess
import time
import argparse
from collections import Counter
from pathlib import Path
from typing import Any

import modal


APP_NAME = "openairft-katago-pv"
LOCAL_DIR = Path(__file__).resolve().parent
REMOTE_PROJECT_DIR = "/root/openairft"
REMOTE_INPUT_PATH = f"{REMOTE_PROJECT_DIR}/dan_reviews_1d_plus.jsonl"
REMOTE_OUTPUT_DIR = "/outputs"
REMOTE_CACHE_DIR = "/cache"
REMOTE_KATAGO_DIR = f"{REMOTE_CACHE_DIR}/katago"

OUTPUT_VOLUME_NAME = os.environ.get("OUTPUT_VOLUME_NAME", "openairft-katago-pv")
CACHE_VOLUME_NAME = os.environ.get("CACHE_VOLUME_NAME", "openairft-katago-cache")
OUTPUT_STEM = os.environ.get("OUTPUT_STEM", "katago_pv_1k")

KATAGO_EXECUTABLE_URL = os.environ.get(
    "KATAGO_EXECUTABLE_URL",
    "https://github.com/lightvector/KataGo/releases/download/v1.16.4/katago-v1.16.4-cuda12.8-cudnn9.8.0-linux-x64.zip",
)
KATAGO_MODEL_URL = os.environ.get(
    "KATAGO_MODEL_URL",
    "https://media.katagotraining.org/uploaded/networks/models/kata1/kata1-b28c512nbt-s12434156288-d5719330235.bin.gz",
)
KATAGO_CONFIG_URL = os.environ.get(
    "KATAGO_CONFIG_URL",
    "https://raw.githubusercontent.com/lightvector/KataGo/master/cpp/configs/analysis_example.cfg",
)

KATAGO_VISITS = int(os.environ.get("KATAGO_VISITS", "800"))
KATAGO_INFLIGHT_LIMIT = int(os.environ.get("KATAGO_INFLIGHT_LIMIT", "32"))
KATAGO_NUM_ANALYSIS_THREADS = os.environ.get("KATAGO_NUM_ANALYSIS_THREADS", "24")
KATAGO_SEARCH_THREADS_PER_ANALYSIS_THREAD = os.environ.get(
    "KATAGO_SEARCH_THREADS_PER_ANALYSIS_THREAD", "1"
)
QUERY_TIMEOUT_SEC = float(os.environ.get("QUERY_TIMEOUT_SEC", "180"))
SAMPLE_N = int(os.environ.get("SAMPLE_N", "1000"))
SAMPLE_SEED = int(os.environ.get("SAMPLE_SEED", "20260503"))
PV_COUNT = int(os.environ.get("PV_COUNT", "10"))
PV_PLIES = int(os.environ.get("PV_PLIES", "12"))

GO_COLUMNS = "ABCDEFGHJKLMNOPQRST"

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("curl", "unzip", "xz-utils")
    .pip_install(
        "nvidia-cublas-cu12",
        "nvidia-cuda-runtime-cu12",
        "nvidia-cudnn-cu12",
        "nvidia-cusolver-cu12",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_PROJECT_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
output_volume = modal.Volume.from_name(OUTPUT_VOLUME_NAME, create_if_missing=True)
cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)


def katago_paths() -> tuple[Path, Path, Path]:
    root = Path(REMOTE_KATAGO_DIR)
    return root / "katago", root / "model.bin.gz", root / "analysis.cfg"


def patch_analysis_config(config_path: Path) -> None:
    text = config_path.read_text()
    lines = []
    seen = set()
    replacements = {
        "numAnalysisThreads": f"numAnalysisThreads = {KATAGO_NUM_ANALYSIS_THREADS}",
        "numSearchThreadsPerAnalysisThread": (
            "numSearchThreadsPerAnalysisThread = "
            f"{KATAGO_SEARCH_THREADS_PER_ANALYSIS_THREAD}"
        ),
        "reportAnalysisWinratesAs": "reportAnalysisWinratesAs = BLACK",
        "nnMaxBatchSize": "nnMaxBatchSize = 64",
        "nnCacheSizePowerOfTwo": "nnCacheSizePowerOfTwo = 22",
    }
    for line in text.splitlines():
        stripped = line.strip()
        key = stripped.split("=", 1)[0].strip() if "=" in stripped else ""
        if key in replacements:
            lines.append(replacements[key])
            seen.add(key)
        else:
            lines.append(line)
    for key, value in replacements.items():
        if key not in seen:
            lines.append(value)
    config_path.write_text("\n".join(lines) + "\n")


def install_or_reuse_katago_assets() -> dict[str, str]:
    katago_bin, katago_model, katago_config = katago_paths()
    katago_root = katago_bin.parent
    katago_root.mkdir(parents=True, exist_ok=True)

    if katago_bin.exists() and katago_model.exists() and katago_config.exists():
        katago_bin.chmod(0o755)
        patch_analysis_config(katago_config)
        return {
            "katago_binary": str(katago_bin),
            "katago_model": str(katago_model),
            "katago_config": str(katago_config),
            "cache_hit": "true",
        }

    archive_path = katago_root / "katago_download.zip"
    subprocess.run(["curl", "-L", KATAGO_EXECUTABLE_URL, "-o", str(archive_path)], check=True)
    subprocess.run(["unzip", "-o", str(archive_path), "-d", str(katago_root)], check=True)
    found_bins = list(katago_root.rglob("katago"))
    if not found_bins:
        raise RuntimeError("Could not find katago binary after extracting archive")
    source_bin = found_bins[0]
    if source_bin.resolve() != katago_bin.resolve():
        shutil.copy2(source_bin, katago_bin)
    katago_bin.chmod(0o755)

    subprocess.run(["curl", "-L", KATAGO_MODEL_URL, "-o", str(katago_model)], check=True)
    subprocess.run(["curl", "-L", KATAGO_CONFIG_URL, "-o", str(katago_config)], check=True)
    patch_analysis_config(katago_config)
    cache_volume.commit()
    return {
        "katago_binary": str(katago_bin),
        "katago_model": str(katago_model),
        "katago_config": str(katago_config),
        "cache_hit": "false",
    }


def parse_board_state(board_state: str) -> dict[str, Any]:
    text = str(board_state or "").strip()
    if "|" in text and text.lower().startswith("move:"):
        prefix, text = text.split("|", 1)
        prefix_match = re.search(r"move:(\d+)", prefix, re.IGNORECASE)
        prefix_move = int(prefix_match.group(1)) if prefix_match else None
    else:
        prefix_move = None

    stones = []
    turns = []
    occupied = {}
    for item in text.split(","):
        item = item.strip()
        match = re.fullmatch(r"([BW])(\d+):([A-HJ-T])(\d{1,2}|pass|PASS)", item)
        if not match:
            continue
        color, turn_text, col, row_text = match.groups()
        if row_text.upper() == "PASS":
            continue
        coord = f"{col.upper()}{int(row_text)}"
        if col.upper() not in GO_COLUMNS:
            continue
        row_num = int(row_text)
        if not 1 <= row_num <= 19:
            continue
        occupied[coord] = color
        turns.append((int(turn_text), color))

    for coord, color in sorted(occupied.items()):
        stones.append([color, coord])
    if turns:
        last_turn, last_color = max(turns, key=lambda pair: pair[0])
        initial_player = "W" if last_color == "B" else "B"
        inferred_move_number = last_turn
    elif prefix_move is not None:
        initial_player = "B" if prefix_move % 2 == 0 else "W"
        inferred_move_number = prefix_move
    else:
        initial_player = "B"
        inferred_move_number = 0
    return {
        "initial_stones": stones,
        "initial_player": initial_player,
        "inferred_move_number": inferred_move_number,
        "parsed_stone_count": len(stones),
    }


def load_and_sample_rows(sample_n: int, sample_seed: int) -> list[dict[str, Any]]:
    rows = []
    with Path(REMOTE_INPUT_PATH).open(encoding="utf-8") as handle:
        for line in handle:
            if not line.strip():
                continue
            row = json.loads(line)
            if row.get("board_state"):
                parsed = parse_board_state(row["board_state"])
                if parsed["initial_stones"]:
                    row.update(parsed)
                    rows.append(row)
    import random

    rng = random.Random(sample_seed)
    if len(rows) < sample_n:
        raise RuntimeError(f"Only {len(rows)} usable rows; requested {sample_n}")
    return rng.sample(rows, sample_n)


def subprocess_env() -> dict[str, str]:
    env = os.environ.copy()
    env["APPIMAGE_EXTRACT_AND_RUN"] = "1"
    site_dirs = site_dirs_for_cuda()
    if site_dirs:
        env["LD_LIBRARY_PATH"] = ":".join(site_dirs)
    return env


def site_dirs_for_cuda() -> list[str]:
    import site

    dirs = []
    for base in site.getsitepackages():
        nvidia_dir = Path(base) / "nvidia"
        if not nvidia_dir.exists():
            continue
        for lib_dir in nvidia_dir.glob("*/lib"):
            dirs.append(str(lib_dir))
    return dirs


def ordered_move_infos(move_infos: list[dict[str, Any]]) -> list[dict[str, Any]]:
    return sorted(
        move_infos,
        key=lambda item: (
            int(item.get("order", 1_000_000)),
            -float(item.get("visits", item.get("edgeVisits", 0)) or 0),
        ),
    )


def trim_pv_lines(move_infos: list[dict[str, Any]], pv_count: int, pv_plies: int) -> list[dict[str, Any]]:
    pv_lines = []
    for idx, info in enumerate(ordered_move_infos(move_infos)[:pv_count], start=1):
        pv = info.get("pv") or []
        pv_lines.append(
            {
                "rank": idx,
                "move": info.get("move"),
                "pv": pv[:pv_plies],
                "pv_plies_returned": min(len(pv), pv_plies),
                "winrate": info.get("winrate"),
                "scoreLead": info.get("scoreLead"),
                "visits": info.get("visits", info.get("edgeVisits")),
                "prior": info.get("prior"),
                "order": info.get("order"),
            }
        )
    return pv_lines


def drain_stderr(proc: subprocess.Popen[str]) -> str:
    if proc.stderr is None:
        return ""
    chunks = []
    try:
        import select

        while True:
            ready, _, _ = select.select([proc.stderr], [], [], 0)
            if not ready:
                break
            line = proc.stderr.readline()
            if not line:
                break
            chunks.append(line.rstrip("\n"))
    except Exception:
        pass
    return "\n".join(chunks)


def run_queries(
    rows: list[dict[str, Any]],
    katago_info: dict[str, str],
    katago_visits: int,
    inflight_limit: int,
    query_timeout_sec: float,
    pv_count: int,
    pv_plies: int,
) -> tuple[list[dict[str, Any]], Counter]:
    command = [
        katago_info["katago_binary"],
        "analysis",
        "-model",
        katago_info["katago_model"],
        "-config",
        katago_info["katago_config"],
    ]
    proc = subprocess.Popen(
        command,
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        bufsize=1,
        env=subprocess_env(),
    )
    assert proc.stdin is not None and proc.stdout is not None
    stats: Counter = Counter()
    pending = iter(rows)
    inflight: dict[str, dict[str, Any]] = {}
    results_by_index: dict[int, dict[str, Any]] = {}
    completed = 0
    submitted = 0
    last_report = time.time()

    def submit(row: dict[str, Any]) -> None:
        nonlocal submitted
        qid = str(row["id"])
        query = {
            "id": qid,
            "initialStones": row["initial_stones"],
            "moves": [],
            "initialPlayer": row["initial_player"],
            "rules": "japanese",
            "komi": 6.5,
            "boardXSize": 19,
            "boardYSize": 19,
            "maxVisits": katago_visits,
            "includePolicy": True,
            "includeOwnership": False,
        }
        proc.stdin.write(json.dumps(query) + "\n")
        proc.stdin.flush()
        inflight[qid] = row
        submitted += 1
        stats["submitted"] = submitted

    try:
        while True:
            while len(inflight) < inflight_limit:
                try:
                    row = next(pending)
                except StopIteration:
                    break
                submit(row)
            if not inflight:
                break
            import select

            ready, _, _ = select.select([proc.stdout], [], [], query_timeout_sec)
            if not ready:
                stats["timeouts"] += len(inflight)
                stderr = drain_stderr(proc)
                if stderr:
                    stats["stderr_tail"] = stderr[-4000:]
                break
            line = proc.stdout.readline()
            if not line:
                stderr = drain_stderr(proc)
                if stderr:
                    stats["stderr_tail"] = stderr[-4000:]
                stats["process_exited_without_stdout"] += 1
                break
            try:
                payload = json.loads(line)
            except json.JSONDecodeError:
                continue
            if "warning" in payload:
                continue
            qid = str(payload.get("id", ""))
            row = inflight.pop(qid, None)
            if row is None:
                continue
            if "error" in payload:
                stats["katago_errors"] += 1
                stats["last_error"] = payload["error"]
                continue
            pv_lines = trim_pv_lines(payload.get("moveInfos") or [], pv_count, pv_plies)
            root_info = payload.get("rootInfo") or {}
            result = {
                "id": row["id"],
                "game_file": row.get("game_file"),
                "move_number": row.get("move_number"),
                "reviewer_rank": row.get("reviewer_rank"),
                "initial_player": row["initial_player"],
                "inferred_move_number": row["inferred_move_number"],
                "parsed_stone_count": row["parsed_stone_count"],
                "rules": "japanese",
                "komi": 6.5,
                "katago_visits": katago_visits,
                "root_winrate": root_info.get("winrate"),
                "root_scoreLead": root_info.get("scoreLead"),
                "top_move_winrate": pv_lines[0].get("winrate") if pv_lines else None,
                "top_move_scoreLead": pv_lines[0].get("scoreLead") if pv_lines else None,
                "principal_variations": pv_lines,
            }
            results_by_index[int(row["_sample_index"])] = result
            completed += 1
            stats["completed"] = completed
            if time.time() - last_report > 10:
                print(f"[INFO] completed={completed} submitted={submitted} inflight={len(inflight)}")
                last_report = time.time()
    finally:
        try:
            proc.stdin.close()
        except Exception:
            pass
        try:
            proc.wait(timeout=5)
        except subprocess.TimeoutExpired:
            proc.terminate()
            proc.wait(timeout=5)

    return [results_by_index[idx] for idx in sorted(results_by_index)], stats


@app.function(
    gpu="A10G",
    cpu=8,
    memory=32768,
    timeout=60 * 60 * 8,
    volumes={REMOTE_OUTPUT_DIR: output_volume, REMOTE_CACHE_DIR: cache_volume},
)
def run_katago_pv_sample(
    sample_n: int = SAMPLE_N,
    sample_seed: int = SAMPLE_SEED,
    output_stem: str = OUTPUT_STEM,
    katago_visits: int = KATAGO_VISITS,
    inflight_limit: int = KATAGO_INFLIGHT_LIMIT,
    query_timeout_sec: float = QUERY_TIMEOUT_SEC,
    pv_count: int = PV_COUNT,
    pv_plies: int = PV_PLIES,
) -> dict[str, Any]:
    out_dir = Path(REMOTE_OUTPUT_DIR) / output_stem
    out_dir.mkdir(parents=True, exist_ok=True)
    katago_info = install_or_reuse_katago_assets()
    rows = load_and_sample_rows(sample_n, sample_seed)
    for idx, row in enumerate(rows):
        row["_sample_index"] = idx
    results, stats = run_queries(
        rows,
        katago_info,
        katago_visits=katago_visits,
        inflight_limit=inflight_limit,
        query_timeout_sec=query_timeout_sec,
        pv_count=pv_count,
        pv_plies=pv_plies,
    )

    jsonl_path = out_dir / "katago_pv_1k.jsonl"
    with jsonl_path.open("w", encoding="utf-8") as handle:
        for row in results:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")
    manifest = {
        "sample_n": sample_n,
        "sample_seed": sample_seed,
        "pv_count": pv_count,
        "pv_plies": pv_plies,
        "katago_visits": katago_visits,
        "katago_info": katago_info,
        "stats": dict(stats),
        "output_file": f"{output_stem}/katago_pv_1k.jsonl",
        "assumptions": [
            "Supabase board_state is treated as current stones, not a replayable full move list.",
            "Rules default to japanese and komi defaults to 6.5 because those fields are not present in the export.",
            "Winrate and scoreLead are reported from Black's perspective.",
        ],
    }
    manifest_path = out_dir / "manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
    output_volume.commit()
    cache_volume.commit()
    return manifest


@app.local_entrypoint()
def main(
    sample_n: int = SAMPLE_N,
    sample_seed: int = SAMPLE_SEED,
    output_stem: str = OUTPUT_STEM,
    katago_visits: int = KATAGO_VISITS,
    inflight_limit: int = KATAGO_INFLIGHT_LIMIT,
) -> None:
    manifest = run_katago_pv_sample.remote(
        sample_n=sample_n,
        sample_seed=sample_seed,
        output_stem=output_stem,
        katago_visits=katago_visits,
        inflight_limit=inflight_limit,
    )
    print(json.dumps(manifest, indent=2))
