from __future__ import annotations

import json
import shutil
import subprocess
import sys
from pathlib import Path

import modal

APP_NAME = "gpufever50k-fever-pretrained"
VOLUME_NAME = "gpufever50k-fever-results"
REMOTE_PROJECT_DIR = "/root/project/gpufever50k"
REMOTE_RESULTS_DIR = "/root/results"
DEFAULT_GPU = "A100-40GB"
LOCAL_DIR = Path(__file__).resolve().parent

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git")
    .pip_install("torch", "transformers>=4.41", "datasets>=2.19", "pandas>=2.2", "numpy>=1.26", "accelerate>=0.30", "sentencepiece", "protobuf")
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_PROJECT_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
results_volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)


def _build_command(*, model_name: str, train_samples: int, eval_samples: int, max_seq_len: int, epochs: int, batch_size: int, lr: float, consistency_loss_weight: float, freeze_lower_layers_epochs: int, seed: int, smoke_test: bool, output_csv: str, variants: list[str]) -> list[str]:
    command = [
        sys.executable, "run_fever_pretrained_gpu.py", "--require-gpu", "--model-name", model_name,
        "--train-samples", str(train_samples), "--eval-samples", str(eval_samples), "--max-seq-len", str(max_seq_len),
        "--epochs", str(epochs), "--batch-size", str(batch_size), "--lr", str(lr),
        "--consistency-loss-weight", str(consistency_loss_weight), "--freeze-lower-layers-epochs", str(freeze_lower_layers_epochs),
        "--seed", str(seed), "--output-csv", output_csv, "--variants", *variants,
    ]
    if smoke_test:
        command.append("--smoke-test")
    return command


@app.function(gpu=DEFAULT_GPU, cpu=8, memory=32768, timeout=60 * 60 * 24, volumes={REMOTE_RESULTS_DIR: results_volume})
def run_experiment(model_name: str = "gpt2", train_samples: int = 50_000, eval_samples: int = 5_000, max_seq_len: int = 256, epochs: int = 5, batch_size: int = 16, lr: float = 5e-5, consistency_loss_weight: float = 0.5, freeze_lower_layers_epochs: int = 1, seed: int = 42, smoke_test: bool = False, output_stem: str = "results_fever_pretrained_gpu", variants: list[str] = ["no_consistency_loss", "evidence_only_pooling", "evidence_only_strict", "full_sequence_pooling", "claim_only_pooling", "evidence_only_random_labels"]) -> dict:
    relative_output_dir = Path(output_stem)
    remote_output_dir = Path(REMOTE_RESULTS_DIR) / relative_output_dir
    remote_output_dir.mkdir(parents=True, exist_ok=True)
    output_csv = f"{output_stem}.csv"
    command = _build_command(model_name=model_name, train_samples=train_samples, eval_samples=eval_samples, max_seq_len=max_seq_len, epochs=epochs, batch_size=batch_size, lr=lr, consistency_loss_weight=consistency_loss_weight, freeze_lower_layers_epochs=freeze_lower_layers_epochs, seed=seed, smoke_test=smoke_test, output_csv=output_csv, variants=variants)
    print("Running:", " ".join(command))
    subprocess.run(command, cwd=REMOTE_PROJECT_DIR, check=True)
    generated_files = [output_csv, output_csv.replace(".csv", ".md"), "results_fever_pretrained_vs_synthetic.csv", "results_fever_pretrained_vs_synthetic.md"]
    copied_files = []
    for filename in generated_files:
        source = Path(REMOTE_PROJECT_DIR) / filename
        if source.exists():
            destination = remote_output_dir / filename
            shutil.copy2(source, destination)
            copied_files.append(str(relative_output_dir / filename))
    manifest = {"output_stem": output_stem, "volume_output_dir": str(relative_output_dir), "files": copied_files, "command": command, "smoke_test": smoke_test}
    manifest_path = remote_output_dir / "manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2))
    results_volume.commit()
    return manifest


@app.local_entrypoint()
def main(model_name: str = "gpt2", train_samples: int = 50_000, eval_samples: int = 5_000, max_seq_len: int = 256, epochs: int = 5, batch_size: int = 16, lr: float = 5e-5, consistency_loss_weight: float = 0.5, freeze_lower_layers_epochs: int = 1, seed: int = 42, smoke_test: bool = False, output_stem: str = "results_fever_pretrained_gpu", variants: str = "no_consistency_loss,evidence_only_pooling,evidence_only_strict,full_sequence_pooling,claim_only_pooling,evidence_only_random_labels", download_to: str = "gpufever50k/modal_results") -> None:
    selected_variants = [v.strip() for v in variants.split(",") if v.strip()]
    manifest = run_experiment.remote(model_name=model_name, train_samples=train_samples, eval_samples=eval_samples, max_seq_len=max_seq_len, epochs=epochs, batch_size=batch_size, lr=lr, consistency_loss_weight=consistency_loss_weight, freeze_lower_layers_epochs=freeze_lower_layers_epochs, seed=seed, smoke_test=smoke_test, output_stem=output_stem, variants=selected_variants)
    print(json.dumps(manifest, indent=2))
    target_dir = Path(download_to).resolve() / output_stem
    target_dir.mkdir(parents=True, exist_ok=True)
    volume_dir = Path(manifest["volume_output_dir"])
    for volume_path in manifest["files"] + [str(volume_dir / "manifest.json")]:
        destination = target_dir / Path(volume_path).name
        destination.write_bytes(b"".join(results_volume.read_file(volume_path)))
        print(f"Downloaded {volume_path} -> {destination}")
