#!/usr/bin/env python3
"""
Batch run di.py on the JSONL shards produced by the auto_split_parquets.py script.

For each dataset subfolder under:
    {data_dir}/processed/{model_name_safe}/<dataset>/
it will run di.py on the following files:
  - train:  inference-train.jsonl
  - test:   inference-test.jsonl
  - val:    val.jsonl

Files named fine-tune-*.jsonl are ignored.
"""
import os
import argparse
import subprocess
import random
from pathlib import Path

def main():
    parser = argparse.ArgumentParser(
        description="Batch run di.py on inference and val shards"
    )
    parser.add_argument(
        "--data_dir", type=str, required=True,
        help="Root directory containing 'processed/{model_name_safe}/{dataset}/' subfolders"
    )
    parser.add_argument(
        "--model_name", type=str, default="EleutherAI/pythia-410m-deduped",
        help="Name of the model (used to form subpath under data_dir/processed)"
    )
    parser.add_argument(
        "--result_output", type=str, required=True,
        help="Directory to save di.py results"
    )
    parser.add_argument(
        "--cache_dir", type=str, default="~/.cache",
        help="HuggingFace cache directory"
    )
    parser.add_argument(
        "--max-val-train", action="store_true",
        help="If set, use train_full.jsonl, train_val.jsonl, test.jsonl, test_val.jsonl and run train_full vs train_val and test vs test_val."
    )
    parser.add_argument(
        "--text_column", type=str, default="all",
        help="The column name in the dataset to use as text input"
    )
    parser.add_argument(
        "--dataset_name", type=str, default="all",
        help="Name of the dataset to process (default: all)"
    )
    parser.add_argument(
        "--split_name", type=str, default="all",
        help="Which split to process: all/train/val/test (default: all)"
    )
    args = parser.parse_args()

    # Make model safe for filesystem
    model_safe = args.model_name.replace("/", "_")

    dataset_folders = []
    # Discover dataset subfolders automatically
    if args.dataset_name == "all":
        dataset_folders = [
            p for p in Path(args.data_dir).iterdir() if p.is_dir()
        ]
        if not dataset_folders:
            raise SystemExit(f"No dataset subdirectories found under {args.data_dir}")
        random.shuffle(dataset_folders)
    else:
        ds_path = Path(args.data_dir) / args.dataset_name
        if not ds_path.is_dir():
            raise SystemExit(f"Dataset directory not found: {ds_path}")
        dataset_folders = [ds_path]

    for ds_folder in dataset_folders:
        dataset_name = ds_folder.name  # e.g., "arxiv", "wikipedia", etc.
        print(f"Processing dataset: {dataset_name}", flush=True)

        if args.max_val_train:
            split_paths = [
                ("train", ds_folder / "train_full.jsonl"),
                ("val", ds_folder / "train_val.jsonl"),
            ]
        else:
            split_paths = [
                ("train", ds_folder / "train.jsonl"),
                ("test", ds_folder / "test.jsonl"),
                ("val", ds_folder / "val.jsonl")
            ]

        # Filter split_paths if split_name is set
        if args.split_name != "all":
            split_paths = [sp for sp in split_paths if sp[0] == args.split_name]
            if not split_paths:
                print(f"  \u26a0\ufe0f  Split '{args.split_name}' not found for dataset '{dataset_name}', skipping.")
                continue

        missing = False
        for split_name, p in split_paths:
            if not p.exists():
                print(f"  \u26a0\ufe0f  Missing {split_name} file at {p}, skipping dataset.")
                missing = True
        if missing:
            continue

        output_dir = Path(args.result_output) / model_safe / dataset_name
        output_dir.mkdir(parents=True, exist_ok=True)

        for split_name, dataset_path in split_paths:
            result_file = f"{split_name}_metrics.json"
            cmd = [
                "python3", "raw_values.py",
                "--model_name", args.model_name,
                "--dataset_path", str(dataset_path),
                "--output_dir", str(output_dir),
                "--result_file_name", result_file,
                "--cache_dir", args.cache_dir,
                "--text_column", args.text_column
            ]
            print(f"  Running: {' '.join(cmd)}", flush=True)
            subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()
