from __future__ import annotations
import asyncio
import json
import os
import re
from dataclasses import dataclass
from typing import Tuple, Optional, List

import httpx

# Optional local HF judge
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# AsyncOpenAI judge
from openai import AsyncOpenAI


# =========================
# Prompt templates (aligned with axbench/prompt_templates.py)
# =========================

CONCEPT_TEMPLATE = """[System]
Please act as an impartial judge and evaluate whether the specified concept is presented in the sentence fragment provided below. Focus solely on whether the concept is clearly incorporated, without regard for grammar or logical coherence.

Begin your evaluation by providing a short explanation of whether the concept is clearly incorporated. Point out references to the concept in the fragment, noting any minimal or awkward incorporations. After providing your explanation, rate the concept's relevance on a scale from 0 to 2, where 0 indicates the concept is not present at all, 1 indicates the concept is somewhat present but minimally or awkwardly incorporated, and 2 indicates the concept is more fully and effectively incorporated, with stronger and more natural integration. Provide your rating using this exact format: "Rating: [[score]]".

[Concept Start]
{concept}
[Concept End]

[Sentence Fragment Start]
{sentence}
[Sentence Fragment End]
"""

INSTRUCT_TEMPLATE = """[System]
Please act as an impartial judge and evaluate whether the sentence fragment provided below is related to the instruction. Focus solely on the degree of relatedness in terms of topic, regardless of grammar, coherence, or informativeness.

Begin your evaluation by providing a brief explanation of whether the sentence is related to the instruction, and point out references related to the instruction. After providing your explanation, rate the instruction relevance on a scale from 0 to 2, where 0 indicates the sentence is unrelated to the instruction, 1 indicates it is somewhat related but only minimally or indirectly relevant in terms of topic, and 2 indicates it is more clearly and directly related to the instruction. Provide your rating using this exact format: "Rating: [[score]]".

[Instruction Start]
{instruction}
[Instruction End]

[Sentence Fragment Start]
{sentence}
[Sentence Fragment End]
"""

FLUENCY_TEMPLATE = """[System]
Please act as an impartial judge and evaluate the fluency of the sentence fragment provided below. Focus solely on fluency, disregarding its completeness, relevance, coherence with any broader context, or informativeness.

Begin your evaluation by briefly describing the fluency of the sentence, noting any unnatural phrasing, awkward transitions, grammatical errors, or repetitive structures that may hinder readability. After providing your explanation, rate the sentence's fluency on a scale from 0 to 2, where 0 indicates the sentence is not fluent and highly unnatural (e.g., incomprehensible or repetitive), 1 indicates it is somewhat fluent but contains noticeable errors or awkward phrasing, and 2 indicates the sentence is fluent and almost perfect. Provide your rating using this exact format: "Rating: [[score]]".

[Sentence Fragment Start]
{sentence}
[Sentence Fragment End]
"""


# =========================
# Robust rating extraction (now supports floats 0.0–2.0)
# =========================
# Primary patterns: explicit 'Rating'/'Score' (English & Chinese), allow CJK punctuation and floats
_RATING_PATTERNS: List[re.Pattern] = [
    re.compile(r'Rating\s*[:：]?\s*\[\s*\[\s*([0-2](?:\.\d+)?)\s*\]\s*\]', re.IGNORECASE | re.DOTALL),
    re.compile(r'Score\s*[:：]?\s*\[\s*\[\s*([0-2](?:\.\d+)?)\s*\]\s*\]', re.IGNORECASE | re.DOTALL),
    re.compile(r'评分\s*[:：]?\s*\[\s*\[\s*([0-2](?:\.\d+)?)\s*\]\s*\]', re.IGNORECASE | re.DOTALL),
    re.compile(r'得分\s*[:：]?\s*\[\s*\[\s*([0-2](?:\.\d+)?)\s*\]\s*\]', re.IGNORECASE | re.DOTALL),
]
# Fallback patterns: any double-bracketed number (incl. full-width brackets) or JSON-like fields
_FALLBACK_PATTERNS: List[re.Pattern] = [
    re.compile(r'[\[\【]\s*[\[\【]\s*([0-2](?:\.\d+)?)\s*[\]\】]\s*[\]\】]', re.IGNORECASE | re.DOTALL),
    re.compile(r'"rating"\s*:\s*([0-2](?:\.\d+)?)', re.IGNORECASE | re.DOTALL),
    re.compile(r'"score"\s*:\s*([0-2](?:\.\d+)?)', re.IGNORECASE | re.DOTALL),
]

