"""
LLM judge / rater for SAE auto-interpretation.

This module uses an OpenAI-compatible Chat Completions HTTP endpoint.
It is designed to be compatible with internal API gateways that provide
OpenAI-style JSON responses and optional SSE streaming.

The implementation mirrors the behavior of a typical high-throughput
aiohttp-based client:
- Supports retries with exponential backoff.
- Supports optional SSE streaming responses.
- Supports DashScope inspection disabling header used in some gateways.

The public interface is intentionally kept consistent with the previous
AsyncOpenAIJudge implementation:
  - generate_explanation(...)
  - score_latent(...)
"""

from __future__ import annotations

import asyncio
import json
import ssl
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import aiohttp


ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


def extract_score(text: str) -> float:
    """Extract a numeric score from judge output text.

    Expected formats:
      - "Score: 0.73"
      - "score=0.73"
      - "0.73"
    If multiple numbers exist, we take the last one.
    """
    import re

    nums = re.findall(r"[-+]?\d*\.\d+|\d+", text)
    if not nums:
        raise ValueError(f"Cannot extract score from: {text!r}")
    return float(nums[-1])


@dataclass
class APIJudgeConfig:
    """Configuration for an OpenAI-compatible judge endpoint."""
    model: str
    api_key: str
    base_url: str
    timeout: int = 60
    max_retries: int = 3
    max_concurrent: int = 10
    stream: bool = False
    close_dash_inspect: bool = True
    provider: Optional[str] = None


# Backward-compat aliases (older code imports these names)
OpenAIJudgeConfig = APIJudgeConfig


class AsyncAPIJudge:
    """Async judge implementation using aiohttp against an OpenAI-compatible API."""

    def __init__(self, cfg: APIJudgeConfig):
        self.cfg = cfg
        self._sem = asyncio.Semaphore(cfg.max_concurrent)

    def _build_headers(self) -> Dict[str, str]:
        headers = {
            "Authorization": f"Bearer {self.cfg.api_key}",
            "Content-Type": "application/json",
            "X-DashScope-DataInspection": '{"input": "disable", "output": "disable"}',
        }
        if not self.cfg.close_dash_inspect:
            headers.pop("X-DashScope-DataInspection", None)
        return headers

    def _is_response_complete(self, finish_reason: Optional[str]) -> bool:
        valid_reasons = {
            "stop",
            "length",
            "function_call",
            "content_filter",
            "tool_calls",
            "stop_sequence",
            "max_tokens",
            "content_block",
            None,
            "null",
        }
        return finish_reason in valid_reasons

    async def _handle_stream_response(self, response: aiohttp.ClientResponse) -> Tuple[str, Optional[str]]:
        """Parse SSE streaming response and return (content, finish_reason)."""
        content = ""
        finish_reason = None
        buffer = ""

        async for chunk in response.content.iter_any():
            chunk_text = chunk.decode("utf-8", errors="ignore")
            buffer += chunk_text

            while True:
                next_data = buffer.find("data: ", 1) if buffer.startswith("data: ") else -1
                if next_data == -1:
                    break

                message = buffer[:next_data].strip()
                buffer = buffer[next_data:]

                if not message.startswith("data: "):
                    continue
                message = message[6:]

                if message == "[DONE]":
                    continue

                try:
                    payload = json.loads(message)
                except json.JSONDecodeError:
                    continue

                if "keep_alive" in payload:
                    continue

                if "choices" in payload and payload["choices"]:
                    choice = payload["choices"][0]
                    if "delta" in choice:
                        delta = choice["delta"]
                        if "content" in delta and delta["content"]:
                            content += delta["content"]
                    elif "message" in choice and choice["message"]:
                        content = choice["message"].get("content", "") or content

                    fr = choice.get("finish_reason")
                    if fr:
                        finish_reason = fr

        return content, finish_reason

    async def _request_once(
        self,
        session: aiohttp.ClientSession,
        messages: List[Dict[str, str]],
        max_tokens: int,
    ) -> str:
        headers = self._build_headers()
        request_data: Dict[str, object] = {
            "model": self.cfg.model,
            "messages": messages,
            "max_tokens": max_tokens,
            "stream": self.cfg.stream,
        }

        if self.cfg.provider:
            request_data["dashscope_extend_params"] = {"provider": self.cfg.provider}

        async with session.post(
            self.cfg.base_url,
            headers=headers,
            json=request_data,
            timeout=self.cfg.timeout,
        ) as resp:
            if resp.status != 200:
                text = await resp.text()
                raise RuntimeError(f"HTTP {resp.status}: {text[:500]}")

            if self.cfg.stream:
                content, finish_reason = await self._handle_stream_response(resp)
            else:
                data = json.loads(await resp.text())
                content = data["choices"][0]["message"].get("content", "")
                finish_reason = data["choices"][0].get("finish_reason")

            if not self._is_response_complete(finish_reason):
                raise RuntimeError(f"Incomplete response: finish_reason={finish_reason}")

            if not content:
                raise RuntimeError("Empty content from judge API")

            return content

    async def _call_chat(self, messages: List[Dict[str, str]], max_tokens: int) -> str:
        backoff = 1.0
        last_err: Optional[Exception] = None

        async with self._sem:
            async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=ssl_context)) as session:
                for attempt in range(self.cfg.max_retries + 1):
                    try:
                        return await self._request_once(session, messages, max_tokens=max_tokens)
                    except Exception as e:
                        last_err = e
                        if attempt >= self.cfg.max_retries:
                            break
                        await asyncio.sleep(backoff)
                        backoff = min(backoff * 2.0, 16.0)

        raise RuntimeError(f"Judge API failed after retries: {last_err}")

    async def generate_explanation(
        self,
        prompt_messages: List[Dict[str, str]],
        max_tokens: int = 1024,
    ) -> str:
        return await self._call_chat(prompt_messages, max_tokens=max_tokens)

    async def score_latent(
        self,
        prompt_messages: List[Dict[str, str]],
        max_tokens: int = 256,
    ) -> float:
        text = await self._call_chat(prompt_messages, max_tokens=max_tokens)
        return extract_score(text)


# Backward-compat alias (older code imports this name).
AsyncOpenAIJudge = AsyncAPIJudge
