"""
workflow-version of our idea
* agent: gemini models
* tool: multimodal
"""

from argparse import ArgumentParser
import json
import logging
import os
from pathlib import Path
from pprint import pprint, pformat
from copy import deepcopy
from collections import defaultdict

from utils import (
    get_date,
    format_instruction,
    # check_prompt,
)

from utils_api import (
    estimate_cost,
    call_api_single,
    GOOGLE_USAGE_KEYS,
)

from tools import (
    load_tools,
)
from tools_workflow import (
    call_function,
)

from prompts import (
    format_initial_user_message,
    load_prompt_template,
)

from google import genai
from google.genai import types

TOOLS = {
    "1": "sample_frame",
    "2": "check_instruction",
    "3": "check_final_picture",
    "0": "finish",
}


def delete_all_uploaded_files():
    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
    logging.info("Delete all uploaded files ...")
    for file in client.files.list():
        client.files.delete(name=file.name)
    return None


def run_single_example(args, example, template_components, tools, instructions):
    count_turn = 0  # count conversational turns
    conversation_history = []
    total_usage = defaultdict(int)

    system_developer_message = template_components["system"]
    initial_user_message, initial_user_message_text = format_initial_user_message(
        args.model_id, template_components, example
    )
    # register turn 0
    conversation_history.append(
        [["system", system_developer_message], ["user", initial_user_message_text]]
    )

    contents = [initial_user_message]
    answer = ""
    all_uploaded_files = []
    for count_turn, tool_id in enumerate(args.order + "0"):
        logging.info(f"Next: turn {count_turn + 1}")

        current_conversation = []

        tool_name = TOOLS[tool_id]
        logging.info(f"debug: {tool_name=}")
        # call tool
        parts_tool, parts_tool_text, uploaded_files = call_function(
            args,
            example,
            template_components,
            instructions,
            tool_name,
        )
        contents.append(types.Content(role="user", parts=parts_tool))
        current_conversation.append(["user-tool", parts_tool_text])
        all_uploaded_files += uploaded_files
        system_developer_message = ""

        try:
            response, usage = call_api_single(args, system_developer_message, contents)
        except Exception:
            logging.exception("Error in run_single_example")
            conversation_history.append(["Error"])
            continue

        for key in GOOGLE_USAGE_KEYS:
            total_usage[key] += usage[key]

        if not response.candidates:
            logging.error("No response generated.")
            break

        if len(response.candidates) > 1:  # sanity check
            logging.info("[Debug] Multiple candidates in response")
        response = response.candidates[0]

        # append model's response
        contents.append(response.content)

        # postprocess
        latest_assistant_response = ""
        for part in response.content.parts:
            if part.thought:
                current_conversation.append(["assistant", part.text])
            elif part.text:
                current_conversation.append(["assistant", part.text])
                latest_assistant_response = part.text
            else:
                logging.warning(f"Uncovered type of {part=}")

        # check if answer is produced
        if "<answer>" in latest_assistant_response:
            answer = latest_assistant_response.split("<answer>")[-1].replace(
                "</answer>", ""
            )

        # logging
        for _conversation in current_conversation:
            logging.info(f"[{_conversation[0]}]")
            logging.info(_conversation[1])

        conversation_history.append(current_conversation)

    cost = estimate_cost(args.model_id, total_usage)

    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
    for file in all_uploaded_files:
        client.files.delete(name=file.name)

    return contents, conversation_history, answer, cost


def check_if_prediction_exists(example, examples_prev):
    corresponding_example = None
    if "examples" in examples_prev:
        for example_prev in examples_prev["examples"]:
            if (
                "question" in example_prev
                and example["question"] == example_prev["question"]
            ):
                if (
                    "prediction" in example_prev
                    and example_prev["prediction"]["answer"].strip()
                ):
                    corresponding_example = deepcopy(example_prev)

    return corresponding_example


def main(args):
    # load data
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    # load instruction
    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    # load prompt template
    template_components = load_prompt_template(args.filepath_template)

    # load tools
    tools = load_tools(args)
    logging.info(pformat(tools, width=100))

    new_examples, total_cost = [], 0

    # output file
    filepath_output = (
        args.dirpath_output / f"{args.model_id}_{args.reasoning}_{args.order}.json"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    # start inference
    logging.info("Prediction starts ...")
    for idx, example in enumerate(examples):
        logging.info(f"[Example {idx}]")

        # check if prediction already exists
        corresponding_example_prev = check_if_prediction_exists(example, examples_prev)
        if corresponding_example_prev:
            logging.info("Previous prediction exists. Skip.")
            new_examples.append(corresponding_example_prev)
        else:
            try:
                messages, conversation_history, answer, cost = run_single_example(
                    args,
                    example,
                    template_components,
                    tools,
                    toy2instruction[example["toy_id"]],
                )
                new_example = deepcopy(example)
                new_example["prediction"] = {
                    "conversation_history": conversation_history,
                    "answer": answer,
                    "cost": cost,
                }
                new_examples.append(new_example)
                total_cost += cost
            except Exception:
                logging.exception("An error occurred")
                new_examples.append("Error")

        # save output every time
        metadata["cost"] = f"{total_cost:.3f}"
        final_output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(final_output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(description="workflow")
    parser.add_argument("--filepath_qa", type=Path, help="filepath for qa input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument(
        "--dirpath_intermediate_output",
        type=Path,
        help="dirpath to store intermediate outputs",
    )
    parser.add_argument("--filepath_tool", type=Path, help="filepath for tools")
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id", default="")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.1)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=128
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument("--max_turn", type=int, help="max turns", default=10)
    parser.add_argument(
        "--threshold_continue",
        type=int,
        help="the minimum turn to spend before generating answers for each example",
        default=5,
    )
    parser.add_argument(
        "--threshold_max_text_turn",
        type=int,
        help="the max #turn of non-function outputs",
        default=2,
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--budget_tokens",
        type=int,
        help="thinking budget tokens for anthropic api",
        default=2048,
    )
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--order", type=str, help="order of tool calling")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log / f"workflow_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info("Arguments:")
    pprint(vars(args))

    main(args)
