import html
import os
import re
import time
import json
import sqlite3
import hashlib
import threading
import numpy as np
import torch
from torch import nn
from openai import OpenAI


class _EmbeddingCache:
    def __init__(self, path=None):
        db_path = path or os.environ.get("EMBEDDING_CACHE_PATH", "embedding_cache.sqlite")
        self.conn = sqlite3.connect(db_path, check_same_thread=False)
        self.conn.execute("CREATE TABLE IF NOT EXISTS cache (k TEXT PRIMARY KEY, v BLOB)")
        self.lock = threading.Lock()
    def get(self, key):
        with self.lock:
            cur = self.conn.execute("SELECT v FROM cache WHERE k=?", (key,))
            row = cur.fetchone()
            return json.loads(row[0]) if row else None
    def set(self, key, vec):
        with self.lock:
            self.conn.execute("INSERT OR REPLACE INTO cache(k,v) VALUES(?,?)", (key, json.dumps(vec)))
            self.conn.commit()


class RemoteEmbedder:
    def __init__(self, model=None, cache_path=None):
        api_key = os.environ.get("OPENAI_API_KEY")
        base_url = os.environ.get("OPENAI_BASE_URL")
        if not api_key:
            raise RuntimeError("OPENAI_API_KEY is not set")
        self.client = OpenAI(api_key=api_key, base_url=base_url)
        self.model = model or os.environ.get("OPENAI_EMBED_MODEL", "text-embedding-3-large")
        self.cache = _EmbeddingCache(cache_path)
    def embed(self, text: str):
        key = hashlib.sha256((self.model + "\n" + text).encode("utf-8")).hexdigest()
        vec = self.cache.get(key)
        if vec is not None:
            return vec
        delay = 1.0
        for attempt in range(6):
            try:
                resp = self.client.embeddings.create(model=self.model, input=text)
                vec = resp.data[0].embedding
                self.cache.set(key, vec)
                return vec
            except Exception:
                if attempt == 5:
                    raise
                time.sleep(delay)
                delay = min(delay * 2, 16.0)


ORIG_QA_RE = re.compile(r"\*\*Original Question\*\*\s*(.+?)\s*\*\*Original Answer\*\*\s*(.+?)\s*$", re.S)
TRIPLET_XML_RE = re.compile(
    r"<\s*question\s*>\s*(.*?)\s*<\s*/\s*question\s*>\s*.*?"
    r"<\s*solution\s*>\s*(.*?)\s*<\s*/\s*solution\s*>\s*.*?"
    r"<\s*answer\s*>\s*(.*?)\s*<\s*/\s*answer\s*>",
    re.S | re.I
)
TRIPLET_BRACE_RE = re.compile(
    r"\{\s*question\s*\}\s*(.*?)\s*\{\s*solution\s*\}\s*(.*?)\s*\{\s*answer\s*\}\s*(.*)",
    re.S | re.I
)
LABEL_BLOCK_RE = re.compile(
    r"(?:^|\n)\s*(?:Question)\s*[:：]\s*(.*?)"
    r"(?:\n\s*(?:Solution)\s*[:：]\s*(.*?))?"
    r"(?:\n\s*(?:Answer)\s*[:：]\s*(.*))?",
    re.S | re.I
)


def _strip_code_fences(t: str) -> str:
    if not t:
        return ""
    t = t.strip()
    t = re.sub(r"^\s*```[a-zA-Z0-9_-]*\s*", "", t)
    t = re.sub(r"\s*```\s*$", "", t)
    return html.unescape(t)


def _normalize_tags(t: str) -> str:
    def _norm(m):  return f"<{m.group(1).lower().strip()}>"
    def _normc(m): return f"</{m.group(1).lower().strip()}>"
    t = re.sub(r"<\s*(question|solution|answer)\s*>", _norm,  t, flags=re.I)
    t = re.sub(r"<\s*/\s*(question|solution|answer)\s*>", _normc, t, flags=re.I)
    return t


