#!/usr/bin/env python3
"""
concat_blocks.py

Concatenate per-problem JSON block files (<id>_v*.json) into JSONL.

Behavior:
- If --grouped-output is provided, write grouped JSONL:
    {"id": <int>, "versions": [ { ... }, { ... }, ... ]}
- If --flat-output is provided, write flat JSONL:
    {"id": <int>, "version": <int>, ...fields from block...}
- If both flags are provided, write both files.

Usage
-----
python concat_blocks.py \
  --json-blocks-dir outputs/json_blocks \
  --grouped-output outputs/dataset.grouped.jsonl \
  --flat-output outputs/dataset.flat.jsonl
"""

import argparse
import json
import re
from pathlib import Path
from typing import Dict, List

BLOCK_RE = re.compile(r"^(?P<id>\d+)_v(?P<v>\d+)\.json$")


def load_blocks(json_blocks_dir: Path) -> Dict[int, Dict[int, dict]]:
    """
    Load all <id>_v<k>.json files from json_blocks_dir.
    Returns: {id: {version: json_obj}}
    """
    data: Dict[int, Dict[int, dict]] = {}
    for path in json_blocks_dir.glob("*.json"):
        m = BLOCK_RE.match(path.name)
        if not m:
            continue
        pid = int(m.group("id"))
        v = int(m.group("v"))
        try:
            obj = json.loads(path.read_text(encoding="utf-8"))
        except Exception as e:
            print(f"Warning: failed to parse {path}: {e}")
            continue
        data.setdefault(pid, {})[v] = obj
    return data


def write_grouped(blocks: Dict[int, Dict[int, dict]], out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    n = 0
    with out_path.open("w", encoding="utf-8") as f:
        for pid in sorted(blocks.keys()):
            versions: List[dict] = [blocks[pid][v] for v in sorted(blocks[pid].keys())]
            # make sure the "question" and "answer" fields are strings
            for vobj in versions:
                if "question" in vobj and not isinstance(vobj["question"], str):
                    vobj["question"] = str(vobj["question"])
                if "answer" in vobj and not isinstance(vobj["answer"], str):
                    vobj["answer"] = str(vobj["answer"])
            record = {"id": pid, "versions": versions}
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
            n += 1
    print(f"[grouped] Wrote {n} problems to {out_path}")


def write_flat(blocks: Dict[int, Dict[int, dict]], out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    n = 0
    with out_path.open("w", encoding="utf-8") as f:
        for pid in sorted(blocks.keys()):
            for v in sorted(blocks[pid].keys()):
                obj = blocks[pid][v]
                rec = {"id": pid, "version": v}
                rec.update(obj)  # block fields override id/version if collision
                # make sure the "question" and "answer" fields are strings
                if "question" in rec and not isinstance(rec["question"], str):
                    rec["question"] = str(rec["question"])
                if "answer" in rec and not isinstance(rec["answer"], str):
                    rec["answer"] = str(rec["answer"])
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")
                n += 1
    print(f"[flat] Wrote {n} versions to {out_path}")


def main():
    p = argparse.ArgumentParser(description="Concatenate JSON blocks into grouped and/or flat JSONL.")
    p.add_argument("--json-blocks-dir", type=str, required=True, help="Directory with <id>_v*.json files")
    p.add_argument("--grouped-output", type=str, default=None, help="Path to grouped JSONL (one line per id)")
    p.add_argument("--flat-output", type=str, default=None, help="Path to flat JSONL (one line per version)")
    args = p.parse_args()

    blocks_dir = Path(args.json_blocks_dir)
    if not blocks_dir.exists():
        print(f"Error: blocks dir not found: {blocks_dir}")
        return

    blocks = load_blocks(blocks_dir)
    if not blocks:
        print(f"No matching JSON blocks found in {blocks_dir}")
        return

    if args.grouped_output:
        write_grouped(blocks, Path(args.grouped_output))
    if args.flat_output:
        write_flat(blocks, Path(args.flat_output))

    if not args.grouped_output and not args.flat_output:
        print("No output written: please specify --grouped-output, --flat-output, or both.")


if __name__ == "__main__":
    main()
