# pattern_extractor_deepseek_improved.py
# 改进版：专为 DeepSeek 官方 API 优化，显著提升模式抽取质量
# 主要改进：
# 1. 温度降到 0.0 + top_p=1.0 → 更严格遵守 JSON
# 2. 每批预算缩小到 3000 tokens → 样本更集中，模式更具体
# 3. 超级强化 System Prompt + User Prompt → 强制只输出 JSON，拒绝空/宽泛模式
# 4. 增加 few-shot 示例（2个手工高质量示例）→ 引导模型输出具体、可用的模式
# 5. 其他保持稳定，支持断点续跑

import os
import re
import json
import time
import random
import requests
from typing import List, Dict, Any, Tuple
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

# =============== 配置 ===============
INPUT_TXT = r"top2000_positive.txt"
MAX_ITEMS = 2000

OUT_DIR = r""
os.makedirs(OUT_DIR, exist_ok=True)
PATTERN_BATCH_DIR = os.path.join(OUT_DIR, "pattern_batches_improved")
os.makedirs(PATTERN_BATCH_DIR, exist_ok=True)
PATTERNS_MERGED_JSON = os.path.join(OUT_DIR, "patterns_merged_improved.json")
PATTERNS_SUMMARY_TXT = os.path.join(OUT_DIR, "patterns_summary_improved.txt")

# DeepSeek 官方 API 配置（你的 Key）
API_KEY = ""
ENDPOINT = "https://api.deepseek.com/v1/chat/completions"
MODEL = "deepseek-reasoner"  # 底层为最新 V3 系列

# 参数优化（提升质量关键）
SNIPPET_CHAR_LIMIT = 510
USER_TOKEN_BUDGET = 2000      # 缩小批次，每批约15~25条样本 → 模式更集中具体
JSON_MAX_TOKENS = 1600
PATTERN_TEMPERATURE = 1.0     # 强制确定性输出
TOP_P = 1.0                   # 关闭 nucleus sampling
SLEEP_BETWEEN_CALLS = 1.0
RESUME_FROM_EXISTING = True
MAX_RETRIES = 5

# HTTP 会话
SESSION = requests.Session()
adapter = HTTPAdapter(pool_connections=10, pool_maxsize=20,
                      max_retries=Retry(total=3, connect=3, read=3))
SESSION.mount("https://", adapter)

# =============== 工具函数 ===============
CTRL_CHARS = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]")

def approx_tokens(text: str) -> int:
    if not text:
        return 0
    ascii_cnt = sum(1 for ch in text if ord(ch) < 128)
    non_ascii_cnt = len(text) - ascii_cnt
    return int(ascii_cnt / 4.0 + non_ascii_cnt + 1)

def sanitize_text(s: str) -> str:
    s = CTRL_CHARS.sub(" ", s)
    s = s.replace("\r\n", "\n").replace("\r", "\n")
    s = re.sub(r"[ \t]{2,}", " ", s)
    s = s.strip()
    if len(s) > SNIPPET_CHAR_LIMIT:
        s = s[:SNIPPET_CHAR_LIMIT]
    return s

def parse_top_txt(path: str, max_items: int) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        content = f.read()
    blocks = re.split(r"\n[-]{5,}\n", content)
    items = []
    for blk in blocks:
        blk = blk.strip()
        if not blk:
            continue
        m_rank = re.search(r"^Rank\s+(\d+)", blk, flags=re.M)
        if not m_rank:
            continue
        rank = int(m_rank.group(1))
        m_idx = re.search(r"sample_index:\s*([0-9]+)", blk)
        sample_index = int(m_idx.group(1)) if m_idx else -1
        m_score = re.search(r"projection_score:\s*([\-0-9\.eE]+)", blk)
        score = float(m_score.group(1)) if m_score else None
        m_snip = re.search(r"\bsnippet:\s*(.*)", blk, flags=re.S)
        text = sanitize_text(m_snip.group(1).strip() if m_snip else "")
        items.append({"rank": rank, "sample_index": sample_index, "score": score, "text": text})
    items.sort(key=lambda x: x["rank"])
    return items[:max_items]