def process_response(response: str):
    txt = _normalize_tags(_strip_code_fences(response or ""))
    mjson = re.search(r"\{.*\}", txt, flags=re.S)
    if mjson:
        try:
            obj = json.loads(mjson.group(0))
            q = obj.get("question") or obj.get("rewrite_question")
            s = obj.get("solution") or obj.get("rewrite_solution")
            a = obj.get("answer") or obj.get("rewrite_answer")
            if q and s and a:
                return {"rewrite_question": q.strip(), "rewrite_solution": s.strip(), "rewrite_answer": a.strip()}
        except Exception:
            pass
    m = TRIPLET_XML_RE.search(txt) or TRIPLET_BRACE_RE.search(txt) or LABEL_BLOCK_RE.search(txt)
    if m:
        g = m.groups()
        q = (g[0] if len(g) > 0 else "").strip()
        s = (g[1] if len(g) > 1 else "").strip()
        a = (g[2] if len(g) > 2 else "").strip()
        if q and a:
            return {"rewrite_question": q, "rewrite_solution": s, "rewrite_answer": a}
    return {}


def extract_original_qa(prompt_text: str):
    m = ORIG_QA_RE.search(prompt_text or "")
    if m:
        return m.group(1).strip(), m.group(2).strip()
    return None, None


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, 1), nn.Sigmoid()
        )
    def forward(self, x):
        return self.mlp(x).squeeze(-1)


@torch.no_grad()
def _predict_pref(model, e1, e2):
    model.eval()
    p = next(model.parameters())
    device = p.device
    dtype = p.dtype
    ei = torch.tensor(e1, dtype=dtype, device=device)
    ej = torch.tensor(e2, dtype=dtype, device=device)
    diff = ei - ej
    abs_diff = torch.abs(diff)
    x = torch.cat([ei, ej, diff, abs_diff], dim=0).unsqueeze(0)
    in_dim = model.mlp[0].in_features
    assert x.shape[-1] == in_dim, f"Input dim mismatch: x={x.shape[-1]} vs model.in_features={in_dim}"
    with torch.no_grad():
        y = model(x)
    return y.item()


