"""
HuggingFace Inference API backend.

Uses huggingface_hub.InferenceClient (>= 0.34) to route chat-completion
requests through HF's provider network (Novita, Cerebras, nscale, …).
No local GPU required; billing goes through your HF account.

Install: pip install huggingface_hub
Auth:    export HF_TOKEN=<your token>   (HF_API_TOKEN and HUGGINGFACE_TOKEN
         are also accepted as fallbacks, matching the original Regym naming).

Provider routing (huggingface_hub >= 0.34):
  - provider="auto" (default): HF picks the first healthy provider for the
    model, sorted by your preference order at hf.co/settings/inference-providers.
  - provider="novita" / "nscale" / etc.: pin a specific provider.
  - base_url: bypass provider routing entirely and hit a self-hosted TGI server.
"""
from __future__ import annotations

import os
import time

from meta_rg.backends.base import BaseBackend

_RETRY_DELAYS = (5, 10, 20, 40)  # seconds between attempts on 429

# ---------------------------------------------------------------------------
# Optional Weave tracing (no-op when weave is not installed / not initialised)
# ---------------------------------------------------------------------------
try:
    import weave as _weave
    _weave_op = _weave.op
except Exception:
    def _weave_op(fn=None, *, name=None, **_kw):  # type: ignore[misc]
        return fn if fn is not None else (lambda f: f)


@_weave_op(name="hf_chat_completion")
def _log_hf_chat_call(
    messages: list,
    model: str,
    max_tokens: int,
    temperature: float,
    content: str,
    reasoning_content: str,
    prompt_tokens: int,
    completion_tokens: int,
    cached_tokens: int,
) -> str:
    """Weave trace for every raw HF Inference API call — captures full request + response."""
    return content


class HFAPIBackend(BaseBackend):
    def __init__(
        self,
        model_id: str,
        temperature: float = 0.7,
        max_new_tokens: int = 256,
        provider: str = "auto",
        token: str | None = None,
        base_url: str | None = None,
        system_prompt: str | None = None,
        **_,
    ) -> None:
        try:
            from huggingface_hub import InferenceClient
        except ImportError as e:
            raise ImportError("pip install huggingface_hub") from e

        # Accept HF_TOKEN, HF_API_TOKEN (original Regym name), or HUGGINGFACE_TOKEN
        resolved_token = (
            token
            or os.environ.get("HF_TOKEN")
            or os.environ.get("HF_API_TOKEN")
            or os.environ.get("HUGGINGFACE_TOKEN")
        )

        self.model_id = model_id
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.system_prompt = system_prompt

        # Cumulative token usage tracked from API responses (available after calls).
        # cached_tokens is populated when the provider supports prefix caching
        # (OpenAI-compatible usage.prompt_tokens_details.cached_tokens).
        self.total_prompt_tokens: int = 0
        self.total_completion_tokens: int = 0
        self.total_cached_tokens: int = 0
        # Per-call values (reset each call); use for cache_hit_rate metric.
        self.last_call_prompt_tokens: int = 0
        self.last_call_cached_tokens: int = 0
        self.last_call_reasoning_content: str = ""

        if base_url:
            # Self-hosted TGI — ignore provider routing
            self.client = InferenceClient(base_url=base_url, token=resolved_token)
        else:
            self.client = InferenceClient(provider=provider, token=resolved_token)

    def _call_and_track(self, messages: list[dict]) -> str:
        last_exc = None
        for attempt, delay in enumerate([0] + list(_RETRY_DELAYS)):
            if delay:
                print(f"[HFAPIBackend] transient error (attempt {attempt}/{len(_RETRY_DELAYS)})"
                      f" — retrying in {delay}s")
                time.sleep(delay)
            try:
                response = self.client.chat_completion(
                    model=self.model_id,
                    messages=messages,
                    max_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                )
                break
            except Exception as exc:
                status = getattr(getattr(exc, "response", None), "status_code", None)
                if status == 429 or (status is not None and 500 <= status < 600):
                    last_exc = exc
                    continue
                raise
        else:
            raise last_exc  # all retries exhausted

        usage = getattr(response, "usage", None)
        if usage is not None:
            call_pt  = getattr(usage, "prompt_tokens",     0) or 0
            call_ct  = getattr(usage, "completion_tokens", 0) or 0
            self.total_prompt_tokens     += call_pt
            self.total_completion_tokens += call_ct
            details  = getattr(usage, "prompt_tokens_details", None)
            call_cac = (getattr(details, "cached_tokens", 0) or 0) if details is not None else 0
            self.total_cached_tokens     += call_cac
            self.last_call_prompt_tokens  = call_pt
            self.last_call_cached_tokens  = call_cac
        else:
            call_pt = call_ct = call_cac = 0
            self.last_call_prompt_tokens  = 0
            self.last_call_cached_tokens  = 0

        msg     = response.choices[0].message
        content = msg.content or ""
        reasoning = getattr(msg, "reasoning_content", None) or ""
        self.last_call_reasoning_content = reasoning

        _log_hf_chat_call(
            messages=messages,
            model=self.model_id,
            max_tokens=self.max_new_tokens,
            temperature=self.temperature,
            content=content,
            reasoning_content=reasoning,
            prompt_tokens=call_pt,
            completion_tokens=call_ct,
            cached_tokens=call_cac,
        )
        return content

    @property
    def cache_hit_rate(self) -> float:
        """Fraction of prompt tokens served from cache (0.0 if provider doesn't support it)."""
        if self.total_prompt_tokens == 0:
            return 0.0
        return self.total_cached_tokens / self.total_prompt_tokens

    def generate(self, prompt_text: str) -> str:
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": prompt_text})
        return self._call_and_track(messages)

    def generate_chat(self, messages: list[dict]) -> str:
        if self.system_prompt and (not messages or messages[0].get("role") != "system"):
            messages = [{"role": "system", "content": self.system_prompt}] + messages
        return self._call_and_track(messages)
