from __future__ import annotations

"""Optional GPU smoke check for inference.

Requires:
- CUDA
- local Qwen3-VL-8B weights available

Run from the submission root directory:
    python scripts/verify_gpu_infer.py --weights_dir /path/to/hf_cache --out_dir /tmp/test
"""

import argparse
import subprocess
import sys
from pathlib import Path

import torch

# Add parent directory to path for relative imports when run as script
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--weights_dir", type=Path, required=True)
    p.add_argument("--out_dir", type=Path, required=True)
    a = p.parse_args()

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA not available")

    # Import check
    from inference.infer_attribution import main as infer_main  # noqa: F401

    cmd = [
        sys.executable,
        str(_ROOT_DIR / "inference" / "infer_attribution.py"),
        "--weights_dir",
        str(a.weights_dir),
        "--image",
        str(_ROOT_DIR / "assets" / "example_image.png"),
        "--question",
        "What is shown?",
        "--use_uniform_theta",
        "--out_dir",
        str(a.out_dir),
        "--source_mode",
        "block",
    ]
    r = subprocess.run(cmd, check=False)
    if r.returncode != 0:
        raise RuntimeError("Inference command failed")

    for name in ["data.json", "image.png", "attribution.png"]:
        pth = a.out_dir / name
        if not pth.exists():
            raise FileNotFoundError(f"Missing output: {pth}")

    print("OK")


if __name__ == "__main__":
    main()
