import json
import os
from typing import Literal

from PIL import Image
from ale_bench.utils import pil_to_base64jpeg
from dotenv import load_dotenv
from openai import OpenAI
from openai.types.chat import ChatCompletion

from common_resource import (
    EXP_ROOT_DIR, SYSTEM_PROMPT,
    BaseLLM, BaseNamespace,
    get_common_argument_parser, main_loop,
)


DeepSeekClient = Literal["deepseek", "openrouter"]


class DeepSeekNamespace(BaseNamespace):
    client: DeepSeekClient
    providers: str | None
    model: str
    temperature: float | None
    max_tokens: int | None


class DeepSeekModel(BaseLLM):
    def __init__(
        self,
        client: OpenAI,
        is_openrouter: bool,
        providers: list[str] | None,
        model: str,
        system_prompt: str | None,
        temperature: float | None,
        max_tokens: int | None,
        log_dir: str | os.PathLike
    ) -> None:
        super().__init__(system_prompt, log_dir)
        self.client = client
        self.model = model
        if system_prompt is not None:
            self.messages.append({"role": "system", "content": system_prompt})
        self.model_kwargs = {}
        if is_openrouter:
            self.model_kwargs["extra_body"] = {
                "provider": {
                    "data_collection": "deny",
                    "quantizations": ["int8", "fp8", "fp16", "bf16", "fp32"],
                    "require_parameters": True,
                },
                "usage": {"include": True},
            }
            if providers is not None:
                self.model_kwargs["extra_body"]["provider"]["order"] = providers
                self.model_kwargs["extra_body"]["provider"]["allow_fallbacks"] = False
        if temperature is not None:
            self.model_kwargs["temperature"] = temperature
        if max_tokens is not None:
            self.model_kwargs["max_tokens"] = max_tokens
        self.num_total_messages = len(self.messages)

    def send_user_message(self, contents: list[str | Image.Image]):
        # Add a user message
        new_message_content = []
        for content in contents:
            if isinstance(content, str):
                new_message_content.append({"type": "text", "text": content})
            elif isinstance(content, Image.Image):
                new_message_content.append({
                    "type": "image_url", "image_url": {
                        "url": f"data:image/jpeg;base64,{pil_to_base64jpeg(content)}"
                    },
                })
            else:
                raise ValueError(f"Unsupported content type: {type(content)}")
        self.messages.append({"role": "user", "content": new_message_content})
        # Wait for the assistant's response
        response = None
        for i in range(1, self.num_retry + 1):
            try:
                response: ChatCompletion = self.client.chat.completions.create(model=self.model, messages=self.messages, **self.model_kwargs)
            except Exception as e:
                print(f"Trial {i} failed: {e}")
                continue
            if response is None or response.choices is None or len(response.choices) == 0:
                print(f"Trial {i} failed: no response from DeepSeek.")
                continue
            if response.choices[0].finish_reason == "error":
                print(f"Trial {i} failed: the error is returned by DeepSeek.")
                continue
            if response.choices[0].finish_reason == "length":
                print(f"Trial {i} failed: the response is cut off by DeepSeek.")
                continue
            if response.choices[0].finish_reason == "content_filter":
                print(f"Trial {i} failed: the content is filtered by DeepSeek.")
                continue
            if response.choices[0].message.content is None or response.choices[0].message.content == "":
                print(f"Trial {i} failed: empty response from DeepSeek.")
                continue
            self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
            self.num_total_messages += 2
            json.dump(response.model_dump(), open(self.log_dir / f"response_turn{self.num_total_messages:08d}.json", "w"), ensure_ascii=False)
            break
        if response is None:
            self.messages.pop()  # Remove the last user message
            raise ValueError("Failed to get a valid response from DeepSeek.")

    def send_user_message_new_thread(self, contents: list[str | Image.Image]) -> None:
        self.messages = []
        if self.system_prompt is not None:
            self.messages.append({"role": "system", "content": self.system_prompt})
        self.send_user_message(contents)

    def get_last_response(self) -> str:
        if len(self.messages) == 0:
            raise ValueError("No messages to get response from.")
        if self.messages[-1]["role"] != "assistant":
            raise ValueError("The last message is not from the assistant.")
        return self.messages[-1]["content"]

    def load_history(self, file_path: str | os.PathLike) -> None:
        self.messages = json.load(open(file_path, "r"))
        if self.messages[0]["role"] in {"developer", "system"}:
            self.system_prompt = self.messages[0]["content"]
        else:
            self.system_prompt = None

    def save_history(self, file_path: str | os.PathLike) -> None:
        with open(file_path, "w") as f:
            json.dump(self.messages, f, ensure_ascii=False)


def parse_args() -> DeepSeekNamespace:
    parser = get_common_argument_parser(description="Run DeepSeek model on ALE-Bench.")
    parser.add_argument("--client", type=str, choices=["deepseek", "openrouter"], default="openrouter", help="Reasoning effort for the model.")
    parser.add_argument("--providers", type=str, default=None, help="Providers for the model (used in OpenRouter).")
    parser.add_argument("--model", type=str, required=True, help="Model name to run.")
    parser.add_argument("--temperature", type=float, default=None, help="Temperature for the model.")
    parser.add_argument("--max_tokens", type=int, default=None, help="Max tokens for the model.")
    return parser.parse_args(namespace=DeepSeekNamespace)


def get_client(client: DeepSeekClient) -> OpenAI:
    if client == "deepseek":
        # Use the DeepSeek API
        return OpenAI(
            base_url="https://api.deepseek.com",
            api_key=os.getenv("DEEPSEEK_API_KEY"),
        )
    elif client == "openrouter":
        # Use the OpenRouter API
        return OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=os.getenv("OPENROUTER_API_KEY"),
        )
    else:
        raise ValueError(f"Unsupported client: {client}")


def main() -> None:
    # Load environment variables from .env file
    load_dotenv(EXP_ROOT_DIR / ".env")

    # Parse command line arguments
    args = parse_args()

    # Get the client
    providers = None
    if args.client == "openrouter" and args.providers is not None:
        providers = args.providers.split(",")
    llm = DeepSeekModel(
        client=get_client(args.client),
        is_openrouter=args.client == "openrouter",
        providers=providers,
        model=args.model,
        system_prompt=SYSTEM_PROMPT[args.prompt_language],
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        log_dir=args.exp_dir / f"llm_log_{args.problem_id}"
    )

    # Start the ALE-Bench main loop
    main_loop(args, llm)


if __name__ == "__main__":
    main()
