from __future__ import annotations

import json
import shutil
import subprocess
from pathlib import Path

import modal


APP_NAME = "fever-scratch-transformer"
VOLUME_NAME = "fever-scratch-results"
REMOTE_PROJECT_DIR = "/root/project/feverfromscratch"
REMOTE_RESULTS_DIR = "/root/results"
DEFAULT_GPU = "A10G"
LOCAL_DIR = Path(__file__).resolve().parent

image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install(
        "torch",
        "transformers",
        "datasets",
        "numpy",
        "pandas",
        "tqdm",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_PROJECT_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
results_volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)


@app.function(
    gpu=DEFAULT_GPU,
    cpu=8,
    memory=32768,
    timeout=60 * 60 * 6,
    volumes={REMOTE_RESULTS_DIR: results_volume},
)
def run_experiment(
    output_stem: str = "fever_scratch_50k_20260428",
) -> dict:
    relative_output_dir = Path(output_stem)
    remote_output_dir = Path(REMOTE_RESULTS_DIR) / relative_output_dir
    remote_output_dir.mkdir(parents=True, exist_ok=True)

    output_csv = f"{output_stem}.csv"
    command = [
        "python",
        "run_fever_scratch_gpu.py",
        "--require-gpu",
        "--train-samples",
        "50000",
        "--eval-samples",
        "5000",
        "--epochs",
        "10",
        "--batch-size",
        "32",
        "--output-csv",
        output_csv,
    ]

    print("Running:", " ".join(command))
    subprocess.run(command, cwd=REMOTE_PROJECT_DIR, check=True)

    generated_files = [
        output_csv,
        output_csv.replace(".csv", ".md"),
        "results_fever_scratch_vs_synthetic.csv",
        "results_fever_scratch_vs_synthetic.md",
    ]
    copied_files: list[str] = []
    for filename in generated_files:
        source = Path(REMOTE_PROJECT_DIR) / filename
        if source.exists():
            destination = remote_output_dir / filename
            shutil.copy2(source, destination)
            copied_files.append(str(relative_output_dir / filename))

    manifest = {
        "output_stem": output_stem,
        "volume_output_dir": str(relative_output_dir),
        "files": copied_files,
        "command": command,
    }
    manifest_path = remote_output_dir / "manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2))
    results_volume.commit()
    return manifest


@app.local_entrypoint()
def main(
    output_stem: str = "fever_scratch_50k_20260428",
    download_to: str = "feverfromscratch/modal_results",
) -> None:
    manifest = run_experiment.remote(output_stem=output_stem)
    print(json.dumps(manifest, indent=2))

    target_dir = Path(download_to).resolve() / output_stem
    target_dir.mkdir(parents=True, exist_ok=True)

    volume_dir = Path(manifest["volume_output_dir"])
    for volume_path in manifest["files"] + [str(volume_dir / "manifest.json")]:
        destination = target_dir / Path(volume_path).name
        destination.write_bytes(b"".join(results_volume.read_file(volume_path)))
        print(f"Downloaded {volume_path} -> {destination}")