# =============== 强化 Prompt（含 few-shot） ===============
PATTERN_SCHEMA = """
{
  "patterns": [
    {
      "id": "P001",
      "name": "Short name",
      "description": "Linguistic/structural traits, keywords, domain",
      "anchor_token": "Anchor words or phrases that reliably appear (multiple allowed, comma-separated)",
      "repeat_unit": "Smallest repeating unit (e.g., 'Key: Value' per line, or 'Q:\\nA:' as two lines)",
      "template": "Overall template (show several repeat_unit combinations; mark placeholders <TEXT>/<NUM>/<DATE>/<TIME>/<URL>/<EMAIL>/<LANG>/<KEY>/<LEVEL>, etc.)",
      "fields": [
        {"name": "placeholder", "type": "text|id|code|name|date|category|number|time|url|email", "values_or_rules": "Sampling rules/distribution/valid values", "notes": "Optional notes"}
      ],
      "length_control": {
        "target_tokens": 2049,
        "min_tokens": 1800,
        "max_tokens": 2300,
        "strategy": "How to increase/decrease repeat units, cut at natural boundaries, avoid EOD and garbled text"
      },
      "examples_from_batch": [0, 3, 7]
    }
  ],
  "coverage_estimate": "Approximate coverage description",
  "notes": "Any additional notes"
}
"""

FEW_SHOT_EXAMPLES = """
Example 1 batch (JSON array pattern):
Example (index=123, score=9.87):
<SNIPPET_START>
[
  {"id": 1, "name": "Alice", "age": 30},
  {"id": 2, "name": "Bob", "age": 25},
  ...
]
<SNIPPET_END>
-----
Example (index=456, score=9.12):
<SNIPPET_START>
[ {"title": "Book1", "author": "X"}, {"title": "Book2", "author": "Y"} ]
<SNIPPET_END>

Correct output:
{
  "patterns": [
    {
      "id": "P001",
      "name": "JSON Object Array",
      "description": "List of JSON objects in array format, common in data dumps",
      "anchor_token": "[,{,}]",
      "repeat_unit": "  {\\\"key1\\\": <VALUE1>, \\\"key2\\\": <VALUE2>},\\n",
      "template": "[\\n<repeat_unit repeated many times>]\\n]",
      "fields": [
        {"name": "key1", "type": "text", "values_or_rules": "Common field names like id/name/age/title", "notes": ""},
        {"name": "VALUE1", "type": "text|number", "values_or_rules": "Realistic values", "notes": ""}
      ],
      "length_control": {"target_tokens": 2049, "min_tokens": 1800, "max_tokens": 2300, "strategy": "Repeat objects until near target; end with ]"},
      "examples_from_batch": [0, 1]
    }
  ],
  "coverage_estimate": "Covers ~80% of batch",
  "notes": "Highly repetitive structure ideal for induction"
}

Example 2 batch (Markdown list pattern):
... (类似示例，省略以节省空间)
"""

SYSTEM_PROMPT = f"""You are an expert pattern extractor specialized in finding concrete repeating patterns for induction heads research.

CRITICAL RULES (MUST OBEY):
- Analyze the batch and extract ONLY specific, concrete repeating structural patterns that appear in MULTIPLE examples.
- Output EXACTLY one valid JSON object and NOTHING ELSE. No explanations, no markdown, no extra text before/after.
- If no clear repeating pattern exists in at least 3 examples, output: {{"patterns": [], "coverage_estimate": "No clear repeating patterns in this batch", "notes": "Samples too diverse"}}
- Aim for 1-5 high-quality patterns per batch. Avoid overly broad categories like "natural language".
- Each pattern must have reliable anchor_token and repeat_unit suitable for synthetic generation.

Here are two high-quality examples of desired output:
{FEW_SHOT_EXAMPLES}

Strict JSON schema to follow exactly:{PATTERN_SCHEMA}
"""

