#!/usr/bin/env python3
"""
Batch run di.py on the JSONL shards produced by the auto_split_parquets.py script.

For each dataset subfolder under:
    {data_dir}/processed/{model_name_safe}/<dataset>/
it will run di.py on the following files:
  - train:  inference-train.jsonl
  - test:   inference-test.jsonl
  - val:    val.jsonl

Files named fine-tune-*.jsonl are ignored.
"""
import os
import argparse
import subprocess
from pathlib import Path

def main():
    parser = argparse.ArgumentParser(
        description="Batch run di.py on inference and val shards"
    )
    parser.add_argument(
        "--data_dir", type=str, required=True,
        help="Root directory containing 'processed/{model_name_safe}/{dataset}/' subfolders"
    )
    parser.add_argument(
        "--model_name", type=str, default="EleutherAI/pythia-410m-deduped",
        help="Name of the model (used to form subpath under data_dir/processed)"
    )
    parser.add_argument(
        "--reference_model_names", type=str, nargs='*', default=[],
        help="List of reference model names (used to form subpath under result_output)"
    )
    parser.add_argument(
        "--result_output", type=str, required=True,
        help="Directory to save di.py results"
    )
    parser.add_argument(
        "--cache_dir", type=str, default="~/.cache",
        help="HuggingFace cache directory"
    )
    parser.add_argument(
        "--max-val-train", action="store_true",
        help="If set, use train_full.jsonl, train_val.jsonl, test.jsonl, test_val.jsonl and run train_full vs train_val and test vs test_val."
    )
    parser.add_argument(
        "--loss_estimation_method", type=str, default="raw",
        help="Method for loss estimation (raw, ref, sigmoid)"
    )
    args = parser.parse_args()

    # Make model safe for filesystem
    model_safe = args.model_name.replace("/", "_")
    reference_model_safes = [name.replace("/", "_") for name in args.reference_model_names]

    # Discover dataset subfolders automatically
    dataset_root = Path(args.data_dir) / "full_split_data"
    dataset_folders = [
        p for p in dataset_root.iterdir() if p.is_dir()
    ]
    if not dataset_folders:
        raise SystemExit(f"No dataset subdirectories found under {dataset_root}")

    for ds_folder in dataset_folders:
        dataset_name = ds_folder.name  # e.g., "arxiv", "wikipedia", etc.
        print(f"Processing dataset: {dataset_name}", flush=True)

        if args.max_val_train:
            split_paths = [
                ("train", ds_folder / "train_full.jsonl"),
                ("val", ds_folder / "train_val.jsonl"),
            ]
        else:
            split_paths = [
                ("train", ds_folder / "train.jsonl"),
                ("test", ds_folder / "test.jsonl"),
                ("val", ds_folder / "val.jsonl")
            ]

        missing = False
        for split_name, p in split_paths:
            if not p.exists():
                print(f"  \u26a0\ufe0f  Missing {split_name} file at {p}, skipping dataset.")
                missing = True
        if missing:
            continue

        output_dir = Path(args.result_output) / model_safe / dataset_name
        output_dir.mkdir(parents=True, exist_ok=True)

        raw_values_root = Path(args.data_dir) / "raw_values"

        for split_name, dataset_path in split_paths:
            result_file = f"{split_name}_metrics.json"

            # Raw values path for main model
            raw_values_path = (
                raw_values_root / model_safe / dataset_name / f"{split_name}_metrics.json"
            )

            # Collect reference metrics paths for this split
            reference_metrics_paths = []
            for ref_model_safe in reference_model_safes:
                ref_metrics_path = (
                    raw_values_root / ref_model_safe / dataset_name / f"{split_name}_metrics.json"
                )
                reference_metrics_paths.append(str(ref_metrics_path))

            cmd = [
                "python3", "di.py",
                "--dataset_path", str(dataset_path),
                "--output_dir", str(output_dir),
                "--result_file_name", result_file,
                "--cache_dir", args.cache_dir,
                "--raw_values_path", str(raw_values_path),
                "--loss_estimation_method", args.loss_estimation_method
            ]
            if reference_metrics_paths:
                cmd += ["--reference_models_metrics_path"] + reference_metrics_paths

            print(f"  Running: {' '.join(cmd)}", flush=True)
            subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()
