# llm_anthropic.py
from __future__ import annotations

import json
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional

import config

try:
    from anthropic import Anthropic
except Exception:
    Anthropic = None  # type: ignore


def _looks_like_anthropic_model(name: str) -> bool:
    n = (name or "").strip().lower()
    return n.startswith("claude-")


_WARNED_MODEL_MISMATCH = False


def _strip_code_fences(s: str) -> str:
    """
    If the model returns ```json ... ``` or ``` ... ```, extract the inner block.
    Otherwise return original string.
    """
    s = (s or "").strip()
    if not s.startswith("```"):
        return s

    # Find first newline after the opening fence
    first_nl = s.find("\n")
    if first_nl == -1:
        return s

    # Find closing fence
    closing = s.rfind("```")
    if closing <= first_nl:
        return s

    inner = s[first_nl + 1 : closing].strip()
    return inner


def _extract_first_json_object(text: str) -> str:
    """
    Extract the first top-level JSON object {...} from text.
    Handles surrounding prose and respects strings/escapes.
    Raises ValueError if none found.
    """
    t = _strip_code_fences((text or "").strip())
    if not t:
        raise ValueError("Empty response text (no JSON to parse).")

    # Find first '{'
    start = t.find("{")
    if start == -1:
        preview = t[:200].replace("\n", "\\n")
        raise ValueError(f"No '{{' found in response. Preview: {preview}")

    in_str = False
    esc = False
    depth = 0
    for i in range(start, len(t)):
        ch = t[i]

        if in_str:
            if esc:
                esc = False
                continue
            if ch == "\\":
                esc = True
                continue
            if ch == '"':
                in_str = False
            continue

        # not in string
        if ch == '"':
            in_str = True
            continue
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return t[start : i + 1]

    preview = t[start : min(len(t), start + 200)].replace("\n", "\\n")
    raise ValueError(f"Unbalanced JSON braces; couldn't find end of object. Preview: {preview}")


@dataclass
class AnthropicClient:
    """
    Thin wrapper around Anthropic SDK with helpers:
    - text_messages
    - text
    - json_object

    Robustness:
    - If caller passes a non-Anthropic model name, fall back to config.ANTHROPIC_MODEL_MAIN and warn once.
    - json_object extracts the first JSON object from the response (handles code fences / extra text).
    """

    api_key: str = config.ANTHROPIC_API_KEY

    def __post_init__(self):
        if Anthropic is None:
            raise ImportError("anthropic package is not installed. Run: pip install anthropic")
        if not self.api_key:
            raise RuntimeError("ANTHROPIC_API_KEY is missing.")
        self.client = Anthropic(api_key=self.api_key)

    def _normalize_model(self, model: str) -> str:
        global _WARNED_MODEL_MISMATCH
        m = (model or "").strip() or config.ANTHROPIC_MODEL_MAIN

        if not _looks_like_anthropic_model(m):
            fallback = config.ANTHROPIC_MODEL_MAIN
            if not _WARNED_MODEL_MISMATCH:
                _WARNED_MODEL_MISMATCH = True
                print(
                    f"[WARN] AnthropicClient got non-Anthropic model='{m}'. "
                    f"Falling back to ANTHROPIC_MODEL_MAIN='{fallback}'. "
                    f"(Fix your config/build_dataset call site.)"
                )
            m = fallback

        return m

    def text_messages(
        self,
        model: str,
        system: str,
        messages,
        max_tokens: int,
        temperature: Optional[float] = None,
        retries: int = 3,
    ) -> str:
        model = self._normalize_model(model)

        last = None
        for _ in range(max(1, retries)):
            try:
                resp = self.client.messages.create(
                    model=model,
                    system=system,
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=temperature if temperature is not None else 0.0,
                )

                # Anthropic returns a list of content blocks.
                # Prefer "text" blocks, but also fall back to any block that has .text.
                parts = []
                for block in getattr(resp, "content", []) or []:
                    btype = getattr(block, "type", "")
                    if btype == "text" and getattr(block, "text", None) is not None:
                        parts.append(block.text)
                    elif getattr(block, "text", None) is not None:
                        parts.append(block.text)

                txt = "".join(parts).strip()
                return txt

            except Exception as e:
                last = e
                time.sleep(0.5)

        raise RuntimeError(f"Anthropic text_messages failed: {last}")

    def text(
        self,
        model: str,
        system: str,
        user_input: str,
        max_tokens: int,
        temperature: Optional[float] = None,
        retries: int = 3,
    ) -> str:
        return self.text_messages(
            model=model,
            system=system,
            messages=[{"role": "user", "content": user_input}],
            max_tokens=max_tokens,
            temperature=temperature,
            retries=retries,
        )

    def json_object(
        self,
        model: str,
        system: str,
        user_input: str,
        max_tokens: int,
        temperature: Optional[float] = None,
        retries: int = 3,
    ) -> Dict[str, Any]:
        last = None
        last_txt_preview = ""

        for _ in range(max(1, retries)):
            try:
                txt = self.text(
                    model=model,
                    system=system,
                    user_input=user_input,
                    max_tokens=max_tokens,
                    temperature=temperature,
                    retries=1,
                )

                if not (txt or "").strip():
                    raise ValueError("Empty response from Anthropic (no text content).")

                # Extract the first JSON object even if the model adds prose or code fences.
                json_str = _extract_first_json_object(txt)
                return json.loads(json_str)

            except Exception as e:
                last = e
                try:
                    last_txt_preview = (txt or "")[:250].replace("\n", "\\n")  # type: ignore[name-defined]
                except Exception:
                    last_txt_preview = ""
                time.sleep(0.5)

        msg = f"Anthropic json_object failed: {last}"
        if last_txt_preview:
            msg += f" | last_response_preview='{last_txt_preview}'"
        raise RuntimeError(msg)