# =============== API 调用 ===============
def call_chat_api(messages: List[Dict[str, str]]) -> str:
    headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
    payload = {
        "model": MODEL,
        "messages": messages,
        "temperature": PATTERN_TEMPERATURE,
        "top_p": TOP_P,
        "max_tokens": JSON_MAX_TOKENS,
        "response_format": {"type": "json_object"},
    }
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            r = SESSION.post(ENDPOINT, headers=headers, json=payload, timeout=(10, 300))
            if r.status_code == 200:
                return r.json()["choices"][0]["message"]["content"]
            else:
                print(f"[API] Error {r.status_code}: {r.text[:300]}")
        except Exception as e:
            print(f"[API] Exception (attempt {attempt}): {e}")
        time.sleep(2 ** (attempt - 1))
    raise RuntimeError("API 调用多次失败")

# =============== JSON 兜底解析 ===============
def parse_json_loose(content: str) -> Dict[str, Any]:
    txt = content.strip()
    if txt.startswith("{"):
        try: return json.loads(txt)
        except: pass
    for fence in ("```json", "```"):
        m = re.search(rf"{re.escape(fence)}\s*(.*?)\s*{re.escape('```')}", content, flags=re.S)
        if m:
            try: return json.loads(m.group(1).strip())
            except: pass
    return {}

