import os
from pathlib import Path
from dotenv import load_dotenv
from openai import OpenAI
from typing import List, Dict
import re, json

env_path = Path(__file__).resolve().parents[0] / '.env'
load_dotenv(env_path)
client = OpenAI()

def openai_call(
    messages: List[dict], llm_model_name: str = os.environ.get("MODEL_NAME"), json_format=False
) -> str:
    """
    Call the OpenAI API to get the response
    Args:
        messages (List[dict]): The messages to send to the OpenAI API
        llm_model_name (str): The name of the LLM model


    Returns:
        response_json_str (str): The response from the OpenAI API
        cost (float): The cost of the response
    """
    result_json_str = ""
    while True:
        if json_format:
            response = client.chat.completions.create(
                model=llm_model_name,
                response_format={"type": "json_object"},
                messages=messages,
                temperature=0.0,
            )
        else:
            response = client.chat.completions.create(
                model=llm_model_name,
                messages=messages,
                temperature=0.0,
                # logprobs = True,
                # top_logprobs = 2
            )
        content = response.choices[0].message.content
        result_json_str += content
        # if finish_reason is length, then it is not complete
        if response.choices[0].finish_reason != "length":
            break
        else:
            messages.append(
                {
                    "role": "assistant",
                    "content": content,
                }
            )
            messages.append(
                {
                    "role": "user",
                    "content": "Continue the response",
                }
            )

    return result_json_str

def get_chat_completion(messages, logger=None) -> str | None:
    for i in range(3):
        try:
            reply = openai_call(messages=messages)
            assert reply, "failed to get reply."
            return reply
        except:
            if logger:
                logger.info(f"{i + 1}/3 - try to get reply for message: {messages}")
            else:
                print(f"{i + 1}/3 - try to get reply for message: {messages}")
        
def get_json_chat_completion(messages, logger=None) -> Dict | None:
    def load_json_result(content):
        # tools
        def extract_json_res(content: str) -> Dict:
            json_marker_pattern = r"```json([\s\S]*?)```"
            match = re.search(json_marker_pattern, content, re.DOTALL)
            json_content = match.group(1)
            dict_content = json.loads(json_content)
            return dict_content
        def llm_json_formatter(content: str) -> str:
            prompt = f'''Please check if the text conforms to JSON format. If it does not, output the correct JSON format result or extract the part in JSON format; if it does, return the original text.
                Please return a valid JSON result, without any extra explanations or symbols.
                TEXT: {content}'''
            res = openai_call(
                messages=[{"role": "user", "content": prompt.format(content)}],
                llm_model_name='gpt-4o-mini'
            )
            return res
        
        # main workflow
        try:
            json_res = json.loads(content)
            return json_res
        except:
            pass
        try:
            json_res = extract_json_res(content=content)
            return json_res
        except Exception as e:
            if logger:
                logger.debug(f"failed to directly load the content: {content}. ERROR: {e}. try to use LLM to format it.")
            else:
                print(f"failed to directly load the content: {content}. ERROR: {e}. try to use LLM to format it.")
        try:
            new_content = llm_json_formatter(content=content)
            json_res = extract_json_res(new_content)
            return json_res
        except Exception as e:
            if logger:
                logger.debug(f"the content: {content}, improved by the LLM: {new_content}, still not conform to the JSON format. ERROR: {e}")
            else:
                print(f"the content: {content}, improved by the LLM: {new_content}, still not conform to the JSON format. ERROR: {e}")
    for _ in range(3):
        try:
            _reply = get_chat_completion(messages=messages, logger=logger)
            reply = load_json_result(_reply)
            assert reply, "failed to get json reply."
            return reply
        except:
            pass
    if logger:
        logger.error(f"failed to get result for message: {messages}")
    else:
        print(f"failed to get result for message: {messages}")


if __name__ == '__main__':
    print(openai_call([{"role": "user", "content": "say it is a test"}]))