import os
import argparse
import subprocess
import json
from pathlib import Path

def get_num_samples(metrics_path):
    if not os.path.exists(metrics_path):
        return 0
    with open(metrics_path, "r") as f:
        metrics = json.load(f)
    return len(metrics.get("ppl", []))

def main():
    parser = argparse.ArgumentParser(description="Batch run linear_di.py for each dataset subset.")
    parser.add_argument('--results_dir', type=str, required=True, help='Directory containing the metrics files')
    parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output results')
    parser.add_argument('--model_name', type=str, default='EleutherAI/pythia-410m-deduped')
    parser.add_argument('--num_random', type=int, default=1)
    parser.add_argument('--outliers', type=str, default='mean', choices=["randomize", "keep", "zero", "mean", "clip", "mean+p-value", "p-value"], help='Outlier handling method')
    parser.add_argument('--normalize', type=str, default='train', choices=["no", "train", "combined"], help='Normalization method')
    parser.add_argument('--percent_to_train', type=float, default=0.25, help='Percentage of data to use for training')
    parser.add_argument('--no-test', action='store_true', help='If set, do not run test split, run train split twice with --no_test_run')

    args = parser.parse_args()

    model_name_safe = args.model_name.replace("/", "_")
    results_root = Path(args.results_dir) / model_name_safe
    if not results_root.exists():
        raise SystemExit(f"No results for model {model_name_safe} in {args.results_dir}")

    subsets_list = [p.name for p in results_root.iterdir() if p.is_dir()]

    metrics_suffix = f"metrics.json"

    for subset in subsets_list:
        print(f"Processing subset: {subset}", flush=True)
        metrics_dir = results_root / subset
        train_metrics_path = metrics_dir / f"train_{metrics_suffix}"
        val_metrics_path = metrics_dir / f"val_{metrics_suffix}"
        test_metrics_path = metrics_dir / f"test_{metrics_suffix}"

        num_samples_train = get_num_samples(train_metrics_path)
        num_samples_val = get_num_samples(val_metrics_path)
        num_samples = int(min(num_samples_train, num_samples_val) * args.percent_to_train)

        # Run for train split
        cmd_train = [
            "python3", "linear_di.py",
            "--model_name", args.model_name,
            "--dataset_name", subset,
            "--num_samples", str(num_samples),
            "--normalize", "train",
            "--features", "selected",
            "--num_random", str(args.num_random),
            "--train_metrics_path", str(train_metrics_path),
            "--val_metrics_path", str(val_metrics_path),
            "--output_dir", args.output_dir,
            "--save_file_prefix", "train_val",
            "--outliers", args.outliers
        ]
        print(f"Running (train): {' '.join(cmd_train)}", flush=True)
        subprocess.run(cmd_train, check=True)

        if args.no_test:
            # Run train split again instead of test, with --no_test_run
            num_samples_train2 = int(num_samples_val // 2 * args.percent_to_train)
            cmd_train2 = [
                "python3", "linear_di.py",
                "--model_name", args.model_name,
                "--dataset_name", subset,
                "--num_samples", str(num_samples_train2),
                "--normalize", "train",
                "--features", "selected",
                "--false_positive",
                "--num_random", str(args.num_random),
                "--train_metrics_path", str(train_metrics_path),
                "--val_metrics_path", str(val_metrics_path),
                "--output_dir", args.output_dir,
                "--save_file_prefix", "test_val",
                "--outliers", args.outliers
            ]
            print(f"Running (train, no_test): {' '.join(cmd_train2)}", flush=True)
            subprocess.run(cmd_train2, check=True)
        else:
            if test_metrics_path.exists():
                num_samples_test = get_num_samples(test_metrics_path)
                num_samples_test_run = int(min(num_samples_test, num_samples_val) * args.percent_to_train)
                cmd_test = [
                    "python3", "linear_di.py",
                    "--model_name", args.model_name,
                    "--dataset_name", subset,
                    "--num_samples", str(num_samples_test_run),
                    "--normalize", "train",
                    "--features", "selected",
                    "--num_random", str(args.num_random),
                    "--train_metrics_path", str(test_metrics_path),
                    "--val_metrics_path", str(val_metrics_path),
                    "--output_dir", args.output_dir,
                    "--save_file_prefix", "test_val",
                    "--outliers", args.outliers
                ]
                print(f"Running (test): {' '.join(cmd_test)}", flush=True)
                subprocess.run(cmd_test, check=True)

if __name__ == "__main__":
    main()
