import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
                
import openai
from models.base_model import LLMModel
from configs.LLM_configs import LLM_CONFIG
from retrying import retry
import logging

LOGGER = logging.getLogger("Uncertainty Align")

class OPENAIModel(LLMModel):
    def __init__(self,
                 model_name: str,
                 temperature: float = 1.0,
                 max_completion_tokens: int = 1024,
                 top_p: float = 1.0):
        super().__init__(model_name)
        self.temperature = temperature
        self.max_completion_tokens = max_completion_tokens
        self.top_p = top_p
        self.set_api()
        self.is_api = True

    def set_api(self):
        openai.base_url = LLM_CONFIG["openai"]["base_url"]
        openai.api_key = LLM_CONFIG["openai"]["api_key"]

    @retry(stop_max_attempt_number=20, wait_fixed=3000)
    def call(self, llm_input: str, return_probs: bool = False, return_top_num: int = 5, do_generation: bool = True):
        """
        Call OpenAI API to get response
        
        Args:
            llm_input: Input text
            return_probs: Whether to return probability distribution
            
        Returns:
            If return_probs is False, return generated text
            If return_probs is True, return (generated_text, probability_distribution) tuple
        """
        # Build request parameters
        params = {
            "model": self.name,
            "messages": [{"role": "system", "content": llm_input}],
            "max_tokens": self.max_completion_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p
        }
        
        if return_probs:
            params["logprobs"] = True
            params["top_logprobs"] = return_top_num

        response = openai.chat.completions.create(**params)
        generated_text = response.choices[0].message.content
        
        if return_probs and hasattr(response.choices[0], "logprobs"):
            logprobs_data = response.choices[0].logprobs
            probs_list = self._format_logprobs(logprobs_data)
            return generated_text, probs_list
        
        return generated_text

    def _format_logprobs(self, logprobs_data):
        """
        Format logprobs data as probability dictionary list, each dictionary corresponds to probabilities at one position
        """
        import math
        
        prob_list = []
        for token_idx, token_info in enumerate(logprobs_data.content):
            position_probs = {}
            
            if hasattr(token_info, "top_logprobs") and token_info.top_logprobs:
                for logprob_info in token_info.top_logprobs:
                    if hasattr(logprob_info, 'token') and hasattr(logprob_info, 'logprob'):
                        token = logprob_info.token.strip()
                        logprob = logprob_info.logprob
                        if token not in position_probs:
                            position_probs[token] = math.exp(logprob)
                        else:
                            position_probs[token] += math.exp(logprob)
                
                prob_list.append(position_probs)
        
        return prob_list