import json
import os
from typing import Literal

from PIL import Image
from ale_bench.utils import pil_to_base64jpeg
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
from anthropic.types import Message
from dotenv import load_dotenv
from openai import OpenAI
from openai.types.chat import ChatCompletion

from common_resource import (
    EXP_ROOT_DIR, SYSTEM_PROMPT,
    BaseLLM, BaseNamespace,
    get_common_argument_parser, main_loop,
)


class AnthropicNamespace(BaseNamespace):
    client: Literal["anthropic", "bedrock", "vertex", "openrouter"]
    model: str
    max_tokens: int
    thinking_budget: int | None


class AnthropicModel(BaseLLM):
    def __init__(
        self,
        client: Anthropic | AnthropicBedrock | AnthropicVertex | OpenAI,
        is_openrouter: bool,
        model: str,
        max_tokens: int,
        thinking_budget: int | None,
        system_prompt: str | None,
        log_dir: str | os.PathLike
    ) -> None:
        super().__init__(system_prompt, log_dir)
        self.client = client
        self.model = model
        self.is_openrouter = is_openrouter
        self.model_kwargs = {"max_tokens": max_tokens}
        if is_openrouter:
            self.model_kwargs["extra_body"] = {
                "provider": {
                    "data_collection": "deny",
                    "require_parameters": True,
                    "order": ["Google", "Anthropic"],
                    "allow_fallbacks": False,
                },
                "usage": {"include": True},
            }
        if thinking_budget is not None:
            if is_openrouter:
                self.model_kwargs["extra_body"]["reasoning"] = {"max_tokens": thinking_budget}
            else:
                self.model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": thinking_budget}
        if system_prompt is not None and not is_openrouter:
            self.model_kwargs["system"] = system_prompt
        self.num_total_messages = 1 if is_openrouter else 0

    def send_user_message(self, contents: list[str | Image.Image]):
        new_message_content = []
        for content in contents:
            if isinstance(content, str):
                new_message_content.append({"type": "text", "text": content})
            elif isinstance(content, Image.Image):
                if self.is_openrouter:
                    new_message_content.append({
                        "type": "image_url", "image_url": {
                            "url": f"data:image/jpeg;base64,{pil_to_base64jpeg(content)}"
                        },
                    })
                else:
                    new_message_content.append({
                        "type": "image", "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": pil_to_base64jpeg(content)
                        },
                    })
            else:
                raise ValueError(f"Unsupported content type: {type(content)}")
        self.messages.append({"role": "user", "content": new_message_content})
        openrouter_message = [{"role": "system", "content": self.system_prompt}] if self.system_prompt is not None else []
        openrouter_message.extend(self.messages)
        # Wait for the assistant's response
        response: ChatCompletion | Message | None = None
        for i in range(1, self.num_retry + 1):
            if self.is_openrouter:
                try:
                    response: ChatCompletion = self.client.chat.completions.create(model=self.model, messages=openrouter_message, **self.model_kwargs)
                except Exception as e:
                    print(f"Trial {i} failed: {e}")
                    continue
                if response is None or response.choices is None or len(response.choices) == 0:
                    print(f"Trial {i} failed: no response from Anthropic.")
                    continue
                if response.choices[0].finish_reason == "error":
                    print(f"Trial {i} failed: the error is returned by Anthropic.")
                    continue
                if response.choices[0].finish_reason == "length":
                    print(f"Trial {i} failed: the response is cut off by Anthropic.")
                    continue
                if response.choices[0].finish_reason == "content_filter":
                    print(f"Trial {i} failed: the content is filtered by Anthropic.")
                    continue
                if response.choices[0].message.content is None or response.choices[0].message.content == "":
                    print(f"Trial {i} failed: empty response from Anthropic.")
                    continue
                self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
                self.num_total_messages += 2
                json.dump(response.model_dump(), open(self.log_dir / f"response_turn{self.num_total_messages:08d}.json", "w"), ensure_ascii=False)
                break
            else:
                try:
                    response: Message = self.client.messages.create(model=self.model, messages=self.messages, **self.model_kwargs)
                except Exception as e:
                    print(f"Trial {i} failed: {e}")
                    continue
                if response is None:
                    print(f"Trial {i} failed: no response from Anthropic.")
                    continue
                if len(response.content) == 0:
                    print(f"Trial {i} failed: empty response from Anthropic.")
                    continue
                if response.stop_reason == "max_tokens":
                    print(f"Trial {i} failed: the response is cut off by Anthropic.")
                    continue
                self.messages.append({"role": response.role, "content": [c.model_dump() for c in response.content]})
                self.num_total_messages += 2
                json.dump(response.model_dump(), open(self.log_dir / f"response_turn{self.num_total_messages:08d}.json", "w"), ensure_ascii=False)
                break
        if response is None:
            self.messages.pop()  # Remove the last user message
            raise ValueError("Failed to get a valid response from Anthropic.")

    def send_user_message_new_thread(self, contents: list[str | Image.Image]) -> None:
        self.messages = []
        self.send_user_message(contents)

    def get_last_response(self) -> str:
        if len(self.messages) == 0:
            raise ValueError("No messages to get response from.")
        if self.messages[-1]["role"] != "assistant":
            raise ValueError("The last message is not from the assistant.")
        if self.is_openrouter:
            return self.messages[-1]["content"]
        for content in self.messages[-1]["content"]:
            if content["type"] == "text":
                return content["text"]
        raise ValueError("The last message does not contain text content.")

    def load_history(self, file_path: str | os.PathLike) -> None:
        data = json.load(open(file_path, "r"))
        self.messages = data["messages"]
        self.system_prompt = data["system"]

    def save_history(self, file_path: str | os.PathLike) -> None:
        with open(file_path, "w") as f:
            json.dump({"messages": self.messages, "system": self.system_prompt}, f, ensure_ascii=False)


