# cirbench/utils/api/llama.py
from __future__ import annotations
import os, time, json
from typing import List, Dict, Any, Optional

# Small helper: parse boolean-ish env vars ("1/true/yes/on")
def _env_on(name: str, default: str = "0") -> bool:
    v = os.getenv(name, default)
    return isinstance(v, str) and v.strip().lower() in ("1", "true", "yes", "on")

try:
    # Requires: pip install openai>=1.0
    from openai import OpenAI
except Exception:
    OpenAI = None


class _Out:
    def __init__(self, text: str, meta: Dict[str, Any]):
        self.text = text
        self.meta = meta


def _safe_get_usage_tokens(usage: Any) -> tuple[Optional[int], Optional[int], Optional[int]]:
    """
    Normalize usage payload from OpenAI-compatible servers.
    Returns (prompt_tokens, completion_tokens, total_tokens)
    """
    if not usage:
        return (None, None, None)
    # Some SDKs expose attributes, others dicts
    pt = getattr(usage, "prompt_tokens", None)
    ct = getattr(usage, "completion_tokens", None)
    tt = getattr(usage, "total_tokens", None)
    if isinstance(usage, dict):
        pt = usage.get("prompt_tokens", pt)
        ct = usage.get("completion_tokens", ct)
        tt = usage.get("total_tokens", tt)
    return (pt, ct, tt)


def _finish_reason(choice: Any) -> str:
    """
    Normalize finish_reason to a lower-case keyword.
    We map 'MAX_TOKENS' to 'length' for consistency.
    """
    fr = None
    if choice is not None:
        fr = getattr(choice, "finish_reason", None)
        if isinstance(choice, dict):
            fr = choice.get("finish_reason", fr)
    if not fr:
        return "unknown"
    frs = str(fr).strip().lower()
    return "length" if frs == "max_tokens" else frs



