# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A shared interface for sampling from LLMs via APIs."""

import abc
import dataclasses
import os
from typing import Any
from absl import logging
import anthropic
from google import genai
import openai
import retry


class RetriableError(Exception):
  pass


class UnknownApiError(Exception):
  pass


FAILURE_STR = "[ERROR: API_SAMPLING_FAILED]"


# Don't freeze; allow updating text if sampling fails.
@dataclasses.dataclass(frozen=False)
class APIResponse:
  text: str
  # We don't currently use caching, but we could consider implementing it and
  # logging cached prompt tokens separately as well.
  n_prompt_tokens: int
  n_completion_tokens: int
  raw_response: Any


class LanguageModelClient(abc.ABC):
  """Abstract base class for language model API clients."""

  @abc.abstractmethod
  def __init__(
      self,
      model_name: str,
  ):
    pass

  @retry.retry(
      exceptions=RetriableError,
      tries=40,  # ~24 hours total.
      delay=5,
      max_delay=int(5e3),
      backoff=1.5,
      jitter=(0, 5),  # Avoid backoff stampeding.
  )
  def generate(
      self, prompt: str, max_new_tokens: int, do_sample: bool = False
  ) -> APIResponse:
    """Generate a completion for the given prompt, with exponential backoff."""
    res = self._generate(prompt, max_new_tokens, do_sample)
    if not isinstance(res.text, str):
      logging.error(
          "Unexpected response text: %s. Prompt: %s. Full response: %s."
          " Replacing.",
          res.text,
          prompt,
          res.raw_response,
      )
      res.text = FAILURE_STR
    return res

  @abc.abstractmethod
  def _generate(
      self, prompt: str, max_new_tokens: int, do_sample: bool
  ) -> APIResponse:
    """Generate a completion for the given prompt."""


RetriableHttpErrorCodes = frozenset({
    429,  # RESOURCE_EXHAUSTED; Gemini uses this for rate limiting.
    500,  # INTERNAL
    502,  # BAD_GATEWAY, e.g. "The server encountered a temporary error and
    # could not complete your request. Please try again in 30 seconds."
    503,  # UNAVAILABLE
    504,  # DEADLINE_EXCEEDED
    529,  # Service is overloaded, from Anthropic.
})


class GeminiClient(LanguageModelClient):
  """Client for the Gemini API."""

  def __init__(self, model_name: str):
    self._model_name = model_name
    # API key should be in environment variable.
    self._client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))

  def _is_reasoning_model(self) -> bool:
    return self._model_name.startswith("gemini-2.5")

  def _generate(
      self, prompt: str, max_new_tokens: int, do_sample: bool = False
  ) -> APIResponse:
    """Generate a completion for the given prompt."""
    config_kwargs = dict()
    if not self._is_reasoning_model():
      # Reasoning models may need longer output to allow for thinking tokens.
      config_kwargs["max_output_tokens"] = max_new_tokens
    if not do_sample:
      # Only choose the most likely token (greedy sampling).
      config_kwargs["top_k"] = 1
    try:
      response = self._client.models.generate_content(
          model=self._model_name,
          contents=prompt,
          config=genai.types.GenerateContentConfig(
              # max_output_tokens=max_new_tokens,
              safety_settings=[
                  genai.types.SafetySetting(
                      category=category,
                      threshold=genai.types.HarmBlockThreshold.BLOCK_NONE,
                  )
                  for category in genai.types.HarmCategory
                  if category
                  != "HARM_CATEGORY_UNSPECIFIED"  # Can't filter this.
              ],
              # Disable AFC to prevent getting spammed with logs, e.g.
              # "AFC is enabled with max remote calls: 10."
              automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(
                  disable=True,
              ),
              **config_kwargs,
          ),
      )
    except genai.errors.APIError as e:
      if e.code in RetriableHttpErrorCodes:
        raise RetriableError(e) from e
      else:
        raise
    return APIResponse(
        text=response.text,
        n_prompt_tokens=response.usage_metadata.prompt_token_count,
        n_completion_tokens=response.usage_metadata.candidates_token_count,
        raw_response=response,
    )