def _clamp_0_2_float(x: float) -> float:
    try:
        x = float(x)
    except Exception:
        return 0.0
    if x < 0.0:
        return 0.0
    if x > 2.0:
        return 2.0
    return x

def _extract_score(text: str) -> float:
    """
    Extract a rating in [0.0, 2.0] (float). Do NOT round to integers.
    Accepts variations like:
      - Rating: [[2]], Rating: [[1.5]], Score: [[2.0]], 评分：[[1.2]]
      - Full-width brackets: 【【1.7】】
      - JSON-like: {"rating": 1.8} or {"score": 0.5}
    If nothing matches, return 0.0.
    """
    if not text:
        return 0.0

    # Primary patterns
    for pat in _RATING_PATTERNS:
        m = pat.search(text)
        if m:
            try:
                return _clamp_0_2_float(float(m.group(1)))
            except Exception:
                pass

    # Fallback patterns
    for pat in _FALLBACK_PATTERNS:
        m = pat.search(text)
        if m:
            try:
                return _clamp_0_2_float(float(m.group(1)))
            except Exception:
                pass

    # Tail fallback: last standalone number 0–2 (possibly decimal) after rating keywords
    tail = re.search(r'(rating|score|评分|得分)[^0-9]*([0-2](?:\.\d+)?)', text, flags=re.IGNORECASE | re.DOTALL)
    if tail:
        try:
            return _clamp_0_2_float(float(tail.group(2)))
        except Exception:
            pass

    return 0.0

def harmonic_mean_0_2(c: float, i: float, f: float) -> float:
    """
    Harmonic mean on [0,2] using float scores.
    If any component is <= 0, return 0 (strict 'weakest link' behavior).
    """
    if c <= 0.0 or i <= 0.0 or f <= 0.0:
        return 0.0
    return 3.0 / (1.0 / c + 1.0 / i + 1.0 / f)


# =========================
# AsyncOpenAI Judge
# =========================
@dataclass
class OpenAIJudgeConfig:
    model: str = "gpt-4o-mini"
    timeout: float = 60.0
    base_url: str = "https://api.shubiaobiao.cn/v1"
    max_retries: int = 3
    # Debug
    debug: bool = False
    debug_truncate: int = 800

class AsyncOpenAIJudge:
    """
    Call three independent prompts via AsyncOpenAI and extract 0.0–2.0 ratings (floats).
    """
    def __init__(self, cfg: OpenAIJudgeConfig):
        self.cfg = cfg
        self.client = AsyncOpenAI(
            api_key=os.environ.get("OPENAI_API_KEY"),
            base_url=cfg.base_url,
            timeout=cfg.timeout,
            http_client=httpx.AsyncClient(
                limits=httpx.Limits(max_keepalive_connections=100, max_connections=1000),
                headers={"Connection": "close"},
            ),
            max_retries=cfg.max_retries,
        )

    async def _ask(self, system_text: str, user_text: str) -> str:
        """
        We pass the full template (including the [System] header and the "Rating: [[score]]" instruction)
        in the user message content to align exactly with axbench prompt templates.
        """
        messages = []
        if system_text:
            messages.append({"role": "system", "content": system_text})
        messages.append({"role": "user", "content": user_text})

        resp = await self.client.chat.completions.create(
            model=self.cfg.model,
            messages=messages,
            temperature=0.0,
        )
        return resp.choices[0].message.content or ""

    async def a_score(self, concept: str, instruction: str, sentence: str) -> Tuple[float, float, float]:
        concept_out = await self._ask(
            "",  # system channel intentionally empty; full template is in user content
            CONCEPT_TEMPLATE.format(concept=concept, sentence=sentence)
        )
        instruct_out = await self._ask(
            "",
            INSTRUCT_TEMPLATE.format(instruction=instruction, sentence=sentence)
        )
        fluency_out = await self._ask(
            "",
            FLUENCY_TEMPLATE.format(sentence=sentence)
        )

        c = _extract_score(concept_out)
        i = _extract_score(instruct_out)
        f = _extract_score(fluency_out)

        if self.cfg.debug:
            def trunc(x: Optional[str]) -> str:
                if not x:
                    return ""
                x = x.replace("\n", "\\n")
                return (x[:self.cfg.debug_truncate] + " ...<truncated>") if len(x) > self.cfg.debug_truncate else x
            print("[JUDGE DEBUG] --- Concept raw ---")
            print(trunc(concept_out))
            print(f"[JUDGE DEBUG] Concept score parsed -> {c}")
            print("[JUDGE DEBUG] --- Instruct raw ---")
            print(trunc(instruct_out))
            print(f"[JUDGE DEBUG] Instruct score parsed -> {i}")
            print("[JUDGE DEBUG] --- Fluency raw ---")
            print(trunc(fluency_out))
            print(f"[JUDGE DEBUG] Fluency score parsed -> {f}")

        return c, i, f

    def score(self, concept: str, instruction: str, sentence: str) -> Tuple[float, float, float]:
        return asyncio.run(self.a_score(concept, instruction, sentence))


