# cirbench/utils/api/claude.py
from __future__ import annotations
import os, time, re
from typing import List, Dict, Any, Optional

# Small helper: parse boolean-ish env vars ("1/true/yes/on")
def _env_on(name: str, default: str = "0") -> bool:
    v = os.getenv(name, default)
    return isinstance(v, str) and v.strip().lower() in ("1", "true", "yes", "on")

# --- IR extraction helpers -------------------------------------------------
_IR_BLOCK_RE = re.compile(r"(<CIR_JSON>.*?</CIR_JSON>|<IR_OUT>.*?</IR_OUT>)", re.S)

def _extract_ir_block(text: str) -> tuple[str, bool]:
    """Return (block, True) if an IR block is found, else (text, False)."""
    if not text:
        return text, False
    m = _IR_BLOCK_RE.search(text)
    if m:
        return m.group(1), True
    return text, False

try:
    # Requires: pip install anthropic>=0.30
    from anthropic import Anthropic
except Exception:
    Anthropic = None


class _Out:
    def __init__(self, text: str, meta: Dict[str, Any]):
        self.text = text
        self.meta = meta


# ---- Normalizers ---------------------------------------------------------

def _safe_get_usage_tokens_anthropic(usage: Any) -> tuple[Optional[int], Optional[int], Optional[int]]:
    """
    Normalize Anthropic usage to (prompt_tokens, completion_tokens, total_tokens).
    Anthropic exposes usage.input_tokens / usage.output_tokens.
    """
    if not usage:
        return (None, None, None)
    inp = getattr(usage, "input_tokens", None)
    out = getattr(usage, "output_tokens", None)
    tot = None
    try:
        tot = (int(inp) if inp is not None else 0) + (int(out) if out is not None else 0)
    except Exception:
        tot = None
    if isinstance(usage, dict):
        inp = usage.get("input_tokens", inp)
        out = usage.get("output_tokens", out)
        if tot is None:
            try:
                tot = (int(inp) if inp is not None else 0) + (int(out) if out is not None else 0)
            except Exception:
                tot = None
    return (inp, out, tot)


def _finish_reason_claude(stop_reason: Optional[str]) -> str:
    """
    Map Anthropic stop_reason -> CIRBench's unified finish_reason.
    Values include: "end_turn", "max_tokens", "stop_sequence", "tool_use", ...
    """
    if not stop_reason:
        return "unknown"
    s = str(stop_reason).strip().lower()
    if s == "max_tokens":
        return "length"
    if s in ("stop_sequence", "end_turn"):
        return "stop"
    return s


# ---- Runner --------------------------------------------------------------

