from __future__ import annotations

import json
import os
import shutil
import subprocess
from pathlib import Path

import modal


APP_NAME = "katago-winprob-experiment"
OUTPUT_STEM = os.environ.get("OUTPUT_STEM", "katago_winprob_20260429")
OUTPUT_VOLUME_NAME = os.environ.get("OUTPUT_VOLUME_NAME", "katago-winprob-results")
REMOTE_OUTPUT_DIR = "/outputs"
REMOTE_PROJECT_DIR = "/root/project"
REMOTE_SGF_DIR = "/root/gtlreviews"
REMOTE_WORK_DIR = "/root/work"
REMOTE_KATAGO_DIR = "/root/katago"

LOCAL_DIR = Path(__file__).resolve().parent
LOCAL_SGF_DIR = LOCAL_DIR.parent / "gtlreviews"

KATAGO_EXECUTABLE_URL = os.environ.get(
    "KATAGO_EXECUTABLE_URL",
    "https://github.com/lightvector/KataGo/releases/download/v1.16.4/katago-v1.16.4-cuda12.8-cudnn9.8.0-linux-x64.zip",
)
KATAGO_MODEL_URL = os.environ.get(
    "KATAGO_MODEL_URL",
    "https://media.katagotraining.org/uploaded/networks/models/kata1/kata1-b28c512nbt-s12434156288-d5719330235.bin.gz",
)
KATAGO_CONFIG_URL = os.environ.get(
    "KATAGO_CONFIG_URL",
    "https://raw.githubusercontent.com/lightvector/KataGo/master/cpp/configs/analysis_example.cfg",
)
KATAGO_ANALYSIS_VISITS = os.environ.get("KATAGO_ANALYSIS_VISITS", "1000")
KATAGO_SAMPLE_EVERY_N_MOVES = os.environ.get("KATAGO_SAMPLE_EVERY_N_MOVES", "40")
KATAGO_MAX_POSITIONS_PER_GAME = os.environ.get("KATAGO_MAX_POSITIONS_PER_GAME", "4")
KATAGO_MAX_GAMES = os.environ.get("KATAGO_MAX_GAMES", "1000")
KATAGO_DATASET_PREFIX = os.environ.get("KATAGO_DATASET_PREFIX", OUTPUT_STEM)
KATAGO_SPLIT_SEED = os.environ.get("KATAGO_SPLIT_SEED", "42")
RUN_TRAINING = os.environ.get("RUN_TRAINING", "1") == "1"

