"""
workflow-version of our idea
* agent: openai models
* tool: multimodal
"""

from argparse import ArgumentParser
import json
import logging
from pathlib import Path
from pprint import pprint, pformat
from copy import deepcopy
from collections import defaultdict

from utils import (
    get_date,
    format_instruction,
    # check_prompt,
)

from utils_api import (
    estimate_cost,
    call_api_single,
    OPENAI_USAGE_KEYS,
)

from tools import (
    load_tools,
)
from tools_workflow import (
    call_function,
)

from prompts import (
    format_initial_user_message,
    load_prompt_template,
)

TOOLS = {
    "1": "sample_frame",
    "2": "check_instruction",
    "3": "check_final_picture",
    "0": "finish",
}


def run_single_example(args, example, template_components, tools, instructions):
    count_turn = 0  # count conversational turns
    conversation_history = []
    total_usage = defaultdict(int)

    system_developer_message = template_components["system"]
    initial_user_message, initial_user_message_text = format_initial_user_message(
        args.model_id, template_components, example
    )
    # register turn 0
    conversation_history.append(
        [["system", system_developer_message], ["user", initial_user_message_text]]
    )

    messages = [initial_user_message]
    answer = ""
    for count_turn, tool_id in enumerate(args.order + "0"):
        logging.info(f"Next: turn {count_turn + 1}")

        current_conversation = []

        tool_name = TOOLS[tool_id]
        logging.info(f"debug: {tool_name=}")
        # call tool
        content, content_text = call_function(
            args,
            example,
            template_components,
            instructions,
            tool_name,
        )
        messages.append({"role": "user", "content": content})
        current_conversation.append(["user-tool", content_text])
        system_developer_message = ""

        try:
            response, usage = call_api_single(args, system_developer_message, messages)
        except Exception:
            logging.exception("Error in run_single_example")
            conversation_history.append(["Error"])
            continue

        for key in OPENAI_USAGE_KEYS:
            total_usage[key] += usage[key]

        if len(response.output) > 1:  # sanity check
            logging.info(
                "[Debug] "
                f"Multiple messages in response: {[x.type for x in response.output]}"
            )

        # postprocess
        latest_assistant_response = ""
        for _response in response.output:
            match _response.type:
                case "message":
                    # append model's response
                    if len(_response.content) > 1:
                        logging.warning(f"Multiple contents: {_response.content=}")
                    _response_text = "\n".join([x.text for x in _response.content])
                    # note: looks like "id" should be added to assistant's output
                    messages.append(
                        {
                            "id": _response.id,
                            "role": "assistant",
                            "type": "message",
                            "content": _response_text,
                        }
                    )
                    current_conversation.append(["assistant", _response_text])
                    latest_assistant_response = _response_text
                case "reasoning":
                    messages.append(
                        {
                            "id": _response.id,
                            "type": "reasoning",
                            "summary": [dict(x) for x in _response.summary],
                        }
                    )
                    _response_text = "\n".join([x.text for x in _response.summary])
                    current_conversation.append(["reasoning", _response_text])
                case _:
                    logging.error(f"Undefined {_response.type=}")

        # check if answer is produced
        if "<answer>" in latest_assistant_response:
            answer = latest_assistant_response.split("<answer>")[-1].replace(
                "</answer>", ""
            )

        # logging
        for _conversation in current_conversation:
            logging.info(f"[{_conversation[0]}]")
            logging.info(_conversation[1])

        conversation_history.append(current_conversation)

    cost = estimate_cost(args.model_id, total_usage)

    return messages, conversation_history, answer, cost


def main(args):
    # load data
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    # load instruction
    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    # load prompt template
    template_components = load_prompt_template(args.filepath_template)

    # load tools
    tools = load_tools(args)
    logging.info(pformat(tools, width=100))

    new_examples, total_cost = [], 0

    # output file
    filepath_output = (
        args.dirpath_output / f"{args.model_id}_{args.reasoning}_{args.order}.json"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    # start inference
    logging.info("Prediction starts ...")
    for idx, example in enumerate(examples):
        if examples_prev and idx < len(examples_prev["examples"]):
            if examples_prev["examples"][idx] != "Error":
                new_examples.append(examples_prev["examples"][idx])
                continue

        logging.info(f"[Example {idx}]")
        try:
            messages, conversation_history, answer, cost = run_single_example(
                args,
                example,
                template_components,
                tools,
                toy2instruction[example["toy_id"]],
            )
            new_example = deepcopy(example)
            new_example["prediction"] = {
                "conversation_history": conversation_history,
                "answer": answer,
                "cost": cost,
            }
            new_examples.append(new_example)
            total_cost += cost
        except Exception:
            logging.exception("An error occurred")
            new_examples.append("Error")

        # save output every time
        metadata["cost"] = f"{total_cost:.3f}"
        final_output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(final_output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(description="workflow")
    parser.add_argument("--filepath_qa", type=Path, help="filepath for qa input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument(
        "--dirpath_intermediate_output",
        type=Path,
        help="dirpath to store intermediate outputs",
    )
    parser.add_argument("--filepath_tool", type=Path, help="filepath for tools")
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id", default="")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.1)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=128
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument("--max_turn", type=int, help="max turns", default=10)
    parser.add_argument(
        "--threshold_continue",
        type=int,
        help="the minimum turn to spend before generating answers for each example",
        default=5,
    )
    parser.add_argument(
        "--threshold_max_text_turn",
        type=int,
        help="the max #turn of non-function outputs",
        default=2,
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--order", type=str, help="order of tool calling")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(args.dirpath_log / f"workflow_{get_date()}.log"),
        ],
    )

    logging.info("Arguments:")
    pprint(vars(args))

    main(args)
