#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import sys
from pathlib import Path
import re
from typing import Iterable, Dict, Any, List, Tuple, Union

def iter_json_items(path: Path) -> Iterable[Tuple[Union[int, str], Dict[str, Any]]]:
    """
    读取 JSON 或 JSONL：
    - JSON 数组：yield (index, obj)
    - JSON 对象：若对象本身即为一条记录，yield (0, obj)
    - JSONL：每行一个 JSON，yield (line_no, obj)
    """
    text = path.read_text(encoding="utf-8", errors="ignore")
    # 粗略判断是否 JSONL（多行且每行看起来像 JSON）
    looks_like_jsonl = ("\n" in text) and all(
        (not line.strip()) or line.strip().startswith(("{", "["))
        for line in text.splitlines()[:50]
    )
    # 优先尝试普通 JSON
    try:
        data = json.loads(text)
        if isinstance(data, list):
            for i, obj in enumerate(data):
                if isinstance(obj, dict):
                    yield i, obj
        elif isinstance(data, dict):
            yield 0, data
        else:
            # 回退到逐行解析
            raise ValueError("Root JSON is not list/dict.")
    except Exception:
        # 逐行 JSON（JSONL）
        for i, line in enumerate(text.splitlines(), start=1):
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            if isinstance(obj, dict):
                yield i, obj

def first_paragraph(text: str) -> str:
    if not text:
        return ""
    # 以空行分段：一个或多个空行/仅空白的行作为分隔
    paras = re.split(r"\n\s*\n", text.strip(), maxsplit=1)
    return paras[0].strip() if paras else text.strip()

def extract_last_sentence_of_first_paragraph(text: str) -> str:
    """
    在“第一段”中提取“最后一句话（优先问句）”。

    规则：
    1) 若第一段中存在 '?'，取**最后一个 '?' 所在的句子**
    2) 否则，回退为取第一段的“最后一句”（以 . ! ? 作为句末标点）
    - 兼容包含缩写/人名（Mr., Dr., U.S.）、撇号（'s/’s）等情况
    """
    para = first_paragraph(text)
    if not para:
        return ""

    # 优先寻找最后一个问号
    q_pos = para.rfind("?")
    if q_pos != -1:
        # 在问号之前寻找最近的句边界（. ! ?）+ 可选引号
        # 例如："... end.\"  Did ..."，我们希望从 Did 开始截取
        start = 0
        for m in re.finditer(r'[.!?]["\']?\s+', para):
            if m.end() <= q_pos:
                start = m.end()
            else:
                break
        return para[start:q_pos + 1].strip()

    # 若没有问号，则取“最后一句”
    # 使用基于标点的句子切分；不过不对 Mr. / U.S. 做复杂 NLP，只取“最后一段落句”
    # 以标点 + 空白/行末作为分界
    sentences = re.findall(r'.+?(?:[.!?]+(?=\s|$)|$)', para, flags=re.S)
    sentences = [s.strip() for s in sentences if s.strip()]
    return sentences[-1] if sentences else para

def process_files(paths: List[Path]) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    for p in paths:
        if not p.exists() or not p.is_file():
            continue
        for idx, obj in iter_json_items(p):
            if not isinstance(obj, dict):
                continue
            if obj.get("category") != "causal_judgement":
                continue

            raw_q = obj.get("input", "")
            extracted = extract_last_sentence_of_first_paragraph(str(raw_q))

            # 尝试带上可回溯的标识
            ident = obj.get("id", idx)

            results.append({
                "source_file": str(p.name),
                "id_or_index": ident,
                "input": extracted
            })
    return results

def main():
    parser = argparse.ArgumentParser(
        description="从多个 JSON/JSONL 中提取 category=causal_judgement 的条目，并将其 question 字段处理为“第一段的最后一句话（优先问句）”。"
    )
    parser.add_argument(
        "-i", "--inputs",
        help="输入包含 JSON/JSONL 的目录（可混合传入多个）",
        default="/home/hj/BenchmarkCompress/experiments_repproduction/FACLENS/new_benchmark/bbh"
    )
    parser.add_argument("-o", "--output", default="dict_questions.json", help="输出 JSON 文件路径（例如 output_questions.json）")
    args = parser.parse_args()

    # 收集所有文件
    files: List[Path] = []
    inp = args.inputs
    p = Path(inp)
    if p.is_dir():
        files.extend([f for f in p.rglob("*") if f.suffix.lower() in {".json", ".jsonl"} and f.is_file()])
    elif p.is_file():
        files.append(p)
    else:
        print(f"[WARN] 未找到：{inp}", file=sys.stderr)

    if not files:
        print("[ERROR] 未找到任何 JSON/JSONL 文件。", file=sys.stderr)
        sys.exit(2)

    results = process_files(files)

    # 写出
    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print(f"完成：共提取 {len(results)} 条，已写入 {out_path}")

if __name__ == "__main__":
    main()