DEFAULT_VARIANTS = os.environ.get(
    "KATAGO_VARIANTS",
    "lm_only no_consistency_loss rationale_only full_consistency random_consistency",
)
DEFAULT_BATCH_SIZE = os.environ.get("BATCH_SIZE", "32")
DEFAULT_EPOCHS = os.environ.get("EPOCHS", "10")
DEFAULT_LR = os.environ.get("LR", "3e-4")
DEFAULT_D_MODEL = os.environ.get("D_MODEL", "256")
DEFAULT_N_LAYERS = os.environ.get("N_LAYERS", "4")
DEFAULT_N_HEADS = os.environ.get("N_HEADS", "8")
DEFAULT_D_FF = os.environ.get("D_FF", "1024")
DEFAULT_MAX_SEQ_LEN = os.environ.get("MAX_SEQ_LEN", "256")
DEFAULT_MAX_POSITION_TOKENS = os.environ.get("MAX_POSITION_TOKENS", "196")
DEFAULT_CONSISTENCY_WEIGHT = os.environ.get("CONSISTENCY_WEIGHT", "0.5")

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("curl", "unzip", "xz-utils")
    .pip_install(
        "torch",
        "numpy",
        "pandas",
        "sgfmill",
        "nvidia-cublas-cu12",
        "nvidia-cuda-runtime-cu12",
        "nvidia-cudnn-cu12",
        "nvidia-cusolver-cu12",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_PROJECT_DIR, copy=True)
    .add_local_dir(LOCAL_SGF_DIR, remote_path=REMOTE_SGF_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
output_volume = modal.Volume.from_name(OUTPUT_VOLUME_NAME, create_if_missing=True)


def build_training_command(train_dataset: Path, eval_dataset: Path) -> list[str]:
    command = [
        "python",
        "run_katago_winprob_gpu.py",
        "--require-gpu",
        "--train-path",
        str(train_dataset),
        "--eval-path",
        str(eval_dataset),
        "--output-csv",
        f"{OUTPUT_STEM}.csv",
        "--batch-size",
        DEFAULT_BATCH_SIZE,
        "--epochs",
        DEFAULT_EPOCHS,
        "--lr",
        DEFAULT_LR,
        "--d-model",
        DEFAULT_D_MODEL,
        "--n-layers",
        DEFAULT_N_LAYERS,
        "--n-heads",
        DEFAULT_N_HEADS,
        "--d-ff",
        DEFAULT_D_FF,
        "--max-seq-len",
        DEFAULT_MAX_SEQ_LEN,
        "--max-position-tokens",
        DEFAULT_MAX_POSITION_TOKENS,
        "--consistency-weight",
        DEFAULT_CONSISTENCY_WEIGHT,
    ]
    variants = [tok for tok in DEFAULT_VARIANTS.split() if tok]
    if variants:
        command.extend(["--variants", *variants])
    return command


def katago_paths() -> tuple[Path, Path, Path]:
    return (
        Path(REMOTE_KATAGO_DIR) / "katago",
        Path(REMOTE_KATAGO_DIR) / "model.bin.gz",
        Path(REMOTE_KATAGO_DIR) / "analysis.cfg",
    )


def install_katago_assets() -> dict[str, str]:
    if not (KATAGO_EXECUTABLE_URL and KATAGO_MODEL_URL and KATAGO_CONFIG_URL):
        raise RuntimeError(
            "Set KATAGO_EXECUTABLE_URL, KATAGO_MODEL_URL, and KATAGO_CONFIG_URL before running the Modal pipeline."
        )

    katago_bin, katago_model, katago_config = katago_paths()
    katago_root = katago_bin.parent
    katago_root.mkdir(parents=True, exist_ok=True)

    archive_path = katago_root / "katago_download"
    subprocess.run(["curl", "-L", KATAGO_EXECUTABLE_URL, "-o", str(archive_path)], check=True)

    lower_url = KATAGO_EXECUTABLE_URL.lower()
    if lower_url.endswith(".zip"):
        subprocess.run(["unzip", "-o", str(archive_path), "-d", str(katago_root)], check=True)
    elif lower_url.endswith(".tar.gz") or lower_url.endswith(".tgz"):
        subprocess.run(["tar", "-xzf", str(archive_path), "-C", str(katago_root)], check=True)
    elif lower_url.endswith(".tar.xz") or lower_url.endswith(".txz") or lower_url.endswith(".xz"):
        subprocess.run(["tar", "-xJf", str(archive_path), "-C", str(katago_root)], check=True)
    else:
        raise RuntimeError(f"Unsupported KataGo executable archive format: {KATAGO_EXECUTABLE_URL}")

    found_bins = list(katago_root.rglob("katago"))
    if not found_bins:
        raise RuntimeError("Could not find `katago` binary after extracting the archive")
    source_bin = found_bins[0]
    if source_bin.resolve() != katago_bin.resolve():
        shutil.copy2(source_bin, katago_bin)
    katago_bin.chmod(0o755)

    subprocess.run(["curl", "-L", KATAGO_MODEL_URL, "-o", str(katago_model)], check=True)
    subprocess.run(["curl", "-L", KATAGO_CONFIG_URL, "-o", str(katago_config)], check=True)

    return {
        "katago_binary": str(katago_bin),
        "katago_model": str(katago_model),
        "katago_config": str(katago_config),
    }


def write_split_dataset(full_dataset: Path, train_out: Path, eval_out: Path, split_seed: int) -> dict[str, int]:
    split_code = f"""
import json
import random
from pathlib import Path

src = Path({str(full_dataset)!r})
rows = [json.loads(line) for line in src.read_text().splitlines() if line.strip()]
random.Random({split_seed}).shuffle(rows)
split = int(len(rows) * 0.9)
train_rows = rows[:split]
eval_rows = rows[split:]

for path, chunk in [(Path({str(train_out)!r}), train_rows), (Path({str(eval_out)!r}), eval_rows)]:
    with path.open("w") as f:
        for row in chunk:
            f.write(json.dumps(row) + "\\n")

print(json.dumps({{"train_rows": len(train_rows), "eval_rows": len(eval_rows)}}))
"""
    proc = subprocess.run(["python", "-c", split_code], text=True, capture_output=True, check=True)
    return json.loads(proc.stdout.strip())


def run_katago_smoke_check(workdir: Path, katago_info: dict[str, str]) -> dict:
    smoke_code = f"""
import json
from katago_winprob_experiment import katago_analysis_smoke_check

result = katago_analysis_smoke_check(
    binary_path={katago_info["katago_binary"]!r},
    model_path={katago_info["katago_model"]!r},
    config_path={katago_info["katago_config"]!r},
    visits=8,
)
print(json.dumps(result))
"""
    proc = subprocess.run(["python", "-c", smoke_code], cwd=workdir, text=True, capture_output=True)
    if proc.returncode != 0:
        raise RuntimeError(
            f"KataGo smoke check failed with return code {proc.returncode}\n"
            f"STDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
        )
    return {
        "stdout": proc.stdout,
        "stderr": proc.stderr,
        "result": json.loads(proc.stdout.strip()),
    }


def collect_katago_debug_info(workdir: Path, katago_info: dict[str, str]) -> dict:
    debug_code = f"""
import json
import os
import site
import subprocess
import sys
from pathlib import Path

binary_path = Path({katago_info["katago_binary"]!r})
search_roots = []
seen = set()
for root in [Path(sys.prefix), *[Path(p) for p in sys.path if p], *[Path(p) for p in site.getsitepackages()]]:
    try:
        resolved = str(root.resolve())
    except Exception:
        resolved = str(root)
    if resolved not in seen:
        seen.add(resolved)
        search_roots.append(str(root))

candidate_dirs = []
seen_dirs = set()
for root_str in search_roots:
    root = Path(root_str)
    if not root.exists():
        continue
    for candidate in [root / "lib", *root.glob("torch/lib"), *root.glob("nvidia/*/lib")]:
        if candidate.is_dir():
            key = str(candidate)
            if key not in seen_dirs:
                seen_dirs.add(key)
                candidate_dirs.append(candidate)

interesting = {{}}
for directory in candidate_dirs:
    found = []
    for pattern in ["libcublas.so*", "libcudnn.so*", "libcudart.so*", "libcusolver.so*", "libcublasLt.so*"]:
        found.extend(sorted(str(path.name) for path in directory.glob(pattern)))
    if found:
        interesting[str(directory)] = found

ldd_proc = subprocess.run(["ldd", str(binary_path)], text=True, capture_output=True)

print(json.dumps({{
    "python_executable": sys.executable,
    "sys_prefix": sys.prefix,
    "ld_library_path": os.environ.get("LD_LIBRARY_PATH", ""),
    "search_roots": search_roots,
    "interesting_lib_dirs": interesting,
    "ldd_stdout": ldd_proc.stdout,
    "ldd_stderr": ldd_proc.stderr,
    "ldd_returncode": ldd_proc.returncode,
}}, indent=2))
"""
    proc = subprocess.run(["python", "-c", debug_code], cwd=workdir, text=True, capture_output=True, check=True)
    return {
        "stdout": proc.stdout,
        "stderr": proc.stderr,
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=32768,
    timeout=60 * 60 * 8,
    volumes={REMOTE_OUTPUT_DIR: output_volume},
)
def run_pipeline() -> dict:
    workdir = Path(REMOTE_WORK_DIR)
    workdir.mkdir(parents=True, exist_ok=True)
    out_dir = Path(REMOTE_OUTPUT_DIR) / OUTPUT_STEM
    out_dir.mkdir(parents=True, exist_ok=True)

    for filename in [
        "katago_winprob_experiment.py",
        "run_katago_winprob_gpu.py",
        "README.md",
    ]:
        shutil.copy(Path(REMOTE_PROJECT_DIR) / filename, workdir / filename)

    katago_info = install_katago_assets()
    debug_info = collect_katago_debug_info(workdir=workdir, katago_info=katago_info)
    (out_dir / "katago_debug_stdout.txt").write_text(debug_info["stdout"])
    (out_dir / "katago_debug_stderr.txt").write_text(debug_info["stderr"])
    output_volume.commit()

    try:
        smoke_info = run_katago_smoke_check(workdir=workdir, katago_info=katago_info)
    except Exception as exc:
        (out_dir / "katago_smoke_failure.txt").write_text(str(exc))
        output_volume.commit()
        raise

    dataset_prefix = KATAGO_DATASET_PREFIX
    full_dataset = workdir / f"{dataset_prefix}_full.jsonl"
    train_dataset = workdir / f"{dataset_prefix}_train.jsonl"
    eval_dataset = workdir / f"{dataset_prefix}_eval.jsonl"

    preprocess_cmd = [
        "python",
        "katago_winprob_experiment.py",
        "preprocess-sgf",
        "--sgf-dir",
        REMOTE_SGF_DIR,
        "--output-path",
        str(full_dataset),
        "--sample-every-n-moves",
        KATAGO_SAMPLE_EVERY_N_MOVES,
        "--max-positions-per-game",
        KATAGO_MAX_POSITIONS_PER_GAME,
        "--max-games",
        KATAGO_MAX_GAMES,
        "--katago-binary",
        katago_info["katago_binary"],
        "--katago-model",
        katago_info["katago_model"],
        "--katago-config",
        katago_info["katago_config"],
        "--katago-visits",
        KATAGO_ANALYSIS_VISITS,
    ]

    preprocess_proc = subprocess.run(preprocess_cmd, cwd=workdir, text=True, capture_output=True)
    if preprocess_proc.returncode != 0:
        raise RuntimeError(
            f"Preprocessing failed with return code {preprocess_proc.returncode}\n"
            f"STDOUT:\n{preprocess_proc.stdout}\nSTDERR:\n{preprocess_proc.stderr}"
        )

    split_info = write_split_dataset(
        full_dataset=full_dataset,
        train_out=train_dataset,
        eval_out=eval_dataset,
        split_seed=int(KATAGO_SPLIT_SEED),
    )

    training_proc = None
    training_command: list[str] = []
    if RUN_TRAINING:
        training_command = build_training_command(train_dataset=train_dataset, eval_dataset=eval_dataset)
        training_proc = subprocess.run(training_command, cwd=workdir, text=True, capture_output=True)
        if training_proc.returncode != 0:
            raise RuntimeError(
                f"Training failed with return code {training_proc.returncode}\n"
                f"STDOUT:\n{training_proc.stdout}\nSTDERR:\n{training_proc.stderr}"
            )

    (out_dir / "preprocess_stdout.txt").write_text(preprocess_proc.stdout)
    (out_dir / "preprocess_stderr.txt").write_text(preprocess_proc.stderr)
    (out_dir / "katago_smoke_stdout.txt").write_text(smoke_info["stdout"])
    (out_dir / "katago_smoke_stderr.txt").write_text(smoke_info["stderr"])
    shutil.copy2(full_dataset, out_dir / full_dataset.name)
    shutil.copy2(train_dataset, out_dir / train_dataset.name)
    shutil.copy2(eval_dataset, out_dir / eval_dataset.name)

    copied_files = [
        f"{OUTPUT_STEM}/{full_dataset.name}",
        f"{OUTPUT_STEM}/{train_dataset.name}",
        f"{OUTPUT_STEM}/{eval_dataset.name}",
        f"{OUTPUT_STEM}/preprocess_stdout.txt",
        f"{OUTPUT_STEM}/preprocess_stderr.txt",
        f"{OUTPUT_STEM}/katago_smoke_stdout.txt",
        f"{OUTPUT_STEM}/katago_smoke_stderr.txt",
    ]

    if training_proc is not None:
        (out_dir / "stdout.txt").write_text(training_proc.stdout)
        (out_dir / "stderr.txt").write_text(training_proc.stderr)
        for filename in [f"{OUTPUT_STEM}.csv", f"{OUTPUT_STEM}.md"]:
            src = workdir / filename
            if src.exists():
                shutil.copy2(src, out_dir / filename)
                copied_files.append(f"{OUTPUT_STEM}/{filename}")
        copied_files.extend([f"{OUTPUT_STEM}/stdout.txt", f"{OUTPUT_STEM}/stderr.txt"])

    manifest = {
        "output_stem": OUTPUT_STEM,
        "volume_output_dir": OUTPUT_STEM,
        "files": copied_files,
        "preprocess_command": preprocess_cmd,
        "training_command": training_command,
        "preprocess_returncode": preprocess_proc.returncode,
        "training_returncode": None if training_proc is None else training_proc.returncode,
        "split_info": split_info,
        "katago_urls": {
            "executable": KATAGO_EXECUTABLE_URL,
            "model": KATAGO_MODEL_URL,
            "config": KATAGO_CONFIG_URL,
        },
        "katago_smoke_result": smoke_info["result"],
        "source_sgf_dir": REMOTE_SGF_DIR,
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
    output_volume.commit()
    return manifest


@app.local_entrypoint()
def main() -> None:
    manifest = run_pipeline.remote()
    print(json.dumps(manifest, indent=2))
