from __future__ import annotations

import json
import os
import selectors
import shutil
import subprocess
import time
from pathlib import Path

import modal


APP_NAME = "katago-large-dataset-build"
OUTPUT_STEM = os.environ.get("OUTPUT_STEM", "katago_large_dataset_v1")
OUTPUT_VOLUME_NAME = os.environ.get("OUTPUT_VOLUME_NAME", "katago-large-datasets")
REMOTE_OUTPUT_DIR = "/outputs"
REMOTE_PROJECT_DIR = "/root/project"
REMOTE_SGF_DIR = "/root/gtlreviews"
REMOTE_WORK_DIR = "/root/work"
REMOTE_KATAGO_DIR = "/root/katago"

LOCAL_DIR = Path(__file__).resolve().parent
LOCAL_SGF_DIR = LOCAL_DIR.parent.parent / "gtlreviews"

KATAGO_EXECUTABLE_URL = os.environ.get(
    "KATAGO_EXECUTABLE_URL",
    "https://github.com/lightvector/KataGo/releases/download/v1.16.4/katago-v1.16.4-cuda12.8-cudnn9.8.0-linux-x64.zip",
)
KATAGO_MODEL_URL = os.environ.get(
    "KATAGO_MODEL_URL",
    "https://media.katagotraining.org/uploaded/networks/models/kata1/kata1-b28c512nbt-s12434156288-d5719330235.bin.gz",
)
KATAGO_CONFIG_URL = os.environ.get(
    "KATAGO_CONFIG_URL",
    "https://raw.githubusercontent.com/lightvector/KataGo/master/cpp/configs/analysis_example.cfg",
)
KATAGO_NUM_ANALYSIS_THREADS = os.environ.get("KATAGO_NUM_ANALYSIS_THREADS", "32")
KATAGO_NUM_SEARCH_THREADS_PER_ANALYSIS_THREAD = os.environ.get(
    "KATAGO_NUM_SEARCH_THREADS_PER_ANALYSIS_THREAD", "1"
)
KATAGO_REPORT_ANALYSIS_WINRATES_AS = os.environ.get("KATAGO_REPORT_ANALYSIS_WINRATES_AS", "BLACK")

