import base64
from collections import defaultdict
import logging
import os
from pathlib import Path
import sys
from copy import deepcopy

from openai import OpenAI
from anthropic import Anthropic
from google import genai
from google.genai import types


PRICE = {
    "gpt-5-2025-08-07": {
        "input": 1.25 / 1e6,
        "cached": 0.125 / 1e6,
        "output": 10 / 1e6,
    },
    "gpt-5-mini-2025-08-07": {
        "input": 0.25 / 1e6,
        "cached": 0.025 / 1e6,
        "output": 2 / 1e6,
    },
    "gpt-4o-2024-11-20": {
        "input": 2.5 / 1e6,
        "cached": 1.25 / 1e6,
        "output": 10 / 1e6,
    },
    "o3-2025-04-16": {
        "input": 2 / 1e6,
        "cached": 0.5 / 1e6,
        "output": 8 / 1e6,
    },
    "claude-3-5-haiku-20241022": {
        "input_tokens": 0.8 / 1e6,
        "cache_creation_input_tokens": 1 / 1e6,
        "cache_read_input_tokens": 0.08 / 1e6,
        "output_tokens": 4 / 1e6,
    },
    "claude-3-7-sonnet-20250219": {
        "input_tokens": 3 / 1e6,
        "cache_creation_input_tokens": 3.75 / 1e6,
        "cache_read_input_tokens": 0.3 / 1e6,
        "output_tokens": 15 / 1e6,
    },
    "claude-sonnet-4-20250514": {
        "input_tokens": 3 / 1e6,
        "cache_creation_input_tokens": 3.75 / 1e6,
        "cache_read_input_tokens": 0.3 / 1e6,
        "output_tokens": 15 / 1e6,
    },
    "claude-opus-4-20250514": {
        "input_tokens": 15 / 1e6,
        "cache_creation_input_tokens": 18.75 / 1e6,
        "cache_read_input_tokens": 1.5 / 1e6,
        "output_tokens": 75 / 1e6,
    },
    "claude-opus-4-1-20250805": {
        "input_tokens": 15 / 1e6,
        "cache_creation_input_tokens": 18.75 / 1e6,
        "cache_read_input_tokens": 1.5 / 1e6,
        "output_tokens": 75 / 1e6,
    },
    "gemini-2.5-pro": {  # price for <=200k
        "prompt_token_count": 1.25 / 1e6,
        "cached_content_token_count": 0.31 / 1e6,
        "candidates_token_count": 10 / 1e6,
        "thoughts_token_count": 10 / 1e6,
    },
    "gemini-2.5-flash": {
        "prompt_token_count": 0.30 / 1e6,
        "cached_content_token_count": 0.075 / 1e6,
        "candidates_token_count": 2.5 / 1e6,
        "thoughts_token_count": 2.5 / 1e6,
    },
    "gemini-2.5-flash-lite": {
        "prompt_token_count": 0.1 / 1e6,
        "cached_content_token_count": 0.025 / 1e6,
        "candidates_token_count": 0.4 / 1e6,
        "thoughts_token_count": 0.4 / 1e6,
    },
}

OPENAI_MODELS = [
    "gpt-4o-2024-11-20",
    "gpt-5-2025-08-07",
    "gpt-5-mini-2025-08-07",
    "o3-2025-04-16",
]
ANTHROPIC_MODELS = [
    "claude-3-5-haiku-20241022",
    "claude-3-7-sonnet-20250219",
    "claude-sonnet-4-20250514",
    "claude-opus-4-20250514",
    "claude-opus-4-1-20250805",
]
GOOGLE_MODELS = [
    "gemini-2.5-pro",
    "gemini-2.5-flash",
    "gemini-2.5-flash-lite",
]
QWEN_MODELS = [
    "Qwen/Qwen2.5-VL-7B-Instruct",
    "Qwen/Qwen2.5-VL-32B-Instruct",
    "Qwen/Qwen2.5-VL-72B-Instruct",
]
INTERNVL_MODELS = [
    "OpenGVLab/InternVL3-38B",
]
MIMO_MODELS = [
    "XiaomiMiMo/MiMo-VL-7B-RL-2508",
]

