import os
from openai import OpenAI
from src.llm_clients.base_llm_client import BaseLLMClient

class OpenAIClient(BaseLLMClient):
    def __init__(self, model: str = 'gpt-4o-2024-08-06', sys_msg: str = ""):
        API_KEY = os.getenv("OPENAI_API_KEY")
        print(API_KEY)
        self.client = OpenAI(api_key=API_KEY)
        self.model = model
        self.sys_msg = sys_msg

    def call(self, prompt: str, with_log_prob: bool = False, temp: float = 0.0):
        if self.model ==  'o4-mini':
            temp = 1.0
        if not with_log_prob:
            completion = self.client.chat.completions.create(
                model=self.model,
                store=False,
                messages=[
                    # {"role": "system", "content": self.sys_msg},
                    {"role": "user", "content": prompt}
                ],
                temperature=temp
            )
            returned_msg = completion.choices[0].message.content
        else:
            completion = self.client.chat.completions.create(
                model=self.model,
                store=False,
                messages=[
                    {"role": "user", "content": prompt}
                ],
                temperature=temp,
                logprobs=True,
                top_logprobs=20
            )
            returned_msg = completion.choices[0].message.content
            log_probs = completion.choices[0].logprobs.content
        output_tokens = completion.usage.completion_tokens
        input_tokens = completion.usage.prompt_tokens
        total_tokens = completion.usage.total_tokens
        if with_log_prob:
            return returned_msg, log_probs, {"input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens}
        else:
            return returned_msg, {"input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens}