DATASET_MAX_POSITIONS = os.environ.get("DATASET_MAX_POSITIONS", "10000")
DATASET_MAX_GAMES = os.environ.get("DATASET_MAX_GAMES", "0")
DATASET_MAX_POSITIONS_PER_GAME = os.environ.get("DATASET_MAX_POSITIONS_PER_GAME", "8")
DATASET_MIN_COMMENT_CHARS = os.environ.get("DATASET_MIN_COMMENT_CHARS", "20")
KATAGO_ANALYSIS_VISITS = os.environ.get("KATAGO_ANALYSIS_VISITS", "2000")
KATAGO_INFLIGHT_LIMIT = os.environ.get("KATAGO_INFLIGHT_LIMIT", "64")
QUERY_TIMEOUT_SEC = os.environ.get("QUERY_TIMEOUT_SEC", "180")
VOLUME_COMMIT_EVERY_SEC = float(os.environ.get("VOLUME_COMMIT_EVERY_SEC", "30"))

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("curl", "unzip", "xz-utils")
    .pip_install(
        "sgfmill",
        "nvidia-cublas-cu12",
        "nvidia-cuda-runtime-cu12",
        "nvidia-cudnn-cu12",
        "nvidia-cusolver-cu12",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_PROJECT_DIR, copy=True)
    .add_local_dir(LOCAL_SGF_DIR, remote_path=REMOTE_SGF_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
output_volume = modal.Volume.from_name(OUTPUT_VOLUME_NAME, create_if_missing=True)


def katago_paths() -> tuple[Path, Path, Path]:
    return (
        Path(REMOTE_KATAGO_DIR) / "katago",
        Path(REMOTE_KATAGO_DIR) / "model.bin.gz",
        Path(REMOTE_KATAGO_DIR) / "analysis.cfg",
    )


def install_katago_assets() -> dict[str, str]:
    katago_bin, katago_model, katago_config = katago_paths()
    katago_root = katago_bin.parent
    katago_root.mkdir(parents=True, exist_ok=True)

    archive_path = katago_root / "katago_download"
    subprocess.run(["curl", "-L", KATAGO_EXECUTABLE_URL, "-o", str(archive_path)], check=True)
    subprocess.run(["unzip", "-o", str(archive_path), "-d", str(katago_root)], check=True)
    found_bins = list(katago_root.rglob("katago"))
    if not found_bins:
        raise RuntimeError("Could not find `katago` binary after extracting the archive")
    source_bin = found_bins[0]
    if source_bin.resolve() != katago_bin.resolve():
        shutil.copy2(source_bin, katago_bin)
    katago_bin.chmod(0o755)

    subprocess.run(["curl", "-L", KATAGO_MODEL_URL, "-o", str(katago_model)], check=True)
    subprocess.run(["curl", "-L", KATAGO_CONFIG_URL, "-o", str(katago_config)], check=True)
    patch_analysis_config(katago_config)
    return {
        "katago_binary": str(katago_bin),
        "katago_model": str(katago_model),
        "katago_config": str(katago_config),
    }


def patch_analysis_config(config_path: Path) -> None:
    text = config_path.read_text()
    lines = text.splitlines()
    patched_lines = []
    replaced_analysis_threads = False
    replaced_search_threads = False
    replaced_report_winrates = False

    for line in lines:
        stripped = line.strip()
        if stripped.startswith("numAnalysisThreads"):
            patched_lines.append(f"numAnalysisThreads = {KATAGO_NUM_ANALYSIS_THREADS}")
            replaced_analysis_threads = True
            continue
        if stripped.startswith("numSearchThreadsPerAnalysisThread"):
            patched_lines.append(
                "numSearchThreadsPerAnalysisThread = "
                f"{KATAGO_NUM_SEARCH_THREADS_PER_ANALYSIS_THREAD}"
            )
            replaced_search_threads = True
            continue
        if stripped.startswith("reportAnalysisWinratesAs"):
            patched_lines.append(f"reportAnalysisWinratesAs = {KATAGO_REPORT_ANALYSIS_WINRATES_AS}")
            replaced_report_winrates = True
            continue
        patched_lines.append(line)

    if not replaced_analysis_threads:
        patched_lines.append(f"numAnalysisThreads = {KATAGO_NUM_ANALYSIS_THREADS}")
    if not replaced_search_threads:
        patched_lines.append(
            "numSearchThreadsPerAnalysisThread = "
            f"{KATAGO_NUM_SEARCH_THREADS_PER_ANALYSIS_THREAD}"
        )
    if not replaced_report_winrates:
        patched_lines.append(f"reportAnalysisWinratesAs = {KATAGO_REPORT_ANALYSIS_WINRATES_AS}")
    config_path.write_text("\n".join(patched_lines) + "\n")


def reset_output_dir(out_dir: Path) -> None:
    if out_dir.exists():
        for child in out_dir.iterdir():
            if child.is_dir():
                shutil.rmtree(child)
            else:
                child.unlink()
    else:
        out_dir.mkdir(parents=True, exist_ok=True)


def stream_builder_process(cmd: list[str], cwd: Path, out_dir: Path) -> int:
    stdout_path = out_dir / "build_stdout.txt"
    stderr_path = out_dir / "build_stderr.txt"
    selector = selectors.DefaultSelector()
    env = os.environ.copy()
    env["PYTHONUNBUFFERED"] = "1"
    proc = subprocess.Popen(
        cmd,
        cwd=cwd,
        text=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        bufsize=1,
        env=env,
    )
    assert proc.stdout is not None
    assert proc.stderr is not None
    selector.register(proc.stdout, selectors.EVENT_READ, data=("stdout", stdout_path))
    selector.register(proc.stderr, selectors.EVENT_READ, data=("stderr", stderr_path))
    last_commit = time.time()

    with stdout_path.open("w") as stdout_f, stderr_path.open("w") as stderr_f:
        writers = {"stdout": stdout_f, "stderr": stderr_f}
        while selector.get_map():
            events = selector.select(timeout=1.0)
            for key, _mask in events:
                stream_name, _path = key.data
                line = key.fileobj.readline()
                if line == "":
                    selector.unregister(key.fileobj)
                    continue
                writers[stream_name].write(line)
                writers[stream_name].flush()
                print(line, end="")
            if time.time() - last_commit >= VOLUME_COMMIT_EVERY_SEC:
                output_volume.commit()
                last_commit = time.time()
            if proc.poll() is not None and not events:
                break

    returncode = proc.wait()
    output_volume.commit()
    return returncode


@app.function(
    gpu="A10G",
    cpu=8,
    memory=32768,
    timeout=60 * 60 * 8,
    volumes={REMOTE_OUTPUT_DIR: output_volume},
)
def build_large_dataset() -> dict:
    workdir = Path(REMOTE_WORK_DIR)
    workdir.mkdir(parents=True, exist_ok=True)
    out_dir = Path(REMOTE_OUTPUT_DIR) / OUTPUT_STEM
    reset_output_dir(out_dir)

    for filename in ["katago_large_dataset_builder.py", "README.md"]:
        shutil.copy(Path(REMOTE_PROJECT_DIR) / filename, workdir / filename)

    katago_info = install_katago_assets()
    cmd = [
        "python",
        "-u",
        "katago_large_dataset_builder.py",
        "--sgf-dir",
        REMOTE_SGF_DIR,
        "--output-dir",
        str(out_dir),
        "--max-positions",
        DATASET_MAX_POSITIONS,
        "--max-games",
        DATASET_MAX_GAMES,
        "--max-positions-per-game",
        DATASET_MAX_POSITIONS_PER_GAME,
        "--min-comment-chars",
        DATASET_MIN_COMMENT_CHARS,
        "--katago-binary",
        katago_info["katago_binary"],
        "--katago-model",
        katago_info["katago_model"],
        "--katago-config",
        katago_info["katago_config"],
        "--katago-visits",
        KATAGO_ANALYSIS_VISITS,
        "--katago-inflight-limit",
        KATAGO_INFLIGHT_LIMIT,
        "--query-timeout-sec",
        QUERY_TIMEOUT_SEC,
    ]
    returncode = stream_builder_process(cmd, cwd=workdir, out_dir=out_dir)
    if returncode != 0:
        stdout_text = (out_dir / "build_stdout.txt").read_text() if (out_dir / "build_stdout.txt").exists() else ""
        stderr_text = (out_dir / "build_stderr.txt").read_text() if (out_dir / "build_stderr.txt").exists() else ""
        raise RuntimeError(
            f"Large dataset build failed with return code {returncode}\nSTDOUT:\n{stdout_text}\nSTDERR:\n{stderr_text}"
        )

    copied = ["build_stdout.txt", "build_stderr.txt"]
    for filename in ["full.incremental.jsonl", "full.jsonl", "train.jsonl", "eval.jsonl", "dataset_manifest.json"]:
        if (out_dir / filename).exists():
            copied.append(filename)
    manifest = {
        "output_stem": OUTPUT_STEM,
        "files": [f"{OUTPUT_STEM}/{name}" for name in copied],
        "command": cmd,
        "katago_urls": {
            "executable": KATAGO_EXECUTABLE_URL,
            "model": KATAGO_MODEL_URL,
            "config": KATAGO_CONFIG_URL,
        },
        "effective_katago_config": {
            "numAnalysisThreads": KATAGO_NUM_ANALYSIS_THREADS,
            "numSearchThreadsPerAnalysisThread": KATAGO_NUM_SEARCH_THREADS_PER_ANALYSIS_THREAD,
            "reportAnalysisWinratesAs": KATAGO_REPORT_ANALYSIS_WINRATES_AS,
        },
        "returncode": returncode,
    }
    (out_dir / "run_manifest.json").write_text(json.dumps(manifest, indent=2))
    output_volume.commit()
    return manifest


@app.local_entrypoint()
def main() -> None:
    result = build_large_dataset.remote()
    print(json.dumps(result, indent=2))
