# llm_judge.py
import os, re, json, requests, time, random, sys
from requests.exceptions import ConnectTimeout, ReadTimeout, HTTPError
from typing import Literal, Tuple, List, Optional
from dataclasses import dataclass
from requests.exceptions import ConnectionError as RequestsConnectionError

JudgeGroup = Literal["yesno", "truefalse", "validinvalid"]

SYSTEM = """You are a strict short-answer validator.
Decide the ONE token label for the user's text: choose ONLY from the allowed set.
Prefer the earliest decisive clause (often the first sentence) and ignore later explanations.

Decision rules (apply in order):
1) "X tells the truth" or "X doesn't tell a lie" → Yes.
2) "X does not tell the truth" or "X tells a lie" → No; "not X" → "No".
3) "X returns to the starting point" → Yes; "X does not return to the starting point" → No.
4) Expressions like "plausible", "likely true", "unlikely false", "probably true" → Yes.
5) Expressions like "not plausible", "implausible", "likely false", "unlikely true", "probably false" → No.
6) If the text is meaningless (e.g., random characters, gibberish, or an incomplete/broken sentence with no clear semantics), return "Undefined".
7) If the text does not clearly express a definite viewpoint (yes/no; true/false; valid/invalid), e.g., if the text includes "cannot determine", "unknown", "insufficient information" etc, return "Undefined".
8) If the allowed labels are not of the yes/no type, output the semantically corresponding label from the provided Allowed labels set (for example, use True/False or Valid/Invalid as appropriate).

Return a JSON object only, no extra text.
"""

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.bbh.causal_judgement.question_inputs import CAUSAL_JUDGEMENT_QUESTIONS
from src.bbh.web_of_lies.question_inputs import WEB_OF_LIES_QUESTIONS
from src.bbh.system_prompt import SYSTEM_PROMPT

def find_question(id: str, task: str) -> str:
    if task.lower() == "causal_judgement":
        return CAUSAL_JUDGEMENT_QUESTIONS[id]
    if task.lower() == "web_of_lies":
        return WEB_OF_LIES_QUESTIONS[id]
    return ""

def _first_sentence(t: str, max_chars=400) -> str:
    t = (t or "").strip()
    m = re.search(r"[.!?。！？](?:\s|$)", t)
    return t[:m.end()].strip() if m else t[:min(len(t), max_chars)]

def _schema_for(group: JudgeGroup):
    if group == "yesno":
        enum = ["Yes", "No", "Undefined"]
    elif group == "truefalse":
        enum = ["True", "False", "Undefined"]
    else:
        enum = ["Valid", "Invalid", "Undefined"]
    schema = {
        "type": "object",
        "properties": {
            "label": {"type": "string", "enum": enum}
        },
        "required": ["label"],
        "additionalProperties": False
    }
    return enum, schema

