#!/usr/bin/env python
"""
One-stop installer for ImpuGen
-------------------------------------------------------
Creates a conda environment, installs the proper PyTorch wheel with a
visible progress bar, installs remaining dependencies, and finally
performs an editable install of the ImpuGen source.

Example:
    python install.py --python 3.12 --cp cu126 --name impugen_aaai
"""
from __future__ import annotations

import argparse
import subprocess
import sys
from pathlib import Path


# ────────────────────────── helper functions ─────────────────────────
def run(cmd: list[str]) -> None:
    """Run a shell command and exit on failure."""
    print(f"▶ {' '.join(cmd)}")
    subprocess.check_call(cmd)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="ImpuGen one-line installer")
    parser.add_argument("--python", default="3.12", help="Python version (e.g. 3.10, 3.11, 3.12)")
    parser.add_argument(
        "--cp",
        required=True,
        choices=["cu118", "cu126", "cu128", "cpu"],
        help="Compute platform tag (CUDA wheels or CPU wheel)",
    )
    parser.add_argument("--name", default="impugen_aaai", help="Conda env name")
    return parser.parse_args()


# ────────────────────────────── main ─────────────────────────────────
def main() -> None:
    args = parse_args()
    env_name: str = args.name
    python_version: str = args.python
    cp_tag: str = args.cp

    repo_root = Path(__file__).resolve().parent
    req_file = repo_root / "requirements.txt"
    if not req_file.exists():
        sys.exit("❌  requirements.txt not found – place it next to install.py.")

    # 1. Create conda environment
    run(["conda", "create", "-y", "-n", env_name, f"python={python_version}", "pip"])

    # 1-a. Locate the python executable inside the new env
    env_python = subprocess.check_output(
        ["conda", "run", "--no-capture-output", "-n", env_name,
         "python", "-c", "import sys, json; print(sys.executable)"],
        text=True,
    ).strip()
    print(f"🟢  Using env-python: {env_python}")

    def run_in_env(py_args: list[str]) -> None:
        """Invoke commands with the environment's python (streams progress)."""
        run([env_python, *py_args])

    # 2. Install PyTorch wheel matching the compute platform
    index_url = f"https://download.pytorch.org/whl/{cp_tag}"
    torch_args = [
        "-m", "pip", "install", "-v", "--progress-bar", "on",
        "torch", "torchvision", "torchaudio",
    ]
    if "cu" in cp_tag:                     # CUDA wheel
        torch_args.extend(["--index-url", index_url])

    run_in_env(torch_args)

    # 3. Install remaining dependencies
    run_in_env(["-m", "pip", "install", "-v", "--progress-bar", "on",
                "-r", str(req_file)])

    run_in_env(['script_sklearn.py'])
    run_in_env(['script_tabsyn_dataset.py'])
    run_in_env(['script_uci_dataset.py'])

    # 5. Done
    print(
        f"\n🎉  Installation complete!\n"
        f"👉  Activate with:   conda activate {env_name}\n"
    )


if __name__ == "__main__":
    main()