class OpenAIClient(LanguageModelClient):
  """Client for the OpenAI API."""

  def __init__(self, model_name: str):
    self._model_name = model_name
    self._client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

  def _generate(
      self, prompt: str, max_new_tokens: int, do_sample: bool
  ) -> APIResponse:
    """Generate a completion for the given prompt."""
    # API Reference:
    # https://github.com/openai/openai-python/blob/a6b493071b843bec3db807637e441c1768b695f8/src/openai/resources/completions.py#L52
    sample_kwargs = dict()
    if not do_sample:
      sample_kwargs["temperature"] = 0.0
    try:
      response = self._client.chat.completions.create(
          model=self._model_name,
          messages=[{"role": "user", "content": prompt}],
          max_tokens=max_new_tokens,
          **sample_kwargs,
      )
    except (
        openai.LengthFinishReasonError,
        openai.ContentFilterFinishReasonError,
    ) as e:
      return APIResponse(
          text=f"[ERROR: {str(e)}]",
          n_prompt_tokens=0,
          n_completion_tokens=0,
          raw_response=e.body,
      )
    # The OpenAI API retries some errors by default, but we'll catch them
    # and retry them ourselves to unify error handling configuration between
    # API models.
    except (
        openai.APITimeoutError,
        openai.APIConnectionError,
    ) as e:
      raise RetriableError(e) from e
    except openai.APIStatusError as e:
      if e.status_code in RetriableHttpErrorCodes:
        raise RetriableError(e) from e
      else:
        raise
    return APIResponse(
        text=response.choices[0].message.content,
        n_prompt_tokens=response.usage.prompt_tokens,
        n_completion_tokens=response.usage.completion_tokens,
        raw_response=response,
    )


class AnthropicClient(LanguageModelClient):
  """Client for the Anthropic API."""

  def __init__(self, model_name: str):
    self._model_name = model_name
    self._client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

  def _generate(
      self, prompt: str, max_new_tokens: int, do_sample: bool = False
  ) -> APIResponse:
    """Generate a completion for the given prompt."""
    sample_kwargs = dict()
    if not do_sample:
      sample_kwargs["temperature"] = 0.0
    try:
      response = self._client.messages.create(
          model=self._model_name,
          messages=[{"role": "user", "content": prompt}],
          max_tokens=max_new_tokens,
          **sample_kwargs,
      )
    except (
        anthropic.APITimeoutError,
        anthropic.APIConnectionError,
    ) as e:
      raise RetriableError(e) from e
    except anthropic.APIStatusError as e:
      if e.status_code in RetriableHttpErrorCodes:
        raise RetriableError(e) from e
      else:
        raise
    if len(response.content) != 1:
      raise ValueError(
          f"Unexpected response content: {response.content}. Prompt: {prompt}."
          " Full response: {response}."
      )
    text = response.content[0].text
    if not isinstance(text, str):
      raise ValueError(
          f"Unexpected response content: {response.content}. Prompt: {prompt}."
          " Full response: {response}."
      )
    return APIResponse(
        text=text,
        n_prompt_tokens=response.usage.input_tokens,
        n_completion_tokens=response.usage.output_tokens,
        raw_response=response,
    )


PREFIX_TO_CLIENT = {
    "gemini_api/": GeminiClient,
    "openai_api/": OpenAIClient,
    "anthropic_api/": AnthropicClient,
}


def from_model_name(model_name: str) -> LanguageModelClient:
  for prefix, client_class in PREFIX_TO_CLIENT.items():
    if model_name.startswith(prefix):
      return client_class(model_name[len(prefix) :])
  raise UnknownApiError(f"Unknown model name: {model_name}")