# =========================
# Local HF Judge (kept)
# =========================
@dataclass
class LocalJudgeConfig:
    model_name: str = "google/gemma-2-2b-it"
    device: str = "cuda:0"
    max_new_tokens: int = 128
    # Debug
    debug: bool = False
    debug_truncate: int = 800

class LocalHFJudge:
    """
    Local HF judging: generate an explanation plus 'Rating: [[score]]' line, then parse.
    Uses the same axbench-aligned templates (embedded in user content).
    """
    def __init__(self, cfg: LocalJudgeConfig):
        self.cfg = cfg
        self.tok = AutoTokenizer.from_pretrained(cfg.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_name).to(cfg.device)
        if self.tok.pad_token_id is None:
            self.tok.pad_token = self.tok.eos_token

    @torch.no_grad()
    def _gen(self, system_text: str, user_text: str) -> str:
        # A simple chat-like prompt format that works for many instruction-tuned models
        prompt = (
            f"<start_of_turn>system\n{system_text}<end_of_turn>\n"
            f"<start_of_turn>user\n{user_text}<end_of_turn>\n"
            f"<start_of_turn>model\n"
        )
        inputs = self.tok(prompt, return_tensors="pt").to(self.cfg.device)
        out_ids = self.model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=self.cfg.max_new_tokens,
            pad_token_id=self.tok.pad_token_id,
            eos_token_id=self.tok.eos_token_id,
        )
        return self.tok.decode(out_ids[0], skip_special_tokens=True)

    def score(self, concept: str, instruction: str, sentence: str) -> Tuple[float, float, float]:
        concept_text = self._gen("", CONCEPT_TEMPLATE.format(concept=concept, sentence=sentence))
        instruct_text = self._gen("", INSTRUCT_TEMPLATE.format(instruction=instruction, sentence=sentence))
        fluency_text = self._gen("", FLUENCY_TEMPLATE.format(sentence=sentence))

        c = _extract_score(concept_text)
        i = _extract_score(instruct_text)
        f = _extract_score(fluency_text)

        if self.cfg.debug:
            def trunc(x: Optional[str]) -> str:
                if not x:
                    return ""
                x = x.replace("\n", "\\n")
                return (x[:self.cfg.debug_truncate] + " ...<truncated>") if len(x) > self.cfg.debug_truncate else x
            print("[JUDGE DEBUG] --- Concept raw (HF) ---")
            print(trunc(concept_text))
            print(f"[JUDGE DEBUG] Concept score parsed -> {c}")
            print("[JUDGE DEBUG] --- Instruct raw (HF) ---")
            print(trunc(instruct_text))
            print(f"[JUDGE DEBUG] Instruct score parsed -> {i}")
            print("[JUDGE DEBUG] --- Fluency raw (HF) ---")
            print(trunc(fluency_text))
            print(f"[JUDGE DEBUG] Fluency score parsed -> {f}")

        return c, i, f
