#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import sys
from pathlib import Path
from typing import Dict, Any, Tuple, List, Optional


def iter_json_items(path: Path):
    """
    逐条产出 (idx, obj)。支持 JSON 数组 / JSON 对象 / JSONL。
    - 数组：idx 为数组索引
    - 对象：idx 固定为 0
    - JSONL：idx 为行号（从 1 开始）
    """
    text = path.read_text(encoding="utf-8", errors="ignore")
    # 普通 JSON
    try:
        data = json.loads(text)
        if isinstance(data, list):
            for i, obj in enumerate(data):
                yield i, obj
        elif isinstance(data, dict):
            yield 0, data
        else:
            raise ValueError("Unsupported JSON type")
        return
    except Exception:
        pass
    # JSONL
    for i, line in enumerate(text.splitlines(), start=1):
        line = line.strip()
        if not line:
            continue
        try:
            obj = json.loads(line)
            yield i, obj
        except Exception:
            continue


def resolve_existing(path: Path) -> Optional[Path]:
    """
    若 path 存在，直接返回；
    若不存在，尝试在同目录下用同 stem 的 .json/.jsonl 互换查找；
    若仍不存在，返回 None。
    """
    if path.is_file():
        return path
    # 尝试互换扩展名
    if path.suffix.lower() == ".json":
        alt = path.with_suffix(".jsonl")
        if alt.is_file():
            return alt
    elif path.suffix.lower() == ".jsonl":
        alt = path.with_suffix(".json")
        if alt.is_file():
            return alt
    # 最后再按 stem 搜索同目录
    for ext in (".jsonl", ".json"):
        cand = path.with_suffix(ext)
        if cand.is_file():
            return cand
    return None


def build_index_by_id(path: Path) -> Dict[str, Dict[str, Any]]:
    """
    为单个原始文件构建： id -> obj 的索引。
    只认顶层键 'id'。
    """
    index: Dict[str, Dict[str, Any]] = {}
    if not path or not path.is_file():
        return index
    for _, obj in iter_json_items(path):
        if isinstance(obj, dict) and "id" in obj:
            sid = str(obj["id"])
            if sid not in index:
                index[sid] = obj
    return index


def normalize_ws(s: str) -> str:
    """空白折叠为单空格并去首尾空白。"""
    return " ".join(str(s).split())