class LlamaRunner:
    """
    OpenAI-compatible Llama runner.
    Expected by cirbench: generate(prompts: List[str]) -> List[_Out]

    Designed for providers exposing Llama models via OpenAI Chat Completions
    (e.g., OpenRouter, DeepInfra, self-hosted gateways).

    Features:
      - Respects temperature/top_p and user-specified max_tokens/max_output_tokens.
      - Auto stop on IR tags and lenient continuations when cut by length.
      - Optional streaming mode (params.stream=True or env CIRBENCH_LLAMA_FORCE_STREAM=1).
    """

    def __init__(self, model: str, params: Dict[str, Any] | None = None):
        if OpenAI is None:
            raise RuntimeError("Missing dependency: `openai` package is required.")
        ps = dict(params or {})

        # Llama providers are usually OpenAI-compatible; allow llama-specific envs
        api_key = ps.get("api_key") or os.getenv("LLAMA_API_KEY")
        base_url = ps.get("base_url") or os.getenv("LLAMA_BASE_URL")
        if not api_key:
            raise RuntimeError("LLAMA_API_KEY not set (and no `params.api_key` provided).")

        if base_url:
            self.client = OpenAI(api_key=api_key, base_url=base_url)
        else:
            self.client = OpenAI(api_key=api_key)

        self.model = model
        self.system_prompt = ps.get("system") or "You are a helpful assistant."

        # Optional generation params
        self.temperature: Optional[float] = ps.get("temperature", None)
        self.top_p: Optional[float] = ps.get("top_p", None)

        # User-specified output limit (OpenAI-style or Gemini-style)
        self.user_max_tokens: Optional[int] = ps.get("max_tokens", None)
        if self.user_max_tokens is None:
            mo = ps.get("max_output_tokens")
            if mo is not None:
                try:
                    self.user_max_tokens = int(mo)
                except Exception:
                    pass

        # Stop sequences: allow 'stop' or 'stop_sequences'
        self.stop: Optional[List[str]] = None
        _stop = ps.get("stop", None)
        if _stop is None:
            _stop = ps.get("stop_sequences", None)
        if _stop is not None:
            if isinstance(_stop, (list, tuple)):
                self.stop = [str(s) for s in _stop if s is not None]
            else:
                self.stop = [str(_stop)]

        # Auto add IR stops (and their half-closed variants) to be conservative
        self.auto_stop_irout: bool = bool(ps.get("auto_stop_irout", True))

        # Dynamic retry & continuation controls
        self.ctx_limit: Optional[int] = ps.get("ctx_limit", None)     # e.g., 32000/131072 if you know it
        self.safety_tokens: int = int(ps.get("safety_tokens", 128))   # budget cushion
        self.min_gen_tokens: int = int(ps.get("min_gen_tokens", 64))
        self.retry_on_length: bool = bool(ps.get("retry_on_length", True))
        self.retry_cap_tokens: int = int(ps.get("retry_cap_tokens", 1024))
        self.max_continuations: int = int(ps.get("max_continuations", 0))

        # Request timeout (seconds)
        self.timeout: Optional[float] = ps.get("timeout", None)

        # Provider-specific passthrough: extra_body (OpenAI SDK supports this)
        self.extra_body: Optional[Dict[str, Any]] = None
        eb = ps.get("extra_body", None)
        if eb is not None:
            if isinstance(eb, dict):
                self.extra_body = eb
            elif isinstance(eb, str):
                try:
                    self.extra_body = json.loads(eb)
                except Exception:
                    self.extra_body = None

        # Optional streaming flag (not related to thinking)
        self.stream_flag: bool = bool(ps.get("stream", False)) or _env_on("CIRBENCH_LLAMA_FORCE_STREAM")

    # ---- Internal helpers -------------------------------------------------

    # Default IR stops are appended ONLY when user provided stop sequences.
    @staticmethod
    def _default_stops() -> List[str]:
        return ["</CIR_JSON>", "</IR_OUT>", "</CIR_JSON", "</IR_OUT"]

    def _compose_stops(self, prompt: str) -> List[str]:
        # If caller did NOT configure stop/stop_sequences, do not send any stop (no defaults).
        if not self.stop:
            return []
        base = list(self.stop)
        if self.auto_stop_irout:
            base += self._default_stops()
        # dedupe while preserving order
        seen, out = set(), []
        for s in base:
            if s not in seen:
                out.append(s); seen.add(s)
        return out

    def _chat_stream(self, prompt: str, *, max_tokens: Optional[int] = None, prior: Optional[str] = None) -> tuple[str, Dict[str, Any]]:
        """Streaming chat call for OpenAI-compatible Llama providers."""
        t0 = time.time()

        log_chunks = _env_on("CIRBENCH_LLAMA_LOG_CHUNKS")
        log_summary = _env_on("CIRBENCH_LLAMA_LOG_STREAM_SUMMARY")

        chunks_total = 0
        chunks_with_choices = 0
        chunks_with_delta_content = 0
        chunks_usage_only = 0
        finish_reason_set_count = 0
        first_token_ms = None

        try:
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": prompt},
            ]
            if prior:
                messages.append({"role": "assistant", "content": prior})
                messages.append({
                    "role": "user",
                    "content": (
                        "Continue exactly where you left off. "
                        "Do not repeat earlier text. "
                        "Stop immediately after you emit the closing tag (e.g., </CIR_JSON> or </IR_OUT>)."
                    ),
                })

            kwargs: Dict[str, Any] = {
                "model": self.model,
                "messages": messages,
                "stream": True,
                "stream_options": {"include_usage": True},
            }
            if self.temperature is not None:
                kwargs["temperature"] = self.temperature
            if self.top_p is not None:
                kwargs["top_p"] = self.top_p
            if max_tokens is not None:
                kwargs["max_tokens"] = int(max_tokens)
            if self.timeout is not None:
                kwargs["timeout"] = float(self.timeout)

            # Always send composed stops if any are set or auto_stop_irout enabled
            stops = self._compose_stops(prompt)
            if stops:
                kwargs["stop"] = stops

            attached_extra = None
            if self.extra_body:
                kwargs["extra_body"] = self.extra_body
                attached_extra = dict(self.extra_body)

            if os.getenv("CIRBENCH_DEBUG") == "1":
                print("[DEBUG.llama.stream] kwargs keys:", sorted(list(kwargs.keys())))

            stream = self.client.chat.completions.create(**kwargs)
            parts: list[str] = []
            finish = "unknown"
            last_usage = None

            for chunk in stream:
                chunks_total += 1
                had_usage = False
                had_choices = False
                had_content = False
                content_len = 0
                had_finish = False

                try:
                    u = getattr(chunk, "usage", None)
                    if u is not None:
                        last_usage = u
                        had_usage = True
                except Exception:
                    pass

                try:
                    chs = getattr(chunk, "choices", None)
                    if not chs and isinstance(chunk, dict):
                        chs = chunk.get("choices")
                    if chs:
                        had_choices = True
                        c0 = chs[0]
                        fr = getattr(c0, "finish_reason", None)
                        if fr is None and isinstance(c0, dict):
                            fr = c0.get("finish_reason")
                        if fr:
                            finish = _finish_reason(c0)
                            had_finish = True
                            finish_reason_set_count += 1

                        delta = getattr(c0, "delta", None)
                        if delta is None and isinstance(c0, dict):
                            delta = c0.get("delta")
                        if delta is not None:
                            content = getattr(delta, "content", None)
                            if content is None and isinstance(delta, dict):
                                content = delta.get("content")
                            if content:
                                parts.append(content)
                                content_len = len(content or "")
                                had_content = True
                                if first_token_ms is None:
                                    first_token_ms = int(round((time.time() - t0) * 1000))
                except Exception:
                    pass

                if had_usage and not had_choices:
                    chunks_usage_only += 1
                if had_choices:
                    chunks_with_choices += 1
                if had_content:
                    chunks_with_delta_content += 1

                if log_chunks:
                    try:
                        print(
                            f"[llama.stream.chunk #{chunks_total}] usage={had_usage} choices={had_choices} "
                            f"finish={had_finish} content_len={content_len}"
                        )
                    except Exception:
                        pass

            text = "".join(parts)
            ptoks, ctoks, ttoks = _safe_get_usage_tokens(last_usage or {})

            if log_summary or os.getenv("CIRBENCH_DEBUG") == "1":
                try:
                    print(
                        "[llama.stream.summary] "
                        f"chunks_total={chunks_total}, with_choices={chunks_with_choices}, "
                        f"with_content={chunks_with_delta_content}, usage_only={chunks_usage_only}, "
                        f"finish_reason='{finish}', ttft_ms={first_token_ms}, usage_present={last_usage is not None}"
                    )
                except Exception:
                    pass

            meta = {
                "provider": "llama",
                "model": self.model,
                "latency_ms": int(round((time.time() - t0) * 1000)),
                "finish_reason": finish,
                "prompt_tokens": ptoks,
                "out_tokens": ctoks,
                "total_tokens": ttoks,
                "max_tokens_used": max_tokens,
                "extra_body_keys": list(attached_extra.keys()) if isinstance(attached_extra, dict) else None,
                "streaming": True,
                "ttft_ms": first_token_ms,
                "chunks_total": chunks_total,
                "chunks_with_choices": chunks_with_choices,
                "chunks_with_delta_content": chunks_with_delta_content,
                "chunks_usage_only": chunks_usage_only,
                "finish_reason_events": finish_reason_set_count,
            }
            return text, meta
        except Exception as ex:
            return "", {
                "provider": "llama",
                "model": self.model,
                "error": f"{type(ex).__name__}:{str(ex)[:300]}",
                "latency_ms": int(round((time.time() - t0) * 1000)),
                "finish_reason": "exception",
                "prompt_tokens": None,
                "out_tokens": None,
                "total_tokens": None,
                "max_tokens_used": max_tokens,
                "streaming": True,
            }

    def _chat_once(self, prompt: str, *, max_tokens: Optional[int] = None, prior: Optional[str] = None) -> tuple[str, Dict[str, Any]]:
        """Make a single chat call. Uses streaming if enabled."""
        if self.stream_flag:
            return self._chat_stream(prompt, max_tokens=max_tokens, prior=prior)

        t0 = time.time()
        try:
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": prompt},
            ]
            if prior:
                messages.append({"role": "assistant", "content": prior})
                messages.append({
                    "role": "user",
                    "content": (
                        "Continue exactly where you left off. "
                        "Do not repeat earlier text. "
                        "Stop immediately after you emit the closing tag (e.g., </CIR_JSON> or </IR_OUT>)."
                    ),
                })

            kwargs: Dict[str, Any] = {
                "model": self.model,
                "messages": messages,
                "stream": False,
            }
            if self.temperature is not None:
                kwargs["temperature"] = self.temperature
            if self.top_p is not None:
                kwargs["top_p"] = self.top_p
            if max_tokens is not None:
                kwargs["max_tokens"] = int(max_tokens)
            if self.timeout is not None:
                kwargs["timeout"] = float(self.timeout)

            # Always send composed stops if any are set or auto_stop_irout enabled
            stops = self._compose_stops(prompt)
            if stops:
                kwargs["stop"] = stops

            attached_extra = None
            if self.extra_body:
                kwargs["extra_body"] = self.extra_body
                attached_extra = dict(self.extra_body)

            if os.getenv("CIRBENCH_DEBUG") == "1":
                print("[DEBUG.llama.nostream] kwargs keys:", sorted(list(kwargs.keys())))

            resp = self.client.chat.completions.create(**kwargs)

            text = ""
            finish = "unknown"
            if resp and getattr(resp, "choices", None):
                choice0 = resp.choices[0]
                finish = _finish_reason(choice0)
                msg = getattr(choice0, "message", None)
                if msg is None and isinstance(choice0, dict):
                    msg = choice0.get("message")
                if msg is not None:
                    content = getattr(msg, "content", None)
                    if content is None and isinstance(msg, dict):
                        content = msg.get("content", "")
                    text = content or ""

            ptoks, ctoks, ttoks = _safe_get_usage_tokens(getattr(resp, "usage", None) or {})

            meta = {
                "provider": "llama",
                "model": self.model,
                "latency_ms": int(round((time.time() - t0) * 1000)),
                "finish_reason": finish,
                "prompt_tokens": ptoks,
                "out_tokens": ctoks,
                "total_tokens": ttoks,
                "max_tokens_used": max_tokens,
                "extra_body_keys": list(attached_extra.keys()) if isinstance(attached_extra, dict) else None,
                "streaming": False,
            }
            return text, meta
        except Exception as ex:
            return "", {
                "provider": "llama",
                "model": self.model,
                "error": f"{type(ex).__name__}:{str(ex)[:300]}",
                "latency_ms": int(round((time.time() - t0) * 1000)),
                "finish_reason": "exception",
                "prompt_tokens": None,
                "out_tokens": None,
                "total_tokens": None,
                "max_tokens_used": max_tokens,
                "streaming": False,
            }

    # ---- Public API -------------------------------------------------------

    def generate(self, prompts: List[str]) -> List[_Out]:
        outs: List[_Out] = []

        def _has_stop(full_text: str, prompt: str) -> bool:
            if not full_text:
                return False
            stops = self._compose_stops(prompt)
            return any(s in full_text for s in stops)

        for p in prompts:
            text, meta = self._chat_once(p, max_tokens=self.user_max_tokens)

            if _env_on("CIRBENCH_LLAMA_LOG_STOP_MATCH"):
                try:
                    stops = self._compose_stops(p)
                    tail = (text or "")[-160:]
                    print(f"[llama.stop] detected={_has_stop(text, p)} stops={stops} tail=\"{tail}\"")
                except Exception:
                    pass

            fr = str(meta.get("finish_reason", "")).lower()
            pt = meta.get("prompt_tokens")
            ct = meta.get("out_tokens")

            need_retry = False
            context_overflow = False
            retry_allowed = None

            if self.retry_on_length:
                if fr in {"length", "max_tokens"}:
                    need_retry = True
                elif (ct in (0, None)) and (self.ctx_limit is not None) and (pt is not None):
                    need_retry = True

            if need_retry and self.ctx_limit is not None and (pt is not None):
                allowed = int(self.ctx_limit) - int(pt) - int(self.safety_tokens)
                if allowed <= 0:
                    context_overflow = True
                    need_retry = False
                else:
                    retry_allowed = max(self.min_gen_tokens, min(allowed, self.retry_cap_tokens))

            if need_retry and (retry_allowed is None):
                retry_allowed = min(self.retry_cap_tokens, max(self.min_gen_tokens, 512))

            if need_retry and (retry_allowed is not None) and (retry_allowed > 0) and (not _has_stop(text, p)):
                text2, meta2 = self._chat_once(p, max_tokens=int(retry_allowed))
                text = text2 or text
                meta.update({
                    "retry": True,
                    "retry_allowed_tokens": int(retry_allowed),
                    "finish_reason": meta2.get("finish_reason", meta.get("finish_reason")),
                    "prompt_tokens": meta2.get("prompt_tokens", meta.get("prompt_tokens")),
                    "out_tokens": meta2.get("out_tokens", meta.get("out_tokens")),
                    "total_tokens": meta2.get("total_tokens", meta.get("total_tokens")),
                    "max_tokens_used": meta2.get("max_tokens_used", meta.get("max_tokens_used")),
                    "latency_ms": meta.get("latency_ms", 0) + meta2.get("latency_ms", 0),
                })
            else:
                meta.update({
                    "retry": False,
                    "retry_allowed_tokens": retry_allowed,
                    "context_overflow": context_overflow,
                    "ctx_limit": self.ctx_limit,
                })

            continued = 0
            need_more = not _has_stop(text, p)
            while need_more and (continued < self.max_continuations):
                cont_tokens = min(self.retry_cap_tokens, max(self.min_gen_tokens, 512))
                if (self.ctx_limit is not None) and (meta.get("prompt_tokens") is not None):
                    allowed = int(self.ctx_limit) - int(meta["prompt_tokens"]) - int(self.safety_tokens)
                    if allowed <= 0:
                        break
                    cont_tokens = max(self.min_gen_tokens, min(cont_tokens, allowed))

                tail = text[-4000:] if text else ""
                t_more, m_more = self._chat_once(p, max_tokens=int(cont_tokens), prior=tail)
                if not t_more:
                    break

                text += t_more
                continued += 1

                meta["latency_ms"] = int(meta.get("latency_ms", 0)) + int(m_more.get("latency_ms", 0))
                meta["out_tokens"] = (meta.get("out_tokens") or 0) + (m_more.get("out_tokens") or 0)
                meta["finish_reason"] = m_more.get("finish_reason", meta.get("finish_reason"))

                if _has_stop(text, p):
                    break
                need_more = not _has_stop(text, p)

            meta["continued"] = bool(continued)
            meta["continued_steps"] = continued

            outs.append(_Out(text, meta))
        return outs


def make(model_cfg: Dict[str, Any]):
    """
    Factory entrypoint for cirbench.utils.api.base.make_runner()
    Expects: model_cfg = {"kind": "llama", "name": "<model>", "params": {...}}
    """
    name = model_cfg.get("name") or "llama4-maverick"
    params = model_cfg.get("params") or {}
    return LlamaRunner(name, params)