class PreferenceScorer:
    def __init__(self, dim, ckpt_path=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = MLP(input_dim=dim * 4).to(self.device)
        ckpt = ckpt_path or os.environ.get("PREFERENCE_CKPT_PATH", "reward_model.pth")
        if os.path.exists(ckpt):
            self.model.load_state_dict(torch.load(ckpt, map_location=self.device))
        self.dim = dim
    @torch.no_grad()
    def score_direction(self, e2, e1):
        p1 = _predict_pref(self.model, e2, e1)
        p2 = _predict_pref(self.model, e1, e2)
        return float(p1 - p2)


_embedder = RemoteEmbedder(
    model=os.environ.get("OPENAI_EMBED_MODEL", "text-embedding-3-large"),
    cache_path=os.environ.get("EMBEDDING_CACHE_PATH", "embedding_cache.sqlite"),
)
_pref = PreferenceScorer(
    dim=int(os.environ.get("EMBED_DIM", "3072")),
    ckpt_path=os.environ.get("PREFERENCE_CKPT_PATH", "reward_model.pth"),
)

TARGET_KWS = [
    "<question>", "<solution>", "<answer>", "</question>", "</solution>", "</answer>",
]
MAX_LEN = int(os.environ.get("MAX_RESPONSE_LEN", "8192"))


def _keyword_score(txt: str) -> float:
    txt_l = (txt or "").lower()
    cnt = sum(1 for kw in TARGET_KWS if kw.lower() in txt_l)
    score = (2 * cnt - len(TARGET_KWS)) * 0.05 if cnt > 0 else -0.1
    if "justification" in (txt or "").lower():
        score -= 0.05
    return score


def _length_score(txt: str) -> float:
    return 0.1 if len(txt or "") < MAX_LEN else -0.1


def _embed(text: str):
    return _embedder.embed(text)


class _JudgeCache:
    def __init__(self, path=None):
        db_path = path or os.environ.get("JUDGE_CACHE_PATH", "judge_cache.sqlite")
        self.conn = sqlite3.connect(db_path, check_same_thread=False)
        self.conn.execute("CREATE TABLE IF NOT EXISTS cache (k TEXT PRIMARY KEY, v TEXT)")
        self.conn.commit()
    def get(self, key: str):
        cur = self.conn.execute("SELECT v FROM cache WHERE k=?", (key,))
        row = cur.fetchone()
        return row[0] if row else None
    def set(self, key: str, val: str):
        self.conn.execute("INSERT OR REPLACE INTO cache(k, v) VALUES(?, ?)", (key, val))
        self.conn.commit()


def _make_client() -> OpenAI:
    api_key = os.environ.get("OPENAI_API_KEY")
    base_url = os.environ.get("OPENAI_BASE_URL")
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY is not set")
    return OpenAI(api_key=api_key, base_url=base_url)


_JUDGE_CLIENT = None
_JUDGE_CACHE = _JudgeCache()
_SYSTEM_PROMPT = (
    "You are a math teacher.\n"
    "Task: Judge whether the student's solution and final answer are correct based on the given math problem.\n"
    "Output rules:\n"
    '- If correct → output exactly "1"\n'
    '- If incorrect → output exactly "0"\n'
    "Do not include any explanation."
)


def _judge_correct(question: str, solution: str, answer: str, model_name: str = None, max_retries: int = 6) -> float:
    global _JUDGE_CLIENT
    if _JUDGE_CLIENT is None:
        _JUDGE_CLIENT = _make_client()
    mdl = model_name or os.environ.get("OPENAI_JUDGE_MODEL", "gpt-4o-mini")
    content = f"Question: {question}\nSolution: {solution}\nAnswer: {answer}"
    key = hashlib.sha256((mdl + "\n" + content).encode("utf-8")).hexdigest()
    cached = _JUDGE_CACHE.get(key)
    if cached is not None:
        return 1 if cached.strip() == "1" else -1
    delay = 1.0
    for attempt in range(max_retries):
        try:
            resp = _JUDGE_CLIENT.chat.completions.create(
                model=mdl,
                messages=[
                    {"role": "system", "content": _SYSTEM_PROMPT},
                    {"role": "user", "content": content},
                ],
                stream=False,
            )
            txt = (resp.choices[0].message.content or "").strip()
            out = "1" if txt == "1" else "0"
            _JUDGE_CACHE.set(key, out)
            return 1 if out == "1" else -1
        except Exception:
            if attempt == max_retries - 1:
                return 0.0
            time.sleep(delay)
            delay = min(delay * 2, 16.0)


def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> float:
    k = _keyword_score(solution_str)
    l = _length_score(solution_str)
    item = process_response(solution_str)
    if not item:
        return k + l - 0.1
    r = _judge_correct(item["rewrite_question"], item["rewrite_solution"], item["rewrite_answer"])
    try:
        e1 = (extra_info or {}).get("orig_embedding", [])
        if not e1:
            prompt_text = (extra_info or {}).get("prompt_text", "")
            oq = (extra_info or {}).get("orig_q", "") or extract_original_qa(prompt_text)[0]
            oa = (extra_info or {}).get("orig_a", "") or extract_original_qa(prompt_text)[1]
            if oq and oa:
                e1 = _embed(f"Q:{oq}\nA:{oa}")
        if e1:
            e2 = _embed(f"Question:{item['rewrite_question']}\nAnswer:{item['rewrite_answer']}")
            d = _pref.score_direction(e2, e1)
        else:
            d = -0.5
    except Exception:
        d = -0.5
    score = k + l + 0.5 * r + 0.5 * d
    return float(score)
