import json
import argparse
import os

parser = argparse.ArgumentParser(description="Combine generation files and filter level-5 questions")

parser.add_argument(
    "--input_files",
    nargs="+",
    required=True,
    help="List of generation JSON files to combine"
)

parser.add_argument(
    "--level5_file",
    required=True,
    help="math_500_level_5.json file"
)

parser.add_argument(
    "--output_dir",
    default="results",
    help="Directory to save output file"
)

parser.add_argument(
    "--output_file",
    default="math_bottomk_level5_generations.json",
    help="Output filename"
)

args = parser.parse_args()

# make sure output directory exists
os.makedirs(args.output_dir, exist_ok=True)

# ---------- Combine generations ----------
combined_generations = []

for f in args.input_files:
    with open(f, "r") as fh:
        data = json.load(fh)
        assert "generations" in data, f"{f} missing 'generations'"
        combined_generations.extend(data["generations"])

print("Total generations:", len(combined_generations))

# ---------- Load level5 questions ----------
with open(args.level5_file, "r") as f:
    level5_data = json.load(f)

level5_questions = {item["question"] for item in level5_data}

# ---------- Match ----------
matched_rows = []

for gen in combined_generations:
    q = gen["question"]
    if q in level5_questions:
        matched_rows.append(gen)

# ---------- Save ----------
output_path = os.path.join(args.output_dir, args.output_file)

with open(output_path, "w") as f:
    json.dump(matched_rows, f, indent=2)

print(f"Extracted {len(matched_rows)} rows into {output_path}")
