#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Convert a local Video-MME dataset into a JSON file compatible with this project,
and optionally extract the bundled video zip files.

Expected directory layout:
  <dataset_root>/
    ├─ videomme/test-00000-of-00001.parquet
    ├─ videos_chunked_01.zip ... videos_chunked_20.zip
    └─ subtitle.zip (optional)

This script:
  1) Extracts videos_chunked_*.zip into <dataset_root>/videos (if needed)
  2) Reads parquet shards and writes videomme_test.json

Output JSON fields:
  - id:             unique question id (prefer question_id)
  - video_path:     path relative to <dataset_root>/<video_dir>
  - question:       question text
  - candidates:     list of answer options
  - correct_choice: index in candidates (0-based), or None if unknown
  - meta:           selected original fields (domain/sub_category/task_type/duration/url)

Usage:
  python scripts/prepare_videomme.py /path/to/Video-MME \
    --parquet videomme/test-00000-of-00001.parquet \
    --output videomme_test.json \
    --extract

Notes:
  - If <dataset_root>/videos already exists and is non-empty, extraction is skipped.
  - Requires either pandas or pyarrow to read parquet.
"""

import argparse
import json
import os
import re
import sys
import glob
import zipfile
from pathlib import Path
from typing import List, Dict, Any, Optional


VIDEO_EXTS = {".mp4", ".webm", ".mkv", ".mov", ".avi"}


def ensure_extract_videos(dataset_root: Path, do_extract: bool) -> Path:
    """Extract videos_chunked_*.zip into <root>/videos if needed."""
    videos_dir = dataset_root / "videos"
    # If already exists and non-empty, return
    if videos_dir.exists() and any(videos_dir.rglob("*")):
        print(f"[INFO] Found existing videos dir: {videos_dir}")
        return videos_dir

    if not do_extract:
        print(f"[WARN] --extract not set and videos are not present; matching may fail: {videos_dir}")
        return videos_dir

    zips = sorted(dataset_root.glob("videos_chunked_*.zip"))
    if not zips:
        print("[ERROR] videos_chunked_*.zip not found; cannot extract.")
        return videos_dir

    videos_dir.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] Extracting {len(zips)} zip files into {videos_dir} ...")
    for z in zips:
        print(f"  - Extract: {z.name}")
        with zipfile.ZipFile(z, 'r') as zip_ref:
            zip_ref.extractall(videos_dir)
    print("[OK] Extraction complete.")
    return videos_dir


def build_video_index(videos_dir: Path) -> Dict[str, List[str]]:
    """Build an index mapping basename-without-ext -> list of relative paths."""
    index: Dict[str, List[str]] = {}
    if not videos_dir.exists():
        return index
    for p in videos_dir.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in VIDEO_EXTS:
            continue
        base = p.stem
        rel = p.relative_to(videos_dir).as_posix()
        index.setdefault(base, []).append(rel)
    return index


def read_parquet_records(parquet_paths: List[Path]) -> List[Dict[str, Any]]:
    """Read one or more parquet files into a list of dict records (pandas or pyarrow required)."""
    errors = []
    try:
        import pandas as pd  # type: ignore
        dfs = [pd.read_parquet(str(p)) for p in parquet_paths]
        df = pd.concat(dfs, ignore_index=True) if len(dfs) > 1 else dfs[0]
        records = df.to_dict(orient="records")
        return records
    except Exception as e:
        errors.append(f"pandas read failed: {e}")

    try:
        import pyarrow.parquet as pq  # type: ignore
        import pyarrow as pa  # noqa: F401
        tables = [pq.read_table(str(p)) for p in parquet_paths]
        table = tables[0] if len(tables) == 1 else pa.concat_tables(tables)
        return table.to_pylist()
    except Exception as e:
        errors.append(f"pyarrow read failed: {e}")

    raise RuntimeError("Unable to read parquet; install pandas or pyarrow.\n" + "\n".join(errors))


def normalize_options(opts: Any) -> List[str]:
    """Normalize options in various formats into a list of strings."""

    def _clean_list(items: List[Any]) -> List[str]:
        cleaned = []
        for item in items:
            text = str(item).strip()
            if text:
                cleaned.append(text)
        return cleaned

    if opts is None:
        return []

    if isinstance(opts, (list, tuple)):
        return _clean_list(list(opts))

    if isinstance(opts, str):
        s = opts.strip()
        if not s:
            return []

        # 1) JSON list
        try:
            parsed = json.loads(s)
            if isinstance(parsed, list):
                result = _clean_list(parsed)
                if result:
                    return result
        except Exception:
            pass

        # 2) e.g. ['A. foo' 'B. bar'] / ["A. foo","B. bar"] / A. foo B. bar
        bracket_str = s[1:-1] if s.startswith("[") and s.endswith("]") else s

        single_quote_items = re.findall(r"'([^']+)'", bracket_str)
        if single_quote_items:
            result = _clean_list(single_quote_items)
            if result:
                return result

        double_quote_items = re.findall(r'"([^"]+)"', bracket_str)
        if double_quote_items:
            result = _clean_list(double_quote_items)
            if result:
                return result

        # 3) Split by "A. " prefix patterns
        letter_split = re.split(r"(?=[A-Z]\.\s)", bracket_str)
        letter_split = [seg.strip() for seg in letter_split if seg.strip()]
        if len(letter_split) > 1:
            return letter_split

        # 4) Fallback: common delimiters
        parts = re.split(r"\s*\|\||\|\s*|\n+", s)
        return [p.strip() for p in parts if p.strip()]

    return _clean_list([opts])


def answer_to_index(answer: Any, options: List[str]) -> Optional[int]:
    if answer is None or not options:
        return None
    # Single letter A/B/C/...
    if isinstance(answer, str) and len(answer.strip()) == 1:
        ch = answer.strip().upper()
        if 'A' <= ch <= 'Z':
            idx = ord(ch) - ord('A')
            if 0 <= idx < len(options):
                return idx
    # Exact text match
    try:
        a = str(answer).strip()
        for i, opt in enumerate(options):
            if a == opt or a.lower() == opt.lower():
                return i
    except Exception:
        pass
    return None


def match_video_path(row: Dict[str, Any], video_index: Dict[str, List[str]]) -> Optional[str]:
    candidates = []
    for k in ["video_id", "videoID", "videoId", "id", "question_id"]:
        v = row.get(k)
        if v:
            candidates.append(str(v))
    # Exact match by basename
    for key in candidates:
        if key in video_index:
            # If multiple, prefer top-level or the first
            paths = sorted(video_index[key], key=lambda p: ("/" in p, p))
            return paths[0]
    # Fuzzy contains match
    low_index = {k.lower(): v for k, v in video_index.items()}
    for key in candidates:
        kl = key.lower()
        if kl in low_index:
            paths = sorted(low_index[kl], key=lambda p: ("/" in p, p))
            return paths[0]
        # Further: substring scan (OK for typical dataset sizes)
        for base, paths in low_index.items():
            if kl in base:
                return sorted(paths, key=lambda p: ("/" in p, p))[0]
    return None


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("dataset_root", help="Video-MME dataset root directory")
    ap.add_argument("--parquet", default=None, help="parquet path or glob (default: videomme/test-*.parquet)")
    ap.add_argument("--output", default="videomme_test.json", help="Output JSON filename (written under dataset root)")
    ap.add_argument("--extract", action="store_true", help="Extract videos_chunked_*.zip into videos/ if needed")
    args = ap.parse_args()

    root = Path(args.dataset_root).resolve()
    if not root.exists():
        print(f"[ERROR] Dataset root does not exist: {root}", file=sys.stderr)
        sys.exit(1)

    videos_dir = ensure_extract_videos(root, args.extract)
    video_index = build_video_index(videos_dir)
    print(f"[INFO] Indexed {sum(len(v) for v in video_index.values())} video files ({len(video_index)} basename keys)")

    # Resolve parquet files
    parquet_glob = args.parquet or "videomme/test-*.parquet"
    parquet_paths = [Path(p) for p in glob.glob(str(root / parquet_glob))]
    parquet_paths = [p for p in parquet_paths if p.exists()]
    if not parquet_paths:
        print(f"[ERROR] parquet files not found: {root / parquet_glob}", file=sys.stderr)
        sys.exit(2)
    print(f"[INFO] Reading parquet shards: {len(parquet_paths)}")

    rows = read_parquet_records(parquet_paths)

    tasks: List[Dict[str, Any]] = []
    missing_videos = 0
    for i, r in enumerate(rows):
        qid = r.get("question_id") or r.get("id") or f"VMME_{i}"
        question = r.get("question") or ""
        options = normalize_options(r.get("options"))
        answer = r.get("answer")
        correct_idx = answer_to_index(answer, options)
        rel_video = match_video_path(r, video_index)
        if rel_video is None:
            missing_videos += 1
            rel_video = ""  # placeholder; downstream will error for easier debugging

        task = {
            "id": qid,
            "video_path": rel_video,  # relative to <root>/videos
            "question": question,
            "candidates": options,
            "correct_choice": correct_idx,
            "meta": {
                "duration": r.get("duration"),
                "domain": r.get("domain"),
                "sub_category": r.get("sub_category"),
                "task_type": r.get("task_type"),
                "url": r.get("url"),
                "video_id": r.get("video_id") or r.get("videoID"),
            }
        }
        tasks.append(task)

    out_path = root / args.output
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(tasks, f, ensure_ascii=False, indent=2)

    print("\n[OK] Conversion complete")
    print(f"  Output: {out_path}")
    print(f"  Samples: {len(tasks)}")
    if missing_videos:
        print(f"  [WARN] {missing_videos} items did not match any video file; check video_id/filenames.")


if __name__ == "__main__":
    main()
