#!/usr/bin/env python3
"""
Extract data from a JSONL file and organize it into a dataset format (`meta.json`).

It extracts four fields:
- problem: question text (from `visualize_qa.question` or `record.question`)
- cot: reasoning process from the visualization QA (`visualize_qa.cot`)
- answer: `generation.answer`
- actual_data: only keeps four fields: `points` / `segments` / `circles` / `annotations`

General rules:
- Keep only samples whose `status == "success"`.
- For `actual_data`, prefer top-level `record.actual_data`;
  if it does not exist, fall back to `record.plotting.actual_data`.

Example usage (run from project root):

1) For the 1213 sanitized file, write out as dataset 5:
    python scripts/extract_dataset.py \
        --input data/output/vllm_sanitized_1213_new.jsonl \
        --output-dir data/datasets/5

2) For the 0117 file (also filtering `status == "success"`), write out as dataset 6:
    python scripts/extract_dataset.py \
        --input data/output/vllm_results_0117.jsonl \
        --output-dir data/datasets/6
"""

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List


def iter_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
    """Read a JSONL file line by line and parse each line as a dict."""
    with path.open("r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError as e:
                print(f"[WARN] Failed to parse JSON at line {line_no}: {e}", file=sys.stderr)
                continue


def extract_actual_subset(actual_data: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract the four required fields from `actual_data`:
    - points
    - segments
    - circles
    - annotations
    All other fields (e.g., `segment_chains`, `annotation_summary`, etc.) are discarded.
    """
    return {
        "points": actual_data.get("points", {}),
        "segments": actual_data.get("segments", []),
        "circles": actual_data.get("circles", []),
        "annotations": actual_data.get("annotations", {}),
    }


def extract_record_fields(record: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract the four fields `problem`, `cot`, `answer`, and `actual_data` from a raw record.
    """
    idx = record.get("index")
    if idx is None:
        raise ValueError("Record is missing the 'index' field")

    # problem
    visualize_qa = record.get("visualize_qa") or {}
    problem = visualize_qa.get("question") or record.get("question") or ""

    # cot
    cot = visualize_qa.get("cot", "")

    # answer
    gen = record.get("generation") or {}
    answer = gen.get("answer") or ""

    # actual_data (prefer top-level, otherwise fall back to plotting.actual_data)
    actual_src = record.get("actual_data")
    if not actual_src:
        plotting = record.get("plotting") or {}
        actual_src = plotting.get("actual_data")

    if actual_src:
        try:
            actual_data = extract_actual_subset(actual_src)
        except Exception as e:
            print(f"[WARN] index={idx} failed to extract actual_data: {e}", file=sys.stderr)
            actual_data = {
                "points": {},
                "segments": [],
                "circles": [],
                "annotations": {
                    "right_angles": [],
                    "length_of_line": [],
                    "measure_of_angle": [],
                },
            }
    else:
        actual_data = {
            "points": {},
            "segments": [],
            "circles": [],
            "annotations": {
                "right_angles": [],
                "length_of_line": [],
                "measure_of_angle": [],
            },
        }

    return {
        "problem": problem,
        "cot": cot,
        "answer": answer,
        "actual_data": actual_data,
    }


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Extract problem/cot/answer/actual_data from a JSONL file and generate dataset meta.json"
    )
    parser.add_argument(
        "--input",
        type=str,
        required=True,
        help="Input JSONL file path",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        required=True,
        help="Output dataset directory, e.g. data/datasets/5",
    )
    args = parser.parse_args()

    input_path = Path(args.input)
    output_dir = Path(args.output_dir)

    if not input_path.exists():
        print(f"[ERR] Input file does not exist: {input_path}", file=sys.stderr)
        sys.exit(1)

    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"[INFO] Input file: {input_path}")
    print(f"[INFO] Output directory: {output_dir}")

    all_meta: List[Dict[str, Any]] = []
    total = 0
    kept = 0

    for record in iter_jsonl(input_path):
        total += 1

        # Keep only records with status == "success"
        if record.get("status") != "success":
            continue

        try:
            meta = extract_record_fields(record)
        except Exception as e:
            idx = record.get("index")
            print(f"[WARN] index={idx} failed to extract fields: {e}", file=sys.stderr)
            continue

        all_meta.append(meta)
        kept += 1

        if kept % 500 == 0:
            print(f"[PROGRESS] Kept {kept} records (scanned {total} records)")

    # Write meta.json
    meta_file = output_dir / "meta.json"
    try:
        with meta_file.open("w", encoding="utf-8") as f:
            json.dump(all_meta, f, ensure_ascii=False, indent=2)
        print(f"[INFO] meta.json written with {len(all_meta)} records")
    except Exception as e:
        print(f"[ERR] Failed to write meta.json: {e}", file=sys.stderr)
        sys.exit(1)

    print(f"[DONE] Scanned {total} records, kept {kept} samples with status=='success'")
    print(f"[DONE] meta.json path: {meta_file}")


if __name__ == "__main__":
    main()

