#!/usr/bin/env python3
# Read the original JSONL, preserve all fields, and only (re)write "group_traces"
import json
import re
import sys
import argparse
from pathlib import Path


def compile_patterns(n_agents: int):
    """Compile regex patterns for <Agent i>...</Agent i> blocks."""
    return {
        i: re.compile(rf"<Agent\s*{i}>(.*?)</Agent\s*{i}>",
                      re.DOTALL | re.IGNORECASE)
        for i in range(1, n_agents + 1)
    }


def main():
    parser = argparse.ArgumentParser(
        description="Extract <Agent i>...</Agent i> text from 'predictions' and rewrite only 'group_traces'."
    )
    parser.add_argument("--input_file", "-i", required=True,
                        help="Path to input JSONL (use '-' for stdin)")
    parser.add_argument("--output_file", "-o", required=True,
                        help="Path to output JSONL (use '-' for stdout)")
    parser.add_argument("--agents", type=int, default=4,
                        help="Number of agents to scan (default: 4)")
    args = parser.parse_args()

    # Resolve paths unless using stdio
    src_path = None if args.input_file == "-" else Path(args.input_file)
    dst_path = None if args.output_file == "-" else Path(args.output_file)

    if src_path is not None and not src_path.exists():
        raise FileNotFoundError(f"Input not found: {src_path}")
    if dst_path is not None:
        dst_path.parent.mkdir(parents=True, exist_ok=True)

    pattern_cache = compile_patterns(args.agents)

    total = 0
    written = 0
    missing_counts = {str(i): 0 for i in range(1, args.agents + 1)}
    preview = []

    fin = sys.stdin if src_path is None else src_path.open("r", encoding="utf-8")
    fout = sys.stdout if dst_path is None else dst_path.open("w", encoding="utf-8")

    try:
        for raw in fin:
            line = raw.strip()
            if not line:
                continue
            total += 1
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                # Skip malformed line
                continue

            # Normalize predictions to a single searchable string
            predictions = obj.get("predictions")
            if isinstance(predictions, (list, dict)):
                predictions_str = json.dumps(predictions, ensure_ascii=False)
            elif predictions is None:
                predictions_str = ""
            else:
                predictions_str = str(predictions)

            # Build group_traces {"0": "...", "1": "...", ...}
            group_traces = {}
            for i in range(1, args.agents + 1):
                matches = pattern_cache[i].findall(predictions_str)
                concatenated = " ".join(m.strip() for m in matches if m is not None)
                if not concatenated:
                    missing_counts[str(i)] += 1
                group_traces[str(i - 1)] = concatenated

            # Preserve all original fields; only set/replace "group_traces"
            obj["group_traces"] = group_traces

            fout.write(json.dumps(obj, ensure_ascii=False) + "\n")
            written += 1

            if len(preview) < 3:
                preview.append({"group_traces": group_traces})
    finally:
        if fin is not sys.stdin:
            fin.close()
        if fout is not sys.stdout:
            fout.close()

    # Summary to stderr (so stdout stays clean if piped)
    print(f"[INFO] Total read: {total}", file=sys.stderr)
    print(f"[INFO] Total written: {written}", file=sys.stderr)
    print(f"[INFO] Missing per agent: {missing_counts}", file=sys.stderr)
    if preview:
        print(f"[INFO] Preview of first {len(preview)} group_traces:", file=sys.stderr)
        for p in preview:
            print(p, file=sys.stderr)


if __name__ == "__main__":
    main()