def verify_records(
    records: List[Dict[str, Any]],
    train_path: Path,
    val_path: Path,
    test_path: Path
) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
    """
    用 id_or_index 作为要查找的 id 值，去原始文件(按键名 'id')查找；
    优先按记录里的 source_file 名称匹配（只比对文件名），找不到再回退到其它文件；
    校验 snippet 是否为原始 input 的子串（原样 + 空白标准化）。
    """
    # 解析实际存在的文件
    resolved = {
        "train": resolve_existing(train_path),
        "val":   resolve_existing(val_path),
        "test":  resolve_existing(test_path),
    }

    for k, p in resolved.items():
        if p is None:
            print(f"[WARN] 未找到 {k} 对应文件：{(train_path if k=='train' else val_path if k=='val' else test_path)}", file=sys.stderr)
        else:
            print(f"[INFO] 使用 {k}: {p}", file=sys.stderr)

    # 建索引
    idx_train = build_index_by_id(resolved["train"]) if resolved["train"] else {}
    idx_val   = build_index_by_id(resolved["val"])   if resolved["val"]   else {}
    idx_test  = build_index_by_id(resolved["test"])  if resolved["test"]  else {}

    name_to_idx = {}
    # 用实际存在文件的“文件名”作 key，指向各自索引
    if resolved["train"]:
        name_to_idx[resolved["train"].name] = idx_train
    if resolved["val"]:
        name_to_idx[resolved["val"].name] = idx_val
    if resolved["test"]:
        name_to_idx[resolved["test"].name] = idx_test

    # 额外支持 .json/.jsonl 同 stem 的互通（用来匹配 records 里的 source_file 名）
    for p, idx in [(resolved["train"], idx_train), (resolved["val"], idx_val), (resolved["test"], idx_test)]:
        if not p:
            continue
        alt = p.with_suffix(".jsonl" if p.suffix.lower() == ".json" else ".json")
        name_to_idx.setdefault(alt.name, idx)

    results: List[Dict[str, Any]] = []
    stats = {
        "total": 0,
        "found": 0,
        "missing_id": 0,
        "no_input_field": 0,
        "pass_raw": 0,
        "pass_norm_only": 0,
        "fail": 0,
    }

    def find_obj(preferred_name: Optional[str], sid: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
        # 1) 按 source_file 提示（只比较文件名）
        if preferred_name:
            pref = Path(preferred_name).name
            idx = name_to_idx.get(pref)
            if idx:
                hit = idx.get(sid)
                if hit is not None:
                    return pref, hit
        # 2) 回退到其它文件
        for fname, idx in name_to_idx.items():
            hit = idx.get(sid)
            if hit is not None:
                return fname, hit
        return None, None

    for rec in records:
        stats["total"] += 1
        sid = str(rec.get("id_or_index", "")).strip()  # 你提取时用的 id_or_index
        snippet = rec.get("input", "")
        preferred = rec.get("source_file")

        item = {
            "id_or_index": sid,
            "source_file_hint": preferred,
            "matched_file": None,
            "found": False,
            "has_input": False,
            "contains_raw": False,
            "contains_norm": False,
            "note": "",
        }

        if not sid:
            item["note"] = "记录缺少 id_or_index"
            stats["missing_id"] += 1
            results.append(item)
            continue

        matched_file, obj = find_obj(preferred, sid)  # 用 sid 去按 'id' 查
        if obj is None:
            item["note"] = "未在任一原始文件中找到该 id（按键 'id' 查找）"
            stats["missing_id"] += 1
            results.append(item)
            continue

        item["found"] = True
        item["matched_file"] = matched_file
        stats["found"] += 1

        if "input" not in obj:
            item["note"] = "找到 id，但原记录缺少 input 字段"
            stats["no_input_field"] += 1
            results.append(item)
            continue

        item["has_input"] = True
        orig = str(obj["input"])
        snip = str(snippet)

        if snip in orig:
            item["contains_raw"] = True
            item["note"] = "OK (raw)"
            stats["pass_raw"] += 1
        else:
            if normalize_ws(snip) in normalize_ws(orig):
                item["contains_norm"] = True
                item["note"] = "OK (norm)"
                stats["pass_norm_only"] += 1
            else:
                item["note"] = "snippet 未被原始 input 完整包含"
                stats["fail"] += 1

        results.append(item)

    return results, stats

def main():
    parser = argparse.ArgumentParser(
        description="校验提取问题是否确实存在于原始 bbh_* 文件中（按 id 精确匹配，并检查 input 包含关系）。"
    )
    parser.add_argument("--records", default="dict_questions.json", help="数组/JSONL 记录文件，元素含 {id_or_index, input, source_file?}")
    parser.add_argument(
        "--train", default="/home/hj/BenchmarkCompress/experiments_repproduction/FACLENS/new_benchmark/bbh/bbh_train.jsonl",
        help="原始训练集文件路径（默认 bbh_train.jsonl）"
    )
    parser.add_argument(
        "--val", default="/home/hj/BenchmarkCompress/experiments_repproduction/FACLENS/new_benchmark/bbh/bbh_val.jsonl",
        help="原始验证集文件路径（默认 bbh_val.jsonl）"
    )
    parser.add_argument(
        "--test", default="/home/hj/BenchmarkCompress/experiments_repproduction/FACLENS/new_benchmark/bbh/bbh_test.jsonl",
        help="原始测试集文件路径（默认 bbh_test.jsonl）"
    )
    parser.add_argument(
        "--report",
        help="将校验结果写入到指定 JSON 文件（可选）"
    )
    parser.add_argument(
        "--show-fail", action="store_true",
        help="打印失败样例与其来源，便于排查（仅显示未通过的条目）。"
    )

    args = parser.parse_args()

    rec_path = Path(args.records)
    if not rec_path.is_file():
        print(f"[ERROR] 记录文件不存在：{rec_path}", file=sys.stderr)
        sys.exit(1)

    # 读取记录（兼容数组、对象、JSONL）
    records: List[Dict[str, Any]] = []
    for _, obj in iter_json_items(rec_path):
        if isinstance(obj, list):
            records.extend([x for x in obj if isinstance(x, dict)])
        elif isinstance(obj, dict):
            records.append(obj)

    if not records:
        print("[ERROR] 记录文件未解析到任何对象。应为数组或 JSONL。", file=sys.stderr)
        sys.exit(1)

    results, stats = verify_records(
        records,
        Path(args.train),
        Path(args.val),
        Path(args.test),
    )

    print("\n=== 校验摘要 ===")
    for k, v in stats.items():
        print(f"{k}: {v}")

    if args.show_fail:
        print("\n=== 未通过条目（仅显示 snippet 失败） ===")
        for r in results:
            if r["has_input"] and not (r["contains_raw"] or r["contains_norm"]):
                print(f"- id={r['id_or_index']} | hint={r['source_file_hint']} | matched={r['matched_file']} | note={r['note']}")

    if args.report:
        Path(args.report).write_text(
            json.dumps({"stats": stats, "results": results}, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )
        print(f"\n[INFO] 详细报告写入：{args.report}")


if __name__ == "__main__":
    main()