#!/usr/bin/env python3
"""
count_versions_and_missing.py

Scan a json_blocks directory for versioned files named like:
  <id>_v<k>.json
  <id>_v<k>.raw.txt

Report:
1) IDs present but with version-count < --min-versions (default: 2)
2) IDs **missing entirely** (0 versions), computed relative to an expected ID set:
   - from an original dataset (--original JSON/JSONL; uses 'id' if present, otherwise line index)
   - or from a numeric range (--min-id, --max-id, inclusive)
   - or from a text file of IDs (one int per line) via --id-list
   - you may combine multiple; the expected set is the union

Options:
- --json-only counts only *.json (ignores *.raw.txt). By default, both count toward a version, de-duped per version number.

Usage examples:
  # Count and report under-threshold + missing compared to an original dataset
  python count_versions_and_missing.py \
    --json-blocks-dir outputs/json_blocks \
    --original data/original.jsonl \
    --min-versions 2 \
    --report-under outputs/under_min_versions.txt \
    --report-missing outputs/missing_ids.txt

  # Count and report using an explicit ID range
  python count_versions_and_missing.py \
    --json-blocks-dir outputs/json_blocks \
    --min-id 0 --max-id 999 \
    --min-versions 3

  # Count and report using a list of expected IDs
  python count_versions_and_missing.py \
    --json-blocks-dir outputs/json_blocks \
    --id-list ids.txt
"""

import argparse
import json
import re
from pathlib import Path
from typing import Dict, Set, List, Any

# Match "123_v1.json" or "123_v1.raw.txt"
FILE_RE = re.compile(r"^(?P<pid>\d+)_v(?P<v>\d+)\.(?:json|raw\.txt)$")


def collect_version_counts(blocks_dir: Path, json_only: bool) -> Dict[int, Set[int]]:
    """
    Return {id: set_of_version_numbers}.
    If json_only=True, only count *.json files; otherwise count *.json and *.raw.txt.
    """
    counts: Dict[int, Set[int]] = {}
    if not blocks_dir.exists():
        raise FileNotFoundError(f"Directory not found: {blocks_dir}")

    patterns = ["*.json"] if json_only else ["*.json", "*.raw.txt"]

    for pattern in patterns:
        for p in blocks_dir.glob(pattern):
            m = FILE_RE.match(p.name)
            if not m:
                continue
            pid = int(m.group("pid"))
            ver = int(m.group("v"))
            counts.setdefault(pid, set()).add(ver)
    return counts