# =============== 批处理 ===============
def extract_patterns_dynamic(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    batches: List[List[str]] = []
    current, budget = [], 0

    def pack_item(it: Dict[str, Any]) -> Tuple[str, int]:
        s = (
            f"Example (index={it['sample_index']}, score={it['score']:.6f}):\n"
            f"<SNIPPET_START>\n{it['text']}\n<SNIPPET_END>\n-----\n"
        )
        return s, approx_tokens(s)

    for it in items:
        s, tk = pack_item(it)
        if budget + tk > USER_TOKEN_BUDGET and current:
            batches.append(current)
            current, budget = [], 0
        current.append(s)
        budget += tk
    if current:
        batches.append(current)

    pattern_objs: List[Dict[str, Any]] = []
    for bi, batch in enumerate(batches):
        json_path = os.path.join(PATTERN_BATCH_DIR, f"patterns_batch_{bi:03d}.json")
        raw_path = os.path.join(PATTERN_BATCH_DIR, f"raw_batch_{bi:03d}.txt")

        if RESUME_FROM_EXISTING and os.path.exists(json_path):
            try:
                with open(json_path, "r", encoding="utf-8") as f:
                    parsed = json.load(f)
                if parsed:
                    pattern_objs.append(parsed)
                    print(f"[模式] 批 {bi+1}/{len(batches)} 跳过（已存在）")
                    continue
            except Exception:
                pass

        user_content = (f"This batch contains {len(batch)} examples. Extract concrete repeating structural patterns that appear in multiple examples. "
                        f"Focus on specific formats suitable for induction heads.\n" + "".join(batch) +
                        "\n\nOutput ONLY the valid JSON object. No other text whatsoever.")

        messages = [
            {"role": "system", "content": SYSTEM_PROMPT.strip()},
            {"role": "user", "content": user_content},
        ]
        content = call_chat_api(messages)

        with open(raw_path, "w", encoding="utf-8") as f:
            f.write(content)

        obj = parse_json_loose(content)
        if not isinstance(obj, dict):
            obj = {"patterns": [], "coverage_estimate": "Parse failed", "notes": "Model output invalid"}

        with open(json_path, "w", encoding="utf-8") as f:
            f.write(json.dumps(obj, ensure_ascii=False, indent=2))

        pattern_objs.append(obj)
        print(f"[模式] 批 {bi+1}/{len(batches)} 完成")
        time.sleep(SLEEP_BETWEEN_CALLS + random.uniform(0, 0.3))

    return pattern_objs

# =============== 归一化与合并 ===============
def normalize_pattern(p: Dict[str, Any], idx: int) -> Dict[str, Any]:
    pid = p.get("id") or f"P{idx+1:03d}"
    name = p.get("name") or f"Pattern-{pid}"
    desc = p.get("description") or ""
    anchor = p.get("anchor_token") or ""
    repeat_unit = p.get("repeat_unit") or "<TEXT>"
    template = p.get("template") or repeat_unit
    fields = p.get("fields") or [{"name": "TEXT", "type": "text", "values_or_rules": "自由文本", "notes": ""}]
    norm_fields = []
    for f in fields:
        if not isinstance(f, dict):
            continue
        norm_fields.append({
            "name": f.get("name", "TEXT"),
            "type": f.get("type", "text"),
            "values_or_rules": f.get("values_or_rules", ""),
            "notes": f.get("notes", ""),
        })
    lc = p.get("length_control") or {}
    length_control = {
        "target_tokens": int(lc.get("target_tokens", 2049)),
        "min_tokens": int(lc.get("min_tokens", 1800)),
        "max_tokens": int(lc.get("max_tokens", 2300)),
        "strategy": lc.get("strategy", "重复 repeat_unit 直至接近目标；在自然边界截断；避免 EOD。"),
    }
    examples = p.get("examples_from_batch") or []
    return {
        "id": pid,
        "name": name,
        "description": desc,
        "anchor_token": anchor,
        "repeat_unit": repeat_unit,
        "template": template,
        "fields": norm_fields,
        "length_control": length_control,
        "examples_from_batch": examples,
    }

def merge_and_normalize(pattern_batches: List[Dict[str, Any]]) -> Dict[str, Any]:
    seen = set()
    merged_list = []
    for obj in pattern_batches:
        pats = obj.get("patterns", []) if isinstance(obj, dict) else []
        for p in pats:
            key = (str(p.get("name", "")).strip().lower(), str(p.get("template", "")).strip()[:200])
            if key in seen:
                continue
            seen.add(key)
            merged_list.append(p)
    out = []
    for i, p in enumerate(merged_list):
        out.append(normalize_pattern(p, i))
    return {
        "patterns": out,
        "coverage_estimate": "改进版抽取（小批次+强化prompt+zero temp）",
        "notes": "已去重，建议人工再筛一次具体模式。",
    }

def save_patterns_summary(merged: Dict[str, Any], path: str):
    lines = ["# 改进版模式摘要（更具体、更高质量）\n"]
    pats = merged.get("patterns", [])
    for i, p in enumerate(pats, start=1):
        lines.append(f"#{i} id={p.get('id','')} name={p.get('name','')}")
        lines.append(f"  anchor_token: {p.get('anchor_token','')}")
        lines.append(f"  repeat_unit : {p.get('repeat_unit','')}")
        tmpl = (p.get("template","") or "").strip()
        lines.append(f"  template    : {tmpl[:200]}{'...' if len(tmpl)>200 else ''}")
        lines.append(f"  fields      : {[f.get('name') for f in (p.get('fields') or [])]}")
        lines.append("")
    with open(path, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))

# =============== 主流程 ===============
def main():
    print(f"读取输入文件：{INPUT_TXT}")
    items = parse_top_txt(INPUT_TXT, MAX_ITEMS)
    print(f"成功读取片段数：{len(items)}（取前 {MAX_ITEMS}）")

    pattern_batches = extract_patterns_dynamic(items)
    merged = merge_and_normalize(pattern_batches)

    with open(PATTERNS_MERGED_JSON, "w", encoding="utf-8") as f:
        f.write(json.dumps(merged, ensure_ascii=False, indent=2))
    save_patterns_summary(merged, PATTERNS_SUMMARY_TXT)

    print(f"\n[完成] 合并后模式数：{len(merged.get('patterns', []))}")
    print(f"[保存] 合并 JSON：{PATTERNS_MERGED_JSON}")
    print(f"[保存] 摘要 TXT ：{PATTERNS_SUMMARY_TXT}")

if __name__ == "__main__":
    main()


