from openai import OpenAI
from .base_backend import BaseBackend

class OpenAIBackend(BaseBackend):
    def __init__(self, model_name, temperature=0.7, max_tokens=1024, token_tracker=None, **kwargs):
        self.client = OpenAI()
        self.model = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.token_tracker = token_tracker

    def chat(self, system: str, user: str, agent_id: str = None, return_usage: bool = False):
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": user},
            ],
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        # Token usage logging
        input_tokens = getattr(response.usage, "prompt_tokens", 0)
        output_tokens = getattr(response.usage, "completion_tokens", 0)
        cached_input_tokens = 0
        if getattr(response.usage, "prompt_tokens_details", None) is not None:
            cached_input_tokens = getattr(response.usage.prompt_tokens_details, "cached_tokens", 0)

        if self.token_tracker:
            self.token_tracker.log(
                agent_id=agent_id or "unknown",
                input_tokens=input_tokens,
                output_tokens=output_tokens,
                cached_input_tokens=cached_input_tokens
            )

        answer_text = response.choices[0].message.content.strip()

        if return_usage:
            return {
                "answer_text": answer_text,
                "usage": {
                    "input_tokens": input_tokens,
                    "output_tokens": output_tokens,
                    "cached_input_tokens": cached_input_tokens,
                    "total_tokens": input_tokens + output_tokens
                }
            }

        return answer_text