import time
from typing import Union, List, Dict
import openai
import os
from src.utils.extract_json_reliable import extract_json


def get_gpt_output(message: Union[str, List[Dict[str, str]]], 
                    model: str = "gpt-4o", 
                    max_new_tokens: int = 2048, 
                    temperature: float = 1.0, 
                    max_retry: int = 5,
                    sleep_time: int = 60,
                    json_object: bool = False,
                    request_timeout: int = 30,
                    **kwargs) -> str:
    """
    Call the OpenAI API to get the GPT model output for a given prompt.

    Args:
        message (Union[str, List[Dict[str, str]]]): The input message or a list of message dicts.
        model (str): The model to use for completion. Default is "gpt-4o".
        max_new_tokens (int): Maximum number of tokens to generate. Default is 2048.
        temperature (float): Sampling temperature. Default is 1.0.
        max_retry (int): Maximum number of retries in case of an error. Default is 5.
        sleep_time (int): Sleep time between retries in seconds. Default is 60.
        json_object (bool): Whether to output in JSON format. Default is False.

    Returns:
        str: The completed text generated by the GPT model.

    Raises:
        Exception: If the completion fails after the maximum number of retries.
    """
        
    if json_object:
        if isinstance(message, str) and 'json' not in message.lower():
            message = 'You are a helpful assistant designed to output JSON. ' + message
        
    if isinstance(message, str):
        messages = [{"role": "user", "content": message}]
    else:
        messages = message

    if json_object:
        kwargs.update({"response_format": {"type": "json_object"}}) 

    client = openai.OpenAI()

    base_model = model
    reasoning_effort = None
    prefix = "gpt-5.2-"
    if isinstance(model, str) and model.startswith(prefix) and len(model) > len(prefix):
        base_model = "gpt-5.2"
        reasoning_effort = model[len(prefix) :]
    if reasoning_effort and "reasoning_effort" not in kwargs:
        kwargs["reasoning_effort"] = reasoning_effort

    is_o_series = base_model.startswith('o1') or base_model.startswith('o3') or base_model.startswith('o4')
    is_gpt_52 = base_model.startswith("gpt-5.2")

    for cnt in range(max_retry):
        try:
            # "o" handling parameters
            if is_o_series:
                chat = client.chat.completions.create(
                    messages=messages,
                    model=base_model,
                    max_completion_tokens=max_new_tokens,
                    timeout=request_timeout,
                    **kwargs
                )
            elif is_gpt_52:
                chat = client.chat.completions.create(
                    messages=messages,
                    model=base_model,
                    temperature=temperature,
                    max_completion_tokens=max_new_tokens,
                    timeout=request_timeout,
                    **kwargs
                )
            else:
                chat = client.chat.completions.create(
                    messages=messages,
                    model=base_model,
                    temperature=temperature,
                    max_tokens=max_new_tokens,
                    timeout=request_timeout,
                    **kwargs
                )
            print("get_gpt_output: ", chat.choices[0].message.content)
            if chat.choices[0].message.content is None:
                print('Warning! None response! Could be due to safety filter.\n', chat)
                return ' '
            if json_object:
                return extract_json(chat.choices[0].message.content)
            else:
                return chat.choices[0].message.content
        except Exception as e:
            print(f"Attempt {cnt} failed: {e}. Retrying after {sleep_time} seconds...")
            time.sleep(sleep_time)
    
    raise Exception("Failed to get GPT output after maximum retries")

