import argparse
import os
import json
from pathlib import Path
from tqdm import tqdm

def combine_shards(base_dir: str, shard_prefix="qualitative_prompt_output_shard_", output_name="combined.json"):
    """
    Walks through `base_dir` recursively, finds shard files, merges them into a single JSON.
    Saves the merged file as `combined.json` inside the same folder.
    """
    base = Path(base_dir)

    for root, dirs, files in os.walk(base):
        if "calibration" in root:
            continue
        shard_files = sorted(
            [f for f in files if f.startswith(shard_prefix) and f.endswith(".json")]
        )
        if not shard_files:
            continue

        merged = []
        for sf in tqdm(shard_files, desc=f"Merging {root}", leave=False):
            with open(Path(root) / sf, "r") as f:
                data = json.load(f)
            if isinstance(data, list):
                merged.extend(data)
            else:
                merged.append(data)

        out_path = Path(root) / output_name
        with open(out_path, "w") as f:
            json.dump(merged, f, indent=2)

        print(f"Saved {out_path} with {len(merged)} entries")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root_dir", type=str, required=True,
                        help="Root directory to search for shard files.")
    args = parser.parse_args()

    # Top-level results directory
    results_root = os.path.join(args.root_dir, "/alignment_results")
    combine_shards(results_root)
