"""
Inference code for proprietary models via API
* Input: video (frames), instruction, question
* Output: answer
"""

from argparse import ArgumentParser
from copy import deepcopy
import json
import logging
from pathlib import Path
import yaml
import os
from tqdm import tqdm

from google import genai

from utils import (
    get_date,
    format_instruction,
)

from utils_api import (
    call_api_single,
    estimate_cost,
    OPENAI_MODELS,
    ANTHROPIC_MODELS,
    GOOGLE_MODELS,
    QWEN_MODELS,
    INTERNVL_MODELS,
    MIMO_MODELS,
)

from prompts import (
    format_benchmark_input,
)


def run_single_example(args, example, template_components, instructions):
    # preprocess
    content, text_prompt, uploaded_files = format_benchmark_input(
        args, example, template_components, instructions
    )
    if args.model_id in OPENAI_MODELS:
        messages = [{"role": "user", "content": content}]
    elif args.model_id in GOOGLE_MODELS:
        messages = content
    elif args.model_id in ANTHROPIC_MODELS:
        messages = [{"role": "user", "content": content}]
    elif args.model_id in QWEN_MODELS + INTERNVL_MODELS + MIMO_MODELS:
        messages = [{"role": "user", "content": content}]
    else:
        logging.error(f"Undefined {args.model_id=}")
    system_developer_message = ""  # note: hard-coded

    # call api
    response, usage = call_api_single(
        args, system_developer_message, messages, benchmark=True
    )

    # postprocess
    answer, thinking = "", ""
    if args.model_id in OPENAI_MODELS:
        for block in response.output:
            match block.type:
                case "message":
                    answer += "\n".join([x.text for x in block.content])
                case "reasoning":
                    thinking += "\n".join([x.text for x in block.summary])
    elif args.model_id in GOOGLE_MODELS:
        for candidate in response.candidates:
            for part in candidate.content.parts:
                if not part.text:
                    continue
                if part.thought:
                    thinking += f"{part.text}\n"
                else:
                    answer += f"{part.text}\n"
        logging.info(f"debug {answer.strip()=} vs {response.text=}")
    elif args.model_id in ANTHROPIC_MODELS:
        for block in response.content:
            match block.type:
                case "thinking":
                    thinking += f"{block.thinking}\n"
                case "text":
                    answer += f"{block.text}\n"
    elif args.model_id in QWEN_MODELS + INTERNVL_MODELS:
        if len(response.choices) > 1:  # sanity check
            logging.error(f"[Debug] Multiple choices/candidates in {response=}")
        else:
            message = response.choices[0].message
        if message.content:
            answer = message.content
            if "[Answer]" in answer:
                splits = answer.split("[Answer]")
                thinking = splits[0].replace("[Rationale]", "")
                answer = splits[1].strip()
    elif args.model_id in MIMO_MODELS:
        if len(response.choices) > 1:  # sanity check
            logging.error(f"[Debug] Multiple choices/candidates in {response=}")
        else:
            message = response.choices[0].message
        if message.content:
            answer = message.content
            if "/think" in answer:
                splits = answer.split("</think>")
                thinking = splits[0].replace("<think>", "").strip()
                answer = splits[1].strip()
    else:
        logging.error(f"Undefined {args.model_id=}")

    if args.model_id in GOOGLE_MODELS:
        client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
        logging.info("Delete uploaded files ...")
        for file in uploaded_files:
            client.files.delete(name=file.name)

    cost = estimate_cost(args.model_id, usage)

    return text_prompt, answer.strip(), thinking.strip(), cost


def check_if_prediction_exists(example, examples_prev):
    corresponding_example = None
    if "examples" in examples_prev:
        for example_prev in examples_prev["examples"]:
            if (
                "question" in example_prev
                and example["question"] == example_prev["question"]
            ):
                if (
                    "prediction" in example_prev
                    and example_prev["prediction"]["answer"].strip()
                ):
                    corresponding_example = deepcopy(example_prev)

    return corresponding_example


def main(args):
    with open(args.filepath_input, "r") as f:
        examples = json.load(f)

    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    with open(args.filepath_template, "r") as f:
        template_components = yaml.safe_load(f)

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    logging.info("Prediction starts ...")
    new_examples, total_cost = [], 0

    filepath_output = (
        args.dirpath_output
        / f"{Path(args.model_id).name}_{args.reasoning}_{args.filepath_input.name}"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []

    for idx, example in tqdm(enumerate(examples), total=len(examples)):
        logging.info(f"[Example {idx}]")

        # check if prediction already exists
        corresponding_example_prev = check_if_prediction_exists(example, examples_prev)
        if corresponding_example_prev:
            logging.info("Previous prediction exists. Skip.")
            new_examples.append(corresponding_example_prev)
        else:
            try:
                text_prompt, answer, thinking, cost = run_single_example(
                    args,
                    example,
                    template_components,
                    toy2instruction[example["toy_id"]],
                )
                new_example = deepcopy(example)
                new_example["prediction"] = {
                    "text_prompt": text_prompt,
                    "answer": answer,
                    "thinking": thinking,
                    "cost": cost,
                }
                new_examples.append(new_example)
                total_cost += cost
            except Exception:
                logging.exception("Error occurred.")
                new_examples.append("Error")

        metadata["cost"] = f"{total_cost:.3f}"
        output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(description="Inference code")
    parser.add_argument("--filepath_input", type=Path, help="filepath for input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.0)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=4608
    )
    parser.add_argument(
        "--budget_tokens", type=int, help="max budget tokens for reasoning", default=4098
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--api_key_for_qwen", type=str, help="dummy api key for qwen", default="dummpy"
    )
    parser.add_argument(
        "--base_url_qwen",
        type=str,
        help="port for qwen",
        default="http://localhost:8000/v1",
    )
    parser.add_argument(
        "--base_url_mimo",
        type=str,
        help="port for mimo",
        default="http://localhost:8088/v1",
    )
    parser.add_argument(
        "--base_url_internvl",
        type=str,
        help="port for internvl",
        default="http://localhost:8007/v1",
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir()

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log
                / f"benchmark_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info(f"Arguments: {vars(args)}")

    main(args)