def _canon(group: JudgeGroup, label: str) -> str:
    if label is None:
        return "Undefined"

    s = (label or "").strip().lower()

    if s in {"undefined", "unknown", "undetermined", "cannot determine", "insufficient information", "unclear", "ambiguous", "can't tell", "not enough info"}:
        return "Undefined"

    POS = {"y", "yes", "yeah", "yep", "true", "t", "valid", "v", "correct", "right", "1", "✔", "✅", "✓"}
    NEG = {"n", "no", "nope", "nah", "false", "f", "invalid", "i", "incorrect", "wrong", "0", "✘", "❌", "✗", "×", "✕", "untrue", "inaccurate"}

    # —— 1：双重否定 → 正向 —— 
    if re.search(r"\bnot\s+(false|incorrect|wrong|unlikely|implausible)\b", s):
        if group == "yesno":     return "Yes"
        if group == "truefalse": return "True"
        return "Valid"

    # —— 2：likely/probably true/false —— 
    if re.search(r"\b(likely|probably)\s+true\b", s):
        if group == "yesno":     return "Yes"
        if group == "truefalse": return "True"
        return "Valid"
    if re.search(r"\b(likely|probably)\s+false\b", s):
        if group == "yesno":     return "No"
        if group == "truefalse": return "False"
        return "Invalid"

     # —— 3：unlikely true/false（与 SYSTEM 规则保持一致）——
    if re.search(r"\bunlikely\s+true\b", s):
        if group == "yesno":     return "No"
        if group == "truefalse": return "False"
        return "Invalid"
    if re.search(r"\bunlikely\s+false\b", s):
        if group == "yesno":     return "Yes"
        if group == "truefalse": return "True"
        return "Valid"

    # 句式/短语信号（先查否定，再查肯定；包含弯引号）
    neg_hit = bool(re.search(
        r"\b(?:not\s+(?:true|plausible|correct|right|valid|accurate)"
        r"|does(?:\s+not|n(?:'|’)?t)\s+hold"
        r"|unlikely|implausible|incorrect|wrong|untrue|inaccurate)\b",
        s
    ))

    # 为避免 "not true" / "doesn't hold" 触发正向词命中，
    # 先把这些否定触发词（仅当其后紧跟正向关键词时）“抹掉”，再做正向匹配
    s_pos_guarded = re.sub(
        r"\bnot\s+(?:true|plausible|correct|right|valid|accurate)\b",
        "__NEG_BLOCK__", s
    )
    s_pos_guarded = re.sub(
        r"\bdoes(?:\s+not|n(?:'|’)?t)\s+hold\b",
        "__NEG_BLOCK__", s_pos_guarded
    )

    pos_hit = bool(
        re.search(
            r"\b(true|plausible|correct|right|valid|accurate)\b",
            s_pos_guarded
        )
    )

    # 直接命中集合或单向命中
    if s in POS or (pos_hit and not neg_hit):
        if group == "yesno":     return "Yes"
        if group == "truefalse": return "True"
        return "Valid"
    if s in NEG or (neg_hit and not pos_hit):
        if group == "yesno":     return "No"
        if group == "truefalse": return "False"
        return "Invalid"

    # 回退到原先的 startswith 规则
    if group == "yesno":
        return "Yes" if s.startswith("y") else ("No" if s.startswith("n") else "Undefined")
    if group == "truefalse":
        # 兼容 "yes/no" 误用到 true/false 任务
        if s.startswith("y"): return "True"
        if s.startswith("n"): return "False"
        return "True" if s.startswith("t") else ("False" if s.startswith("f") else "Undefined")
    # validinvalid
    if s.startswith("y"): return "Valid"   # 兼容 "yes"→"Valid"
    if s.startswith("n"): return "Invalid" # 兼容 "no" →"Invalid"
    return "Valid" if s.startswith("v") else ("Invalid" if s.startswith("i") else "Undefined")


@dataclass
class _ModelEntry:
    name: str
    price: float = 0.0
    fail_count: int = 0
    last_fail_ts: float = 0.0
    cooldown: float = 0.0  # 秒

    def available(self) -> bool:
        return (time.time() - self.last_fail_ts) >= self.cooldown