def parse_args() -> AnthropicNamespace:
    parser = get_common_argument_parser(description="Run Anthropic model on ALE-Bench.")
    parser.add_argument("--client", type=str, choices=["anthropic", "bedrock", "vertex", "openrouter"], required=True, help="Client to use.")
    parser.add_argument("--model", type=str, required=True, help="Model name to run.")
    parser.add_argument("--max_tokens", type=int, required=True, help="Max tokens for the model.")
    parser.add_argument("--thinking_budget", type=int, default=None, help="Thinking budget tokens for the model.")
    return parser.parse_args(namespace=AnthropicNamespace)


def get_client(client: str) -> Anthropic | AnthropicBedrock | AnthropicVertex | OpenAI:
    if client == "anthropic":
        return Anthropic(
            api_key=os.getenv("ANTHROPIC_API_KEY"),
            auth_token=os.getenv("ANTHROPIC_AUTH_TOKEN"),
        )
    elif client == "bedrock":
        return AnthropicBedrock(
            aws_access_key=os.getenv("AWS_ACCESS_KEY_ID"),
            aws_secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
            aws_session_token=os.getenv("AWS_SESSION_TOKEN"),
            aws_region=os.getenv("AWS_REGION"),
        )
    elif client == "vertex":
        return AnthropicVertex(
            project_id=os.getenv("GCP_PROJECT_ID"),
            region=os.getenv("GCP_REGION"),
        )
    elif client == "openrouter":
        return OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=os.getenv("OPENROUTER_API_KEY"),
        )
    else:
        raise ValueError(f"Unknown client: {client}. Supported clients are: anthropic, bedrock, vertex.")


def main() -> None:
    # Load environment variables from .env file
    load_dotenv(EXP_ROOT_DIR / ".env")

    # Parse command line arguments
    args = parse_args()

    # Get the client
    llm = AnthropicModel(
        client=get_client(args.client),
        model=args.model,
        is_openrouter=args.client == "openrouter",
        max_tokens=args.max_tokens,
        thinking_budget=args.thinking_budget,
        system_prompt=SYSTEM_PROMPT[args.prompt_language],
        log_dir=args.exp_dir / f"llm_log_{args.problem_id}"
    )

    # Start the ALE-Bench main loop
    main_loop(args, llm)


if __name__ == "__main__":
    main()
