from __future__ import annotations

import json
import os
import shutil
import subprocess
from pathlib import Path

import modal

APP_NAME = "generated-rationale-scalar-run"
OUTPUT_STEM = os.environ.get("OUTPUT_STEM", "generated_rationale_scalar_20260429")
VOLUME_DIR = "/outputs"
LOCAL_DIR = Path(__file__).resolve().parent

app = modal.App(APP_NAME)
volume = modal.Volume.from_name("generated-rationale-scalar-results", create_if_missing=True)

image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install(
        "torch",
        "numpy",
        "pandas",
        "tabulate",
    )
    .add_local_file(
        LOCAL_DIR / "generated_rationale_scalar_verifier_experiment.py",
        remote_path="/root/generated_rationale_scalar_verifier_experiment.py",
    )
    .add_local_file(
        LOCAL_DIR / "run_generated_rationale_scalar_gpu.py",
        remote_path="/root/run_generated_rationale_scalar_gpu.py",
    )
)


def build_command() -> list[str]:
    return [
        "python",
        "run_generated_rationale_scalar_gpu.py",
        "--require-gpu",
        "--num-train", "2048",
        "--num-eval", "512",
        "--batch-size", "64",
        "--epochs", "20",
        "--lr", "3e-4",
        "--d-model", "128",
        "--n-layers", "2",
        "--n-heads", "4",
        "--d-ff", "256",
        "--consistency-weight", "0.5",
        "--output-csv", f"{OUTPUT_STEM}.csv",
    ]


@app.function(
    image=image,
    gpu="A10G",
    timeout=60 * 60 * 4,
    volumes={VOLUME_DIR: volume},
)
def run_job():
    workdir = Path("/root/work")
    workdir.mkdir(parents=True, exist_ok=True)

    src_files = [
        Path("/root/generated_rationale_scalar_verifier_experiment.py"),
        Path("/root/run_generated_rationale_scalar_gpu.py"),
    ]
    for src in src_files:
        if not src.exists():
            raise FileNotFoundError(f"Missing required file: {src}")
        shutil.copy(src, workdir / src.name)

    cmd = build_command()
    proc = subprocess.run(cmd, cwd=workdir, text=True, capture_output=True)

    out_dir = Path(VOLUME_DIR) / OUTPUT_STEM
    out_dir.mkdir(parents=True, exist_ok=True)

    (out_dir / "stdout.txt").write_text(proc.stdout)
    (out_dir / "stderr.txt").write_text(proc.stderr)

    csv_name = f"{OUTPUT_STEM}.csv"
    md_name = f"{OUTPUT_STEM}.md"
    for name in [csv_name, md_name]:
        src = workdir / name
        if src.exists():
            shutil.copy(src, out_dir / name)

    manifest = {
        "output_stem": OUTPUT_STEM,
        "volume_output_dir": OUTPUT_STEM,
        "files": [
            f"{OUTPUT_STEM}/{csv_name}",
            f"{OUTPUT_STEM}/{md_name}",
            f"{OUTPUT_STEM}/stdout.txt",
            f"{OUTPUT_STEM}/stderr.txt",
        ],
        "command": cmd,
        "returncode": proc.returncode,
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
    volume.commit()

    if proc.returncode != 0:
        raise RuntimeError(
            f"Job failed with return code {proc.returncode}\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
        )

    return manifest


@app.local_entrypoint()
def main():
    manifest = run_job.remote()
    print(json.dumps(manifest, indent=2))