class ClaudeRunner:
    """
    Anthropic Claude runner using the official Messages API.
    Expected by cirbench: generate(prompts: List[str]) -> List[_Out]

    Features:
      - Supports temperature/top_p, custom stop sequences, and IR stop auto-injection.
      - Non-streaming by default; if enable_thinking=true is set under params.extra_body,
        we use streaming to better emulate "reasoning"-style workflows (Claude does not
        expose reasoning text by default; metrics will be left None).
      - You can also force streaming by setting params.stream=true (or params.streaming=true).
      - Supports Claude extended thinking (params.thinking: {type: "enabled", budget_tokens: N}); counts toward max_tokens and may require larger max_tokens.
      - Auto-streaming: if not explicitly set, the runner will enable streaming when `max_tokens` or `thinking.budget_tokens` exceed thresholds (configurable), to avoid the SDK's 10-minute non-streaming limit.
    """

    def __init__(self, model: str, params: Dict[str, Any] | None = None):
        if Anthropic is None:
            raise RuntimeError("Missing dependency: `anthropic` package is required (pip install anthropic).")
        ps = dict(params or {})
        api_key = ps.get("api_key") or os.getenv("ANTHROPIC_API_KEY")
        base_url = ps.get("base_url") or os.getenv("ANTHROPIC_BASE_URL")
        if not api_key:
            raise RuntimeError("ANTHROPIC_API_KEY not set (and no `params.api_key` provided).")

        # Construct Anthropic client, explicitly overriding the default timeout.
        # According to the Anthropic SDK docs, overriding the `timeout` option
        # disables the client's 10-minute non-streaming length heuristic, so we
        # can rely on max_tokens without being forced into streaming.
        timeout_val = ps.get("timeout", None)
        if timeout_val is not None:
            try:
                timeout_val = float(timeout_val)
            except Exception:
                timeout_val = None
        if timeout_val is None:
            # Keep a sensible default (10 minutes) while still overriding the SDK default.
            timeout_val = 600.0

        client_kwargs: Dict[str, Any] = {"api_key": api_key, "timeout": timeout_val}
        if base_url:
            client_kwargs["base_url"] = base_url
        self.client = Anthropic(**client_kwargs)

        self.model = model
        self.system_prompt = ps.get("system") or "You are a helpful assistant."

        # Optional generation params
        self.temperature: Optional[float] = ps.get("temperature", None)
        self.top_p: Optional[float] = ps.get("top_p", None)
        # Disallow setting both temperature and top_p simultaneously to avoid undefined provider behavior
        if self.temperature is not None and self.top_p is not None:
            raise ValueError(
                "ClaudeRunner: 'temperature' and 'top_p' cannot both be set; please specify at most one."
            )
        # Allow forcing streaming independent of thinking flag
        self.force_stream: bool = bool(ps.get("stream", False) or ps.get("streaming", False))

        # Auto-streaming heuristics (can be disabled via params.auto_stream: false)
        self.auto_stream: bool = bool(ps.get("auto_stream", _env_on("CIRBENCH_CLAUDE_AUTO_STREAM", "1")))
        try:
            self.stream_max_tokens_threshold: int = int(ps.get("stream_max_tokens_threshold", os.getenv("CIRBENCH_CLAUDE_STREAM_MAXTOKENS_THRESHOLD", "12288")))
        except Exception:
            self.stream_max_tokens_threshold = 12288
        try:
            self.stream_thinking_threshold: int = int(ps.get("stream_thinking_threshold", os.getenv("CIRBENCH_CLAUDE_STREAM_THINKING_THRESHOLD", "4096")))
        except Exception:
            self.stream_thinking_threshold = 4096


        # Output limit (Claude requires max_tokens on every call)
        self.user_max_tokens: Optional[int] = ps.get("max_tokens", None)
        if self.user_max_tokens is None:
            mo = ps.get("max_output_tokens")
            if mo is not None:
                try:
                    self.user_max_tokens = int(mo)
                except Exception:
                    pass
        if self.user_max_tokens is None:
            self.user_max_tokens = 1024  # sensible default for Claude
        # Minimum thinking budget required by Anthropic docs
        self._min_thinking_budget = 1024

        # Stop sequences
        self.stop: Optional[List[str]] = None
        _stop = ps.get("stop", None)
        if _stop is None:
            _stop = ps.get("stop_sequences", None)
        if _stop is not None:
            if isinstance(_stop, (list, tuple)):
                self.stop = [str(s) for s in _stop if s is not None]
            else:
                self.stop = [str(_stop)]
        self.auto_stop_irout: bool = bool(ps.get("auto_stop_irout", True))

        # Retry/continuation controls (kept for parity with other runners)
        self.ctx_limit: Optional[int] = ps.get("ctx_limit", None)
        self.safety_tokens: int = int(ps.get("safety_tokens", 128))
        self.min_gen_tokens: int = int(ps.get("min_gen_tokens", 64))
        self.retry_on_length: bool = bool(ps.get("retry_on_length", True))
        self.retry_cap_tokens: int = int(ps.get("retry_cap_tokens", 1024))
        self.max_continuations: int = int(ps.get("max_continuations", 0))


        # Pass-through bag (not automatically sent; reserved for future use)
        self.extra_body: Optional[Dict[str, Any]] = None
        eb = ps.get("extra_body", None)
        if isinstance(eb, dict):
            self.extra_body = eb

        # Extended thinking config (Claude 3.7+/4.x): allow top-level params.thinking or extra_body.thinking
        self.thinking_cfg: Optional[Dict[str, Any]] = None
        tc = ps.get("thinking")
        if isinstance(tc, dict):
            # shallow copy to avoid accidental mutation
            self.thinking_cfg = dict(tc)
        elif isinstance(self.extra_body, dict) and isinstance(self.extra_body.get("thinking"), dict):
            # fallback if thinking is nested under extra_body
            self.thinking_cfg = dict(self.extra_body["thinking"])

        # Convenience flag: if set, prefer streaming path.
        # We read an optional "enable_thinking" boolean/string flag from extra_body.
        self.enable_thinking_flag: Optional[bool] = None
        if isinstance(self.extra_body, dict):
            v = self.extra_body.get("enable_thinking")
            if isinstance(v, bool):
                self.enable_thinking_flag = v
            elif isinstance(v, str):
                self.enable_thinking_flag = (v.strip().lower() == "true")

        # Enforce returning only the IR block (strip any prose around it)
        self.ir_only: bool = bool(ps.get("ir_only", _env_on("CIRBENCH_IR_ONLY", "1")))

    # ---- Helpers ----------------------------------------------------------

    @staticmethod
    def _default_stops() -> List[str]:
        return ["</CIR_JSON>", "</IR_OUT>", "</CIR_JSON", "</IR_OUT"]

    def _compose_stops(self, prompt: str) -> List[str]:
        base = list(self.stop or [])
        if self.auto_stop_irout:
            base += self._default_stops()
        seen, out = set(), []
        for s in base:
            if s not in seen:
                out.append(s); seen.add(s)
        return out

    # ---- Core calls -------------------------------------------------------

    def _effective_thinking_budget(self, max_out: int) -> Optional[int]:
        """Compute a valid thinking budget given config and max_out. Returns None if not applicable."""
        if not isinstance(self.thinking_cfg, dict):
            return None
        t = dict(self.thinking_cfg)
        ty = t.get("type")
        if ty is True:
            ty = "enabled"
        if not (isinstance(ty, str) and ty.lower() == "enabled"):
            return None
        bt = t.get("budget_tokens")
        if not isinstance(bt, int):
            # default from env if missing
            try:
                bt = int(os.getenv("CIRBENCH_CLAUDE_THINKING_BUDGET", "8192"))
            except Exception:
                bt = self._min_thinking_budget
        budget = max(self._min_thinking_budget, int(bt))
        if max_out <= budget:
            budget = max(self._min_thinking_budget, max_out - 1)
        if budget < self._min_thinking_budget or budget >= max_out:
            return None
        return budget



    def _chat_once(self, prompt: str, *, max_tokens: Optional[int] = None, prior: Optional[str] = None) -> tuple[str, Dict[str, Any]]:
        t0 = time.time()
        max_out = int(max_tokens or self.user_max_tokens or 1024)

        messages = [{"role": "user", "content": prompt}]
        if prior:
            messages.append({"role": "assistant", "content": prior})
            messages.append({
                "role": "user",
                "content": (
                    "Continue exactly where you left off. "
                    "Do not repeat earlier text. "
                    "Stop immediately after you emit the closing tag (e.g., </CIR_JSON> or </IR_OUT>)."
                ),
            })

        kwargs: Dict[str, Any] = {
            "model": self.model,
            "system": self.system_prompt,
            "messages": messages,
            "max_tokens": max_out,
        }
        if self.temperature is not None:
            kwargs["temperature"] = self.temperature
        if self.top_p is not None:
            kwargs["top_p"] = self.top_p
        stops = self._compose_stops(prompt)
        if stops:
            kwargs["stop_sequences"] = stops

        # If budget_tokens is omitted, default to env CIRBENCH_CLAUDE_THINKING_BUDGET (default 8192)
        # Attach extended thinking if requested
        # See: https://docs.claude.com/en/api/messages (thinking)
        thinking_applied = False
        thinking_budget = None
        if isinstance(self.thinking_cfg, dict):
            t = dict(self.thinking_cfg)
            # Normalize type forms: true/"enabled" -> "enabled"
            ty = t.get("type")
            if ty is True:
                ty = "enabled"
            if isinstance(ty, str) and ty.lower() == "enabled":
                bt = t.get("budget_tokens")
                # Provide a default if not specified or invalid
                if not isinstance(bt, int):
                    try:
                        bt = int(os.getenv("CIRBENCH_CLAUDE_THINKING_BUDGET", "8192"))
                    except Exception:
                        bt = self._min_thinking_budget
                # Ensure limits: >= min and < max_out
                thinking_budget = max(self._min_thinking_budget, bt)
                if max_out <= thinking_budget:
                    thinking_budget = max(self._min_thinking_budget, max_out - 1)
                if thinking_budget >= self._min_thinking_budget and thinking_budget < max_out:
                    kwargs["thinking"] = {"type": "enabled", "budget_tokens": int(thinking_budget)}
                    thinking_applied = True

        if os.getenv("CIRBENCH_DEBUG") == "1":
            print("[DEBUG.claude.nostream] kwargs keys:", sorted(list(kwargs.keys())))

        try:
            resp = self.client.messages.create(**kwargs)
        except Exception as ex:
            return "", {
                "provider": "claude",
                "model": self.model,
                "error": f"{type(ex).__name__}:{str(ex)[:300]}",
                "latency_ms": int(round((time.time() - t0) * 1000)),
                "finish_reason": "exception",
                "prompt_tokens": None,
                "out_tokens": None,
                "total_tokens": None,
                "max_tokens_used": max_out,
                "reasoning_tokens": None,
                "thinking_enabled": bool(self.enable_thinking_flag),
                "streaming": False,
            }

        # Extract text
        text_parts: list[str] = []
        try:
            blocks = getattr(resp, "content", None)
            if isinstance(blocks, list):
                for b in blocks:
                    # Claude text blocks have type == "text" and a `text` field
                    t = getattr(b, "text", None)
                    if t is None and isinstance(b, dict):
                        t = b.get("text")
                    if t:
                        text_parts.append(t)
        except Exception:
            pass
        text = "".join(text_parts)

        # Detect thinking blocks in non-streaming responses (if any)
        thinking_text_len = 0
        try:
            blocks = getattr(resp, "content", None)
            if isinstance(blocks, list):
                for b in blocks:
                    btype = getattr(b, "type", None)
                    if btype is None and isinstance(b, dict):
                        btype = b.get("type")
                    if btype == "thinking":
                        # Extended thinking blocks may use either "thinking" or "text" field
                        t = getattr(b, "thinking", None)
                        if t is None:
                            t = getattr(b, "text", None)
                        if t is None and isinstance(b, dict):
                            t = b.get("thinking") or b.get("text")
                        if t:
                            thinking_text_len += len(str(t))
        except Exception:
            pass

        # Usage & finish
        pt, ct, tt = _safe_get_usage_tokens_anthropic(getattr(resp, "usage", None) or {})
        finish = _finish_reason_claude(getattr(resp, "stop_reason", None))

        meta = {
            "provider": "claude",
            "model": self.model,
            "latency_ms": int(round((time.time() - t0) * 1000)),
            "finish_reason": finish,
            "prompt_tokens": pt,
            "out_tokens": ct,
            "total_tokens": tt,
            "max_tokens_used": max_out,
            "thinking_enabled": bool(thinking_applied or self.thinking_cfg),
            "reasoning_tokens": None,
            "streaming": False,
            "thinking_requested": bool(self.thinking_cfg is not None),
            "thinking_applied": bool(thinking_applied),
            "thinking_budget_tokens": int(thinking_budget) if thinking_budget is not None else None,
            "thinking_text_len": int(thinking_text_len),
            "had_thinking_stream": bool(thinking_text_len > 0),
        }
        return text, meta

    # ---- Public API -------------------------------------------------------

    def generate(self, prompts: List[str]) -> List[_Out]:
        outs: List[_Out] = []

        def _has_stop(full_text: str, prompt: str) -> bool:
            if not full_text:
                return False
            stops = self._compose_stops(prompt)
            return any(s in full_text for s in stops)

        for p in prompts:
            # First attempt
            text, meta = self._chat_once(p, max_tokens=self.user_max_tokens)

            # Retry on length / continuation parity with other runners
            fr = str(meta.get("finish_reason", "")).lower()
            pt = meta.get("prompt_tokens")
            ct = meta.get("out_tokens")

            need_retry = False
            context_overflow = False
            retry_allowed = None

            if self.retry_on_length:
                if fr in {"length", "max_tokens"}:
                    need_retry = True
                elif (ct in (0, None)) and (self.ctx_limit is not None) and (pt is not None):
                    need_retry = True

            if need_retry and self.ctx_limit is not None and (pt is not None):
                allowed = int(self.ctx_limit) - int(pt) - int(self.safety_tokens)
                if allowed <= 0:
                    context_overflow = True
                    need_retry = False
                else:
                    retry_allowed = max(self.min_gen_tokens, min(allowed, self.retry_cap_tokens))

            if need_retry and (retry_allowed is None):
                retry_allowed = min(self.retry_cap_tokens, max(self.min_gen_tokens, 512))

            if need_retry and (retry_allowed is not None) and (retry_allowed > 0) and (not _has_stop(text, p)):
                text2, meta2 = self._chat_once(p, max_tokens=int(retry_allowed))
                text = text2 or text
                meta.update({
                    "retry": True,
                    "retry_allowed_tokens": int(retry_allowed),
                    "finish_reason": meta2.get("finish_reason", meta.get("finish_reason")),
                    "prompt_tokens": meta2.get("prompt_tokens", meta.get("prompt_tokens")),
                    "out_tokens": meta2.get("out_tokens", meta.get("out_tokens")),
                    "total_tokens": meta2.get("total_tokens", meta.get("total_tokens")),
                    "max_tokens_used": meta2.get("max_tokens_used", meta.get("max_tokens_used")),
                    "latency_ms": meta.get("latency_ms", 0) + meta2.get("latency_ms", 0),
                })
            else:
                meta.update({
                    "retry": False,
                    "retry_allowed_tokens": retry_allowed,
                    "context_overflow": context_overflow,
                    "ctx_limit": self.ctx_limit,
                })

            # Optional continuations (rarely used for Claude)
            continued = 0
            need_more = not _has_stop(text, p)
            while need_more and (continued < self.max_continuations):
                cont_tokens = min(self.retry_cap_tokens, max(self.min_gen_tokens, 512))
                if (self.ctx_limit is not None) and (meta.get("prompt_tokens") is not None):
                    allowed = int(self.ctx_limit) - int(meta["prompt_tokens"]) - int(self.safety_tokens)
                    if allowed <= 0:
                        break
                    cont_tokens = max(self.min_gen_tokens, min(cont_tokens, allowed))

                tail = text[-4000:] if text else ""
                t_more, m_more = self._chat_once(p, max_tokens=int(cont_tokens), prior=tail)
                if not t_more:
                    break

                text += t_more
                continued += 1

                meta["latency_ms"] = int(meta.get("latency_ms", 0)) + int(m_more.get("latency_ms", 0))
                meta["out_tokens"] = (meta.get("out_tokens") or 0) + (m_more.get("out_tokens") or 0)
                meta["finish_reason"] = m_more.get("finish_reason", meta.get("finish_reason"))

                if _has_stop(text, p):
                    break
                need_more = not _has_stop(text, p)

            meta["continued"] = bool(continued)
            meta["continued_steps"] = continued

            if getattr(self, "ir_only", False):
                clipped, ok = _extract_ir_block(text)
                if ok:
                    text = clipped
                    meta["ir_only_clipped"] = True
                else:
                    meta["ir_only_clipped"] = False
                    meta["ir_only_missing_block"] = True
            outs.append(_Out(text, meta))
        return outs


# Factory

def make(model_cfg: Dict[str, Any]):
    """
    Factory entrypoint for cirbench.utils.api.base.make_runner()
    Expects: model_cfg = {"kind": "claude", "name": "<model>", "params": {...}}
    """
    name = model_cfg.get("name") or "claude-sonnet-4-5-20250929"
    params = model_cfg.get("params") or {}
    return ClaudeRunner(name, params)