"""
workflow-version of our idea
* agent: mimo
* tool: multimodal
"""

from argparse import ArgumentParser
import json
import logging
from pathlib import Path
from pprint import pformat
from copy import deepcopy
# from collections import defaultdict

from utils import (
    get_date,
    format_instruction,
    # check_prompt,
)

from utils_api import (
    call_api_single,
)

from tools_workflow import (
    call_function,
)

from prompts import (
    format_initial_user_message,
    load_prompt_template,
)

TOOLS = {
    "1": "sample_frame",
    "2": "check_instruction",
    "3": "check_final_picture",
    "0": "finish",
}


def run_single_example(args, example, template_components, tools, instructions):
    conversation_history = []

    system_developer_message = template_components["system"]
    initial_user_message, initial_user_message_text = format_initial_user_message(
        args.model_id, template_components, example
    )
    # register turn 0
    conversation_history.append(
        [["system", system_developer_message], ["user", initial_user_message_text]]
    )
    for role, _conversation in conversation_history[0]:
        logging.info(f"{role=}")
        logging.info(f"{_conversation=}")

    messages = [initial_user_message]
    answer = ""
    for count_turn, tool_id in enumerate(args.order + "0"):
        logging.info(f"Next: turn {count_turn + 1}")

        current_conversation = []

        tool_name = TOOLS[tool_id]
        logging.info(f"debug: {tool_name=}")
        # call tool
        content, content_text, _ = call_function(
            args,
            example,
            template_components,
            instructions,
            tool_name,
        )
        messages.append({"role": "user", "content": content})
        current_conversation.append(["user-tool", content_text])
        system_developer_message = ""

        try:
            response, _ = call_api_single(args, system_developer_message, messages)
        except Exception:
            logging.exception("Error in run_single_example")
            current_conversation.append(["error", "error in run_single_example"])
            conversation_history.append(current_conversation)
            continue

        if not response.choices:
            current_conversation.append(["error", "no output"])
            conversation_history.append(current_conversation)
            continue

        if len(response.choices) > 1:  # sanity check
            logging.info(
                "[Debug] "
                f"Multiple messages in response: {[x.type for x in response.choices]}"
            )
        latest_assistant_response = ""
        message = response.choices[0].message

        # postprocess & append assistant response to input context
        if message.content:
            messages.append({"role": "assistant", "content": message.content})
            current_conversation.append(["assistant", message.content])
            latest_assistant_response = message.content

        # check if answer is produced
        if "<answer>" in latest_assistant_response:
            answer = latest_assistant_response.split("<answer>")[-1].replace(
                "</answer>", ""
            )

        # logging
        for role, _conversation in current_conversation:
            logging.info(f"{role=}")
            logging.info(f"{_conversation=}")

        conversation_history.append(current_conversation)

    return messages, conversation_history, answer


def check_if_prediction_exists(example, examples_prev):
    corresponding_example = None
    if "examples" in examples_prev:
        for example_prev in examples_prev["examples"]:
            if (
                "question" in example_prev
                and example["question"] == example_prev["question"]
            ):
                if (
                    "prediction" in example_prev
                    and example_prev["prediction"]["answer"].strip()
                ):
                    corresponding_example = deepcopy(example_prev)

    return corresponding_example


def main(args):
    # load data
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    # load instruction
    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    # load prompt template
    template_components = load_prompt_template(args.filepath_template)

    new_examples = []

    # output file
    filepath_output = (
        args.dirpath_output
        / f"{Path(args.model_id).name}_{args.reasoning}_{args.order}.json"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
    else:
        logging.info("Initial attempt")
        examples_prev = []

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    # start inference
    logging.info("Prediction starts ...")
    for idx, example in enumerate(examples):
        logging.info(f"[Example {idx}]")

        # check if prediction already exists
        corresponding_example_prev = check_if_prediction_exists(example, examples_prev)
        if corresponding_example_prev:
            logging.info("Previous prediction exists. Skip.")
            new_examples.append(corresponding_example_prev)
        else:
            try:
                messages, conversation_history, answer = run_single_example(
                    args,
                    example,
                    template_components,
                    [],
                    toy2instruction[example["toy_id"]],
                )
                new_example = deepcopy(example)
                new_example["prediction"] = {
                    "conversation_history": conversation_history,
                    "answer": answer,
                }
                new_examples.append(new_example)
            except Exception:
                logging.exception("An error occurred")
                new_examples.append("Error")

        # save output every time
        final_output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(final_output, f, indent=4)
            f.write("\n")


if __name__ == "__main__":
    parser = ArgumentParser(description="workflow")
    parser.add_argument("--filepath_qa", type=Path, help="filepath for qa input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument(
        "--dirpath_intermediate_output",
        type=Path,
        help="dirpath to store intermediate outputs",
    )
    parser.add_argument("--filepath_tool", type=Path, help="filepath for tools")
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id", default="")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.1)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=4096
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument("--max_turn", type=int, help="max turns", default=10)
    parser.add_argument(
        "--threshold_continue",
        type=int,
        help="the minimum turn to spend before generating answers for each example",
        default=5,
    )
    parser.add_argument(
        "--threshold_max_text_turn",
        type=int,
        help="the max #turn of non-function outputs",
        default=2,
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument(
        "--api_key_for_qwen", type=str, help="dummy api key for qwen", default="dummpy"
    )
    parser.add_argument(
        "--base_url_mimo",
        type=str,
        help="port for mimo",
        default="http://localhost:8088/v1",
    )
    parser.add_argument("--order", type=str, help="order of tool calling")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log / f"workflow_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info("Arguments:")
    logging.info(pformat(vars(args), indent=4, width=100))

    main(args)