OPENAI_USAGE_KEYS = ["input", "cached", "output"]
ANTHROPIC_USAGE_KEYS = [
    "input_tokens",
    "cache_creation_input_tokens",
    "cache_read_input_tokens",
    "output_tokens",
]
GOOGLE_USAGE_KEYS = [
    "prompt_token_count",  # for input
    "cached_content_token_count",  # for cached input
    "candidates_token_count",  # for output
    "thoughts_token_count",  # for thinking
]


def call_openai_single_response_api(
    args,
    system_developer_message: str,
    messages: list,
    tools: list = [],
    tool_choice: str = "none",
) -> list[str]:
    """
    call OpenAI API, given single "messages"

    """
    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

    try:
        if args.reasoning:
            assert args.model_id in [
                "gpt-5-2025-08-07",
                "gpt-5-mini-2025-08-07",
                "o3-2025-04-16",
            ]
            response = client.responses.create(
                model=args.model_id,
                reasoning={
                    "effort": args.reasoning_effort,
                    "summary": args.summary_type,
                },
                parallel_tool_calls=False,
                instructions=system_developer_message,
                tools=tools,
                tool_choice=tool_choice,
                input=messages,
            )
        else:
            if args.model_id in ["gpt-5-2025-08-07", "gpt-5-mini-2025-08-07"]:
                response = client.responses.create(
                    model=args.model_id,
                    parallel_tool_calls=False,
                    max_output_tokens=args.max_tokens,
                    instructions=system_developer_message,
                    tools=tools,
                    tool_choice=tool_choice,
                    input=messages,
                )
            else:
                response = client.responses.create(
                    model=args.model_id,
                    parallel_tool_calls=False,
                    temperature=args.temperature,
                    max_output_tokens=args.max_tokens,
                    instructions=system_developer_message,
                    tools=tools,
                    tool_choice=tool_choice,
                    input=messages,
                )

        usage = defaultdict(int)
        cached_input_tokens = response.usage.input_tokens_details.cached_tokens
        usage["input"] += response.usage.input_tokens - cached_input_tokens
        usage["cached"] += cached_input_tokens
        usage["output"] += response.usage.output_tokens
    except Exception:
        logging.exception("An error occurred")
        sys.exit("Error")

    return response, usage


