import json
import os
from typing import Literal

from PIL import Image
from ale_bench.utils import pil_to_base64jpeg
from dotenv import load_dotenv
from openai import OpenAI

from common_resource import (
    EXP_ROOT_DIR, SYSTEM_PROMPT,
    BaseLLM, BaseNamespace,
    get_common_argument_parser, main_loop,
)


ReasoningEffort = Literal["low", "medium", "high"]


class OpenAINamespace(BaseNamespace):
    model: str
    reasoning_effort: ReasoningEffort | None


class OpenAIModel(BaseLLM):
    def __init__(
        self,
        client: OpenAI,
        model: str,
        reasoning_effort: ReasoningEffort | None,
        system_prompt: str | None,
        log_dir: str | os.PathLike
    ) -> None:
        super().__init__(system_prompt, log_dir)
        if system_prompt is not None:
            self.messages.append({"role": "developer", "content": system_prompt})
        self.client = client
        self.model = model
        self.model_kwargs = {}
        if reasoning_effort is not None:
            self.model_kwargs["reasoning_effort"] = reasoning_effort
        self.num_total_messages = len(self.messages)

    def send_user_message(self, contents: list[str | Image.Image]):
        # Add a user message
        new_message_content = []
        for content in contents:
            if isinstance(content, str):
                new_message_content.append({"type": "text", "text": content})
            elif isinstance(content, Image.Image):
                new_message_content.append({
                    "type": "image_url", "image_url": {
                        "url": f"data:image/jpeg;base64,{pil_to_base64jpeg(content)}"
                    },
                })
            else:
                raise ValueError(f"Unsupported content type: {type(content)}")
        self.messages.append({"role": "user", "content": new_message_content})
        # Wait for the assistant's response
        response = None
        for i in range(1, self.num_retry + 1):
            try:
                response = self.client.chat.completions.create(model=self.model, messages=self.messages, **self.model_kwargs)
            except Exception as e:
                print(f"Trial {i} failed: {e}")
                continue
            if response is None or len(response.choices) == 0:
                print(f"Trial {i} failed: no response from OpenAI.")
                continue
            if response.choices[0].finish_reason == "length":
                print(f"Trial {i} failed: the response is cut off by OpenAI.")
                continue
            if response.choices[0].finish_reason == "content_filter":
                print(f"Trial {i} failed: the content is filtered by OpenAI.")
                continue
            self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
            self.num_total_messages += 2
            json.dump(response.model_dump(), open(self.log_dir / f"response_turn{self.num_total_messages:08d}.json", "w"), ensure_ascii=False)
            break
        if response is None:
            self.messages.pop()  # Remove the last user message
            raise ValueError("Failed to get a valid response from OpenAI.")

    def send_user_message_new_thread(self, contents: list[str | Image.Image]) -> None:
        self.messages = []
        if self.system_prompt is not None:
            self.messages.append({"role": "developer", "content": self.system_prompt})
        self.send_user_message(contents)

    def get_last_response(self) -> str:
        if len(self.messages) == 0:
            raise ValueError("No messages to get response from.")
        if self.messages[-1]["role"] != "assistant":
            raise ValueError("The last message is not from the assistant.")
        return self.messages[-1]["content"]

    def load_history(self, file_path: str | os.PathLike) -> None:
        self.messages = json.load(open(file_path, "r"))
        if self.messages[0]["role"] in {"developer", "system"}:
            self.system_prompt = self.messages[0]["content"]
        else:
            self.system_prompt = None

    def save_history(self, file_path: str | os.PathLike) -> None:
        with open(file_path, "w") as f:
            json.dump(self.messages, f, ensure_ascii=False)


def parse_args() -> OpenAINamespace:
    parser = get_common_argument_parser(description="Run OpenAI model on ALE-Bench.")
    parser.add_argument("--model", type=str, required=True, help="Model name to run.")
    parser.add_argument("--reasoning_effort", type=str, choices=["low", "medium", "high"], default=None, help="Reasoning effort for the model.")
    return parser.parse_args(namespace=OpenAINamespace)


def get_client() -> OpenAI:
    return OpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        organization=os.getenv("OPENAI_ORG_ID"),
        project=os.getenv("OPENAI_PROJECT_ID"),
    )


def main() -> None:
    # Load environment variables from .env file
    load_dotenv(EXP_ROOT_DIR / ".env")

    # Parse command line arguments
    args = parse_args()

    # Get the client
    llm = OpenAIModel(
        client=get_client(),
        model=args.model,
        reasoning_effort=args.reasoning_effort,
        system_prompt=SYSTEM_PROMPT[args.prompt_language],
        log_dir=args.exp_dir / f"llm_log_{args.problem_id}"
    )

    # Start the ALE-Bench main loop
    main_loop(args, llm)


if __name__ == "__main__":
    main()