def load_expected_from_original(path: Path) -> Set[int]:
    """
    Load expected IDs from an original dataset (JSONL or JSON list).
    If an item has an integer 'id', use it; else use its 0-based index.
    """
    expected: Set[int] = set()
    if path.suffix.lower() == ".jsonl":
        with path.open("r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
                s = line.strip()
                if not s:
                    continue
                try:
                    obj = json.loads(s)
                except Exception:
                    continue
                if isinstance(obj, dict) and isinstance(obj.get("id"), int):
                    expected.add(int(obj["id"]))
                else:
                    expected.add(idx)
    elif path.suffix.lower() == ".json":
        data = json.loads(path.read_text(encoding="utf-8"))
        if not isinstance(data, list):
            raise ValueError(f"Expected a list in {path}, got {type(data)}")
        for idx, obj in enumerate(data):
            if isinstance(obj, dict) and isinstance(obj.get("id"), int):
                expected.add(int(obj["id"]))
            else:
                expected.add(idx)
    else:
        raise ValueError(f"Unsupported file extension for --original: {path.suffix}")
    return expected


def load_expected_from_id_list(path: Path) -> Set[int]:
    """
    Load expected IDs from a text file (one integer ID per line).
    Lines starting with '#' are ignored.
    """
    expected: Set[int] = set()
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s or s.startswith("#"):
                continue
            try:
                expected.add(int(s))
            except Exception:
                pass
    return expected


def main():
    ap = argparse.ArgumentParser(description="Count versions per ID and report under-threshold + missing IDs (0 versions).")
    ap.add_argument("--json-blocks-dir", type=str, required=True, help="Directory with <id>_v<k>.(json|raw.txt) files")
    ap.add_argument("--min-versions", type=int, default=2, help="Minimum required versions per question (default: 2)")
    ap.add_argument("--json-only", action="store_true", help="Only count *.json files (ignore *.raw.txt)")
    ap.add_argument("--report-under", type=str, default=None, help="Optional path to write under-threshold IDs and counts")
    ap.add_argument("--report-missing", type=str, default=None, help="Optional path to write missing IDs (0 versions)")

    # Expected ID sources (union of all provided)
    ap.add_argument("--original", type=str, default=None, help="Original dataset (JSONL or JSON) to derive expected IDs")
    ap.add_argument("--min-id", type=int, default=0, help="Lower bound (inclusive) for expected ID range")
    ap.add_argument("--max-id", type=int, default=None, help="Upper bound (inclusive) for expected ID range")
    ap.add_argument("--id-list", type=str, default=None, help="Text file with one expected ID per line")

    args = ap.parse_args()

    blocks_dir = Path(args.json_blocks_dir)
    id_to_versions = collect_version_counts(blocks_dir, json_only=args.json_only)

    # Build expected ID set (union of sources)
    expected_ids: Set[int] = set()
    if args.original:
        expected_ids |= load_expected_from_original(Path(args.original))
    if args.min_id is not None and args.max_id is not None:
        if args.min_id > args.max_id:
            raise ValueError("--min-id must be <= --max-id")
        expected_ids |= set(range(args.min_id, args.max_id + 1))
    if args.id_list:
        expected_ids |= load_expected_from_id_list(Path(args.id_list))

    # Summary counts
    present_ids = set(id_to_versions.keys())
    total_present = len(present_ids)

    # Under-threshold (only among PRESENT ids)
    under = []
    for pid in sorted(present_ids):
        n = len(id_to_versions[pid])
        if n < args.min_versions:
            under.append((pid, n))

    # Missing (0 versions) only if we know expected IDs
    missing = []
    if expected_ids:
        missing_ids = sorted(expected_ids - present_ids)
        missing = [(pid, 0) for pid in missing_ids]

    # Print summary
    print(f"Scanned: {blocks_dir}")
    print(f"Present questions found: {total_present}")
    print(f"Threshold (min versions): {args.min_versions}")
    print(f"Under-threshold (present but < threshold): {len(under)}")
    if expected_ids:
        print(f"Expected IDs provided: {len(expected_ids)}")
        print(f"Missing (0 versions): {len(missing)}")
    else:
        print("Missing (0 versions): not computed (no expected IDs provided)")

    # Print details
    if under:
        print("\n[Under-threshold]")
        print("ID\tversions")
        for pid, n in under:
            print(f"{pid}\t{n}")
    else:
        print("\n[Under-threshold] None")

    if expected_ids:
        if missing:
            print("\n[Missing]")
            print("ID\tversions")
            for pid, n in missing:
                print(f"{pid}\t{n}")
        else:
            print("\n[Missing] None")

    # Optional reports
    if args.report_under:
        out_under = Path(args.report_under)
        out_under.parent.mkdir(parents=True, exist_ok=True)
        with out_under.open("w", encoding="utf-8") as f:
            f.write(f"# Under-threshold (versions < {args.min_versions})\n")
            f.write("id\tversions\n")
            for pid, n in under:
                f.write(f"{pid}\t{n}\n")
        print(f"\nWrote under-threshold report to: {out_under}")

    if args.report_missing and expected_ids:
        out_missing = Path(args.report_missing)
        out_missing.parent.mkdir(parents=True, exist_ok=True)
        with out_missing.open("w", encoding="utf-8") as f:
            f.write("# Missing (0 versions)\n")
            f.write("id\tversions\n")
            for pid, n in missing:
                f.write(f"{pid}\t{n}\n")
        print(f"Wrote missing report to: {out_missing}")


if __name__ == "__main__":
    main()