def call_anthropic_message_api_single(
    args,
    system_developer_message: str,
    messages: list,
    tools: list = [],
    tool_choice: str = "none",
    benchmark: bool = False,
) -> list[str]:
    """
    call Anthropic API, given single "messages"

    """
    client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])

    if not benchmark:
        # add cache to tools
        tools_with_cache = deepcopy(tools)
        tools_with_cache[-1]["cache_control"] = {"type": "ephemeral"}

        # add cache to messages
        messages_with_cache = deepcopy(messages)
        if messages[-1]["role"] == "user":
            if isinstance(messages[-1]["content"], list):
                messages_with_cache[-1]["content"][-1]["cache_control"] = {
                    "type": "ephemeral"
                }
            else:
                logging.info("[debug] messages[-1]['content'] is not list")
        elif messages[-2]["role"] == "user":
            if isinstance(messages[-2]["content"], list):
                messages_with_cache[-1]["content"][-1]["cache_control"] = {
                    "type": "ephemeral"
                }
            else:
                logging.info("[debug] messages[-2]['content'] is not list")
        else:
            logging.info("[debug] No user hit in the last two messages so no caching")
        try:
            if args.reasoning:
                assert args.model_id in [
                    "claude-opus-4-1-20250805",
                    "claude-opus-4-20250514",
                    "claude-sonnet-4-20250514",
                    "claude-3-7-sonnet-20250219",
                ]

                response = client.messages.create(
                    model=args.model_id,
                    thinking={
                        "type": "enabled",
                        "budget_tokens": args.budget_tokens,
                    },
                    max_tokens=args.max_tokens,
                    system=[
                        {
                            "type": "text",
                            "text": system_developer_message,
                            "cache_control": {"type": "ephemeral"},
                        }
                    ],
                    tools=tools_with_cache,
                    tool_choice={"type": tool_choice, "disable_parallel_tool_use": True},
                    messages=messages_with_cache,
                )
            else:
                response = client.messages.create(
                    model=args.model_id,
                    thinking={
                        "type": "disabled",
                    },
                    temperature=args.temperature,
                    max_tokens=args.max_tokens,
                    system=system_developer_message,
                    tools=tools_with_cache,
                    tool_choice={"type": tool_choice, "disable_parallel_tool_use": True},
                    messages=messages_with_cache,
                )
            usage = defaultdict(int)
            for key in ANTHROPIC_USAGE_KEYS:
                usage[key] = dict(response.usage)[key]
            logging.info(f"{response.usage=}")
        except Exception:
            logging.exception("An error occurred")
            response, usage = None, defaultdict(int)
    else:  # benchmark
        try:
            if args.reasoning:
                response = client.messages.create(
                    model=args.model_id,
                    thinking={
                        "type": "enabled",
                        "budget_tokens": args.budget_tokens,
                    },
                    max_tokens=args.max_tokens,
                    system=system_developer_message,
                    messages=messages,
                )
            else:
                response = client.messages.create(
                    model=args.model_id,
                    thinking={
                        "type": "disabled",
                    },
                    temperature=args.temperature,
                    max_tokens=args.max_tokens,
                    system=system_developer_message,
                    messages=messages,
                )
            usage = defaultdict(int)
            for key in ANTHROPIC_USAGE_KEYS:
                usage[key] = dict(response.usage)[key]
            logging.info(f"{response.usage=}")
        except Exception:
            logging.exception("An error occurred")
            response, usage = None, defaultdict(int)

    return response, usage


def call_google_message_api_single(
    args,
    system_developer_message: str,
    messages: list,
    tools: list = [],
    tool_choice: str = "none",
) -> list[str]:
    """
    call Google API, given single "messages"

    """
    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

    try:
        # note: seems like it's possible to use python functions directly as is
        automatic_function_calling = types.AutomaticFunctionCallingConfig(disable=True)
        if tools and tool_choice != "none":
            tool_config = types.ToolConfig(
                function_calling_config=types.FunctionCallingConfig(mode="AUTO")
            )
        else:
            tool_config = types.ToolConfig(
                function_calling_config=types.FunctionCallingConfig(mode="NONE")
            )
        thinking_config = types.ThinkingConfig(
            # better to have if statement for dynamic thinking?
            thinking_budget=args.budget_tokens if args.reasoning else 0,
            include_thoughts=True if args.reasoning else 0,
        )

        response = client.models.generate_content(
            model=args.model_id,
            config=types.GenerateContentConfig(
                system_instruction=system_developer_message,
                tools=tools,
                automatic_function_calling=automatic_function_calling,
                tool_config=tool_config,
                thinking_config=thinking_config,
                temperature=args.temperature,
                # max_output_tokens=args.max_tokens,  # no output if enabled, bug?
            ),
            contents=messages,
        )
        usage = defaultdict(int)
        for key in GOOGLE_USAGE_KEYS:
            usage[key] = (
                int(dict(response.usage_metadata)[key])
                if dict(response.usage_metadata)[key]
                else 0
            )
        # logging.info(f"debug(call_google_message_api_single): {response.usage_metadata=}")
    except Exception:
        logging.exception("An error occurred in call_google_message_api_single")
        raise Exception("An error occurred in call_google_message_api_single")

    return response, usage