class LLMJudge:
    """硅基流动 OpenAI 兼容 /chat/completions JSON-Schema 约束输出（支持多模型轮换）"""
    def __init__(self,
                 model: str = "Qwen/Qwen2.5-32B-Instruct",
                 base_url: str = "https://api.siliconflow.cn/v1",
                 timeout: Tuple[float, float] = (5, 45),
                 max_retries: int = 3,
                 backoff_base: float = 5,
                 backoff_cap: float = 8.0,
                 jitter: float = 0.2,
                 retry_statuses = (408, 425, 429, 500, 502, 503, 504),
                 model_pool: Optional[List[Tuple[str, float]]] = None):
        self.model = model
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.session = requests.Session()

        self.max_retries = int(max_retries)
        self.backoff_base = float(backoff_base)
        self.backoff_cap = float(backoff_cap)
        self.jitter = float(jitter)
        self.retry_statuses = set(int(x) for x in retry_statuses)

        # 模型级冷却参数（用于分散并发）
        self.cooldown_base = 2
        self.cooldown_cap = 33.3

        # 从 siliconflow_apikey.json 中获取api_key
        with open("siliconflow_apikey.json", "r") as f:
            api_key = json.load(f)["apikey"]
        self.api_key = api_key or ""

        # —— 模型池（带价格）——
        default_pool = [
            # 价格单位：元（仅调度排序用途；免费=0）
            ("Qwen/Qwen2.5-32B-Instruct",                 2.00),  # 放入池，参与≤2首发挑选
            ("Qwen/Qwen3-30B-A3B-Instruct-2507",          2.80),
            ("zai-org/GLM-4.5-Air",                       6.00),
            ("Qwen/Qwen3-Next-80B-A3B-Instruct",          4.00),
            ("Qwen/Qwen3-Next-80B-A3B-Thinking",          4.00),
            ("Qwen/QwQ-32B",                              4.00),
            ("Qwen/Qwen3-32B",                            4.00),
            ("Qwen/Qwen3-14B",                            2.00),
            ("Pro/deepseek-ai/DeepSeek-V3.2-Exp",         3.00),
            ("Qwen/Qwen2.5-72B-Instruct-128K" ,           4.13),
            ("Qwen/Qwen2.5-72B-Instruct" ,                4.13),
            ("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",  1.26),
            ("deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",  0.70),
            ("THUDM/GLM-4-32B-0414",                      1.89),
            ("Qwen/Qwen2.5-14B-Instruct",                 0.70),
            # 免费模型示例（如有）
            ("deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",     0.00),
            ("Qwen/Qwen3-8B",                             0.00),
            ("Qwen/Qwen2.5-7B-Instruct",                  0.00),
            ("THUDM/GLM-Z1-9B-0414",                      0.00),
            ("THUDM/GLM-4-9B-0414",                       0.00),
        ]
        if model_pool is None:
            model_pool = default_pool

        # 构建唯一模型列表（保留 self.model，但不固定首发）
        all_names = [self.model] + [m for (m, _) in model_pool]
        seen, ordered = set(), []
        for n in all_names:
            if n in seen:
                continue
            seen.add(n)
            p = next((price for (name, price) in model_pool if name == n), 0.0)
            ordered.append(_ModelEntry(name=n, price=p))
        self._pool: List[_ModelEntry] = ordered

    # —— 动态生成当次尝试序列 —— 
    def _attempt_sequence(self) -> List[_ModelEntry]:
        entries = list(self._pool)
        # 先尊重冷却；若全部处于冷却，则等待最短剩余冷却时间再继续，避免无效打点。
        # avail = [e for e in entries if e.available()] or entries
        avail = [e for e in entries if e.available()]
        if not avail:
            now = time.time()
            sleep_s = min(max(0.0, e.cooldown - (now - e.last_fail_ts)) for e in entries)
            if sleep_s > 0:
                time.sleep(min(sleep_s, 10.0))  # 最多等 10s
            avail = [e for e in entries if e.available()] or entries

        # 1) 首发池：仅限这11个模型
        primary_names = {
            "Qwen/Qwen2.5-32B-Instruct",
            "Qwen/Qwen3-14B",
            "Qwen/Qwen2.5-72B-Instruct-128K",
            "Qwen/Qwen2.5-72B-Instruct",
            "Qwen/Qwen3-32B",
            "Qwen/QwQ-32B",
            "Qwen/Qwen3-30B-A3B-Instruct-2507",
            "Qwen/Qwen3-Next-80B-A3B-Thinking",
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "zai-org/GLM-4.5-Air",
            "Pro/deepseek-ai/DeepSeek-V3.2-Exp",
        }
        primary_pool = [e for e in avail if e.name in primary_names and e.price > 0.0]
        if primary_pool:
            # 均匀随机（如需区分可在此加权）
            primary = random.choice(primary_pool)
        else:
            # 若都不可用/不在可用集：退化为“最贵次发付费优先”，否则随即挑一个可用模型
            paid_avail = [e for e in avail if e.price > 0.0]
            primary = max(paid_avail, key=lambda x: x.price) if paid_avail else random.choice(avail)

        # 2) 其余：分付费/免费；付费按加权“洗牌”，免费仅在最后随机
        rest = [e for e in avail if e is not primary]
        paid_rest = [e for e in rest if e.price > 0.0]
        free_rest = [e for e in rest if e.price == 0.0]

        # 付费模型的权重：2款高权重，其余低权重
        high_weight_names = {
            "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
            "THUDM/GLM-4-32B-0414",
        }

        def paid_weight(e: _ModelEntry) -> float:
            # 可按需微调这两个数
            return 2.5 if e.name in high_weight_names else 1.0

        def weighted_shuffle(items: List[_ModelEntry], wfunc) -> List[_ModelEntry]:
            # 加权不放回抽样，生成一个“加权随机顺序”
            pool = list(items)
            order: List[_ModelEntry] = []
            while pool:
                weights = [max(1e-6, wfunc(x)) for x in pool]
                idx = random.choices(range(len(pool)), weights=weights, k=1)[0]
                order.append(pool.pop(idx))
            return order

        paid_order = weighted_shuffle(paid_rest, paid_weight)

        # 免费模型：只有在前面所有付费都失败时才会触达，因此这里只做无权重随机顺序
        random.shuffle(free_rest)

        return [primary] + paid_order + free_rest

    # —— 对外主入口 —— 
    def judge(self, text: str, group: JudgeGroup = "yesno", task: Optional[str] = None, id: Optional[str] = None) -> Tuple[str, str]:
        enum, schema = _schema_for(group)
        # head = _first_sentence(text)
        head = (text or "").strip()

        if task.lower() == "causal_judgement" or task.lower() == "web_of_lies":
            # 根据完整 question 的匹配到精简后的 question
            QUESTION = find_question(id, task)
            base_payload = {
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT.get(task.lower(), SYSTEM)},
                    {"role": "user", "content": [
                        {"type":"text","text":"Decide the short answer label."},
                        {"type":"text","text":f"Allowed labels: {', '.join(enum)}"},
                        {"type":"text","text":f"QUESTION:\n{QUESTION}"},
                        {"type":"text","text":f"ANSWER TEXT:\n{head}"},
                    ]}
                ],
                "temperature": 0.0,
                "response_format": {
                    "type": "json_schema",
                    "json_schema": {"name": "bbh_label", "schema": schema, "strict": True}
                }
            }
        else:
            base_payload = {
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT.get(task.lower(), SYSTEM)},
                    {"role": "user", "content": [
                        {"type":"text","text":"Decide the short answer label."},
                        {"type":"text","text":f"Allowed labels: {', '.join(enum)}"},
                        {"type":"text","text":f"ANSWER TEXT:\n{head}"}
                    ]}
                ],
                "temperature": 0.0,
                "response_format": {
                    "type": "json_schema",
                    "json_schema": {"name": "bbh_label", "schema": schema, "strict": True}
                }
            }
        headers = {"Content-Type": "application/json"}
        if getattr(self, "api_key", None):
            headers["Authorization"] = f"Bearer {self.api_key}"
        url = f"{self.base_url}/chat/completions"

        seq = self._attempt_sequence()
        primary, rest = seq[0], seq[1:]

        # 1) 首发：带重试
        try:
            payload = dict(base_payload, model=primary.name)
            r = self._post_with_retries(url, headers, payload)
            out = self._parse_response(r, group, head)
            self._mark_success(primary)
            return out
        except Exception as e:
            self._mark_fail(primary, e)

        # 2) 其余：快速单次轮换
        for entry in rest:
            try:
                payload = dict(base_payload, model=entry.name)
                r = self._post_once(url, headers, payload)
                out = self._parse_response(r, group, head)
                self._mark_success(entry)
                return out
            except Exception as e:
                self._mark_fail(entry, e)
                continue

        # 3) 兜底：再试一次首发
        payload = dict(base_payload, model=primary.name)
        r = self._post_with_retries(url, headers, payload)
        out = self._parse_response(r, group, head)
        self._mark_success(primary)
        return out

    # —— HTTP 层 —— 
    def _post_once(self, url: str, headers: dict, payload: dict) -> requests.Response:
        r = self.session.post(url, headers=headers, json=payload, timeout=self.timeout)
        if r.status_code in self.retry_statuses:
            raise HTTPError(f"retryable status {r.status_code}", response=r)
        r.raise_for_status()
        return r

    def _sleep_for_attempt(self, attempt: int) -> float:
        base = min(self.backoff_cap, self.backoff_base * (2 ** (attempt - 1)))
        return base * (1.0 + self.jitter * random.random())

    def _post_with_retries(self, url: str, headers: dict, payload: dict) -> requests.Response:
        last_err = None
        for attempt in range(1, self.max_retries + 1):
            try:
                r = self.session.post(url, headers=headers, json=payload, timeout=self.timeout)
                if r.status_code in self.retry_statuses:
                    ra = r.headers.get("Retry-After")
                    # if ra and attempt < self.max_retries:
                    #     try:
                    #         wait = float(ra)
                    #     except Exception:
                    #         wait = self._sleep_for_attempt(attempt)
                    if ra and attempt < self.max_retries:
                        try:
                            wait = float(ra)
                        except Exception:
                            # HTTP-date 格式
                            from email.utils import parsedate_to_datetime
                            import datetime as _dt
                            try:
                                when = parsedate_to_datetime(ra)
                                if when.tzinfo is None:
                                        when = when.replace(tzinfo=_dt.timezone.utc)
                                wait = max(0.0, (when - _dt.datetime.now(_dt.timezone.utc)).total_seconds())
                            except Exception:
                                wait = self._sleep_for_attempt(attempt)
                        print(f"[LLMJudge] server asked to retry after {wait:.2f}s (status {r.status_code})")
                        time.sleep(wait)
                        continue
                    raise HTTPError(f"retryable status {r.status_code}", response=r)

                r.raise_for_status()
                return r
            except (ReadTimeout, ConnectTimeout, RequestsConnectionError) as e:
                last_err = e
            except HTTPError as e:
                if getattr(e, "response", None) is not None and e.response is not None:
                    if e.response.status_code in self.retry_statuses:
                        last_err = e
                    else:
                        raise
                else:
                    last_err = e

            if attempt < self.max_retries:
                wait = self._sleep_for_attempt(attempt)
                print(f"[LLMJudge] {type(last_err).__name__}: {last_err}; retry {attempt}/{self.max_retries} in {wait:.2f}s")
                time.sleep(wait)

        raise last_err if last_err is not None else RuntimeError("Unknown error in _post_with_retries")

    # —— 解析响应 —— 
    def _parse_response(self, r: requests.Response, group: JudgeGroup, head: str):
        try:
            data = r.json()
        except Exception:
            return ("Undefined", head)
        
        try:
            content = data["choices"][0]["message"]["content"]
        except (KeyError, IndexError, TypeError):
            return ("Undefined", head)

        if isinstance(content, dict):
            label = content.get("label", "")
        else:
            if not content:
                content = ""
            obj = None
            try:
                obj = json.loads(content)
            except Exception:
                m = re.search(r"\{.*?\}", content, flags=re.S)
                if m:
                    try:
                        obj = json.loads(m.group(0))
                    except Exception:
                        obj = None
            label = (obj or {}).get("label", (content or "").strip())

            if not isinstance(label, str):
                label = str(label)

        canon = _canon(group, label)
        # 兜底：若解析失败或返回值不在 schema 中，统一给 Undefined
        allowed, _ = _schema_for(group)
        if canon not in allowed:
            canon = "Undefined"
        return canon, head

    # —— 冷却与失败标记 —— 
    def _mark_fail(self, entry: _ModelEntry, err: Exception):
        entry.fail_count += 1
        entry.last_fail_ts = time.time()
        base = min(self.cooldown_cap, max(1.2, self.cooldown_base * (2 ** (entry.fail_count - 1))))
        entry.cooldown = base * (1.0 + self.jitter * random.random())
        status = getattr(getattr(err, "response", None), "status_code", None)
        print(f"[LLMJudge] model '{entry.name}' fail ({status}); cooldown {entry.cooldown:.2f}s")

    def _mark_success(self, entry: _ModelEntry):
        entry.fail_count = 0
        entry.cooldown = 0.0