def call_qwen_vllm_single(args, messages, tools, tool_choice):
    """
    use vLLM server mode

    """
    client = OpenAI(api_key=args.api_key_for_qwen, base_url=args.base_url_qwen)

    try:
        response = client.chat.completions.create(
            model=args.model_id,
            temperature=args.temperature,
            max_completion_tokens=args.max_tokens,
            tools=tools,
            tool_choice=tool_choice,
            messages=messages,
        )

    except Exception:
        logging.exception("An error occurred in call_qwen_vllm_single")
        response = None

    return response


def call_mimo_vllm_single(args, messages, tools, tool_choice):
    """
    use vLLM server mode
    note: looks like this is general format?

    """
    client = OpenAI(api_key=args.api_key_for_qwen, base_url=args.base_url_mimo)

    try:
        response = client.chat.completions.create(
            model=args.model_id,
            temperature=args.temperature,
            max_completion_tokens=args.max_tokens,
            tools=tools,
            tool_choice=tool_choice,
            messages=messages,
        )

    except Exception:
        logging.exception("An error occurred in call_mimo_vllm_single")
        response = None

    return response


def call_internvl_vllm_single(args, messages, tools, tool_choice):
    """
    use vLLM server mode
    note: looks like this is general format?

    """
    client = OpenAI(api_key=args.api_key_for_qwen, base_url=args.base_url_internvl)

    try:
        response = client.chat.completions.create(
            model=args.model_id,
            temperature=args.temperature,
            max_completion_tokens=args.max_tokens,
            tools=tools,
            tool_choice=tool_choice,
            messages=messages,
        )

    except Exception:
        logging.exception("An error occurred in call_internvl_vllm_single")
        response = None

    return response


def call_api_single(
    args,
    system_developer_message,
    messages,
    tools: list = [],
    tool_choice: str = "none",
    benchmark: bool = False,
):
    match args.model_id:
        case x if x in OPENAI_MODELS:
            response, usage = call_openai_single_response_api(
                args,
                system_developer_message,
                messages,
                tools,
                tool_choice,
            )
        case x if x in ANTHROPIC_MODELS:
            response, usage = call_anthropic_message_api_single(
                args,
                system_developer_message,
                messages,
                tools,
                tool_choice,
                benchmark,
            )
        case x if x in GOOGLE_MODELS:
            response, usage = call_google_message_api_single(
                args,
                system_developer_message,
                messages,
                tools,
                tool_choice,
            )
        case x if x in QWEN_MODELS:
            response = call_qwen_vllm_single(
                args,
                messages,
                tools,
                tool_choice,
            )
            usage = defaultdict(int)
        case x if x in MIMO_MODELS:
            response = call_mimo_vllm_single(
                args,
                messages,
                tools,
                tool_choice,
            )
            usage = defaultdict(int)
        case x if x in INTERNVL_MODELS:
            response = call_internvl_vllm_single(
                args,
                messages,
                tools,
                tool_choice,
            )
            usage = defaultdict(int)
        case _:
            sys.exit(f"Undefined (call_api) {args.model_id}")

    return response, usage


def estimate_cost(model_id: str, usage: dict[str, int]) -> float:
    """estimate cost"""
    if model_id in OPENAI_MODELS:
        keys = OPENAI_USAGE_KEYS
    elif model_id in ANTHROPIC_MODELS:
        keys = ANTHROPIC_USAGE_KEYS
    elif model_id in GOOGLE_MODELS:
        keys = GOOGLE_USAGE_KEYS
    elif model_id in QWEN_MODELS + MIMO_MODELS + INTERNVL_MODELS:
        keys = []
    else:
        logging.error(f"Undefined (estimate cost) for {model_id=}")
    cost = sum([PRICE[model_id][key] * (usage[key] if usage[key] else 0) for key in keys])
    logging.info(f"Estimated cost: ${cost:.4f}.")
    return cost


def encode_image(filepath: Path):
    with open(filepath, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")
