"""
Prototype tool-augmented procedural activity assistant
* agent: openai models
* tool: multimodal
"""

from argparse import ArgumentParser
import json
import logging
from pathlib import Path
from pprint import pprint, pformat
import sys
from copy import deepcopy
from collections import defaultdict
import random

from utils import (
    get_date,
    format_instruction,
    # check_prompt,
)

from utils_api import (
    encode_image,
    estimate_cost,
    call_api_single,
    OPENAI_USAGE_KEYS,
)

from tools import (
    call_function,
    load_tools,
)

from prompts import (
    # format_initial_user_message,
    load_prompt_template,
    get_duration,
)


def cut_in(
    args,
    example,
    template_components,
    count_turn,
    count_non_tool_output,
    flag_answer_predicted,
):
    content = ""

    logging.info(f"Check if cut-in is required: {count_turn=}, {count_non_tool_output=}")

    # max_turn = args.max_turn
    threshold_max_text_turn = args.threshold_max_text_turn
    threshold_continue = args.threshold_continue

    # THINK: better to think more. create a diagram may help
    if flag_answer_predicted:
        if count_turn > threshold_continue:
            logging.info(
                f"{threshold_continue} turns have passed. No prolonging required."
            )
        else:
            if count_non_tool_output > 0:
                # encourage model to explore more if model's output is only answer
                logging.info(
                    "Answer is provided, but let's encourage the model to explore more."
                )
                content = template_components["user"]["continue"]
            else:
                logging.info(
                    "tool_call detected. let model answer based on the tool result"
                )
                logging.info("Encourage model to think based on tool result.")
                content = "Make sure to review the tool result and think before making any additional tool calls or answering the question."
    else:
        # encourage model to output answer in the specified format
        if count_turn > threshold_continue:
            logging.info("Let's encourage the model to output an answer")
            content = template_components["user"]["answer"].replace(
                "{question}", example["question"]
            )
        else:
            if (
                count_non_tool_output >= threshold_max_text_turn
                and count_turn <= threshold_continue
            ):
                logging.info("Let's encourage the model to continue its exploration")
                # if some consecutive non-function outputs and not enough tool calling
                # encourage model to explore more
                content = template_components["user"]["continue"]
            else:
                if count_non_tool_output == 0:
                    logging.info("Encourage model to think based on tool result.")
                    content = "Make sure to review the tool result and think before making any additional tool calls or answering the question."
                else:
                    logging.info("No need right now")

    if content:
        message = {"role": "user", "content": content}
        message_text = ["user", content]
    else:
        message = None
        message_text = ""

    return message, message_text


def count_image_in_messages(messages):
    count = 0
    for message in messages:
        if (
            "role" in message
            and message["role"] == "user"
            and isinstance(message["content"], list)
        ):
            for block in message["content"]:
                if block["type"] == "input_image":
                    count += 1
    return count


def truncate_images_if_needed(args, messages):
    """
    truncate initial images if #images reached max_#images

    """

    # 1. check #images
    total_images = count_image_in_messages(messages)

    # 2. truncation if needed
    if total_images <= args.max_total_images:
        logging.info(f"#images: {total_images} (not changed)")
        return messages

    new_messages = []
    images_to_remove = total_images - args.max_total_images
    images_removed = 0
    logging.info(f"[debug] remove {images_to_remove} images")

    for message in messages:
        if (
            "role" in message
            and message["role"] == "user"
            and isinstance(message["content"], list)
        ):
            new_content = []
            for block in message["content"]:
                if block["type"] == "image_url" and images_removed < images_to_remove:
                    new_block = {
                        "type": "input_text",
                        "text": "<image>Truncated because #images exceeds the max</image>",
                    }
                    new_content.append(new_block)
                    images_removed += 1
                    logging.info(
                        f"debug: {images_removed}/{images_to_remove} images removed"
                    )
                else:
                    new_content.append(block)
            new_messages.append({"role": "user", "content": new_content})
        else:
            new_messages.append(message)
    final_count = count_image_in_messages(new_messages)
    logging.info(f"#image: {total_images} -> {final_count} (reduced)")

    return new_messages


def uniform_sample(args, example):
    dirpath = args.dirpath_frame / example["sequence_id"] / args.angle
    start, end = int(example["video"]["start"]), int(example["video"]["end"])

    filepaths_and_ids = []
    for i in range(start, end + 1):
        second_adjusted = f"{(i-start):04d}"
        filepaths_and_ids.append([str(dirpath / f"{i}.png"), second_adjusted])

    num_frame_total = len(filepaths_and_ids)

    if len(filepaths_and_ids) > args.max_frames:
        filepaths_sampled = random.sample(filepaths_and_ids, args.max_frames)
        filepaths_sampled = sorted(filepaths_sampled, key=lambda x: x[1])
        fps = f"{len(filepaths_sampled) / num_frame_total:.2f}"
    else:
        filepaths_sampled = filepaths_and_ids
        fps = "1"

    return filepaths_sampled, fps


def format_initial_user_message_pre_sample(args, example, template_components):
    content, message_text = [], ""

    # initial message
    content_initial = (
        template_components["user"]["initial"]
        .replace("{question}", example["question"])
        .replace(
            "{duration}",
            get_duration(example["video"]["start"], example["video"]["end"]),
        )
    )
    content.append({"type": "input_text", "text": content_initial})
    message_text += content_initial + "\n"

    # uniform pre-sample
    filepaths_and_ids, fps = uniform_sample(args, example)
    content_uniform_sample = (
        template_components["user"]["pre_sample"]
        .replace("{num}", str(args.max_frames))
        .replace("{fps}", fps)
    )
    content.append({"type": "input_text", "text": content_uniform_sample})
    message_text += content_uniform_sample + "\n"

    for filepath, idx in filepaths_and_ids:
        content.append({"type": "input_text", "text": f"Frame {idx}"})
        content.append(
            {
                "type": "input_image",
                "image_url": f"data:image/jpeg;base64,{encode_image(filepath)}",
            }
        )
        message_text += f"Frame: {idx}, {str(filepath)}\n"

    message = {"role": "user", "content": content}

    return message, message_text


def run_single_example(args, example, template_components, tools, instructions):
    count_turn = 0  # count conversational turns
    count_non_tool_output = 0
    conversation_history = []
    total_usage = defaultdict(int)

    system_developer_message = template_components["system"]
    initial_user_message, initial_user_message_text = (
        format_initial_user_message_pre_sample(args, example, template_components)
    )
    # register turn 0
    conversation_history.append(
        [["system", system_developer_message], ["user", initial_user_message_text]]
    )

    messages = [initial_user_message]
    flag_answer_predicted = False
    answer = ""
    flag_end_conversation = False
    while count_turn < args.max_turn:
        logging.info(f"Next: turn {count_turn + 1}")

        # api call
        tool_choice = "auto"  # for gpt5

        try:
            response, usage = call_api_single(
                args, system_developer_message, messages, tools, tool_choice
            )
        except Exception:
            logging.exception("Error in run_single_example")
            conversation_history.append(["Error"])
            continue

        for key in OPENAI_USAGE_KEYS:
            total_usage[key] += usage[key]

        if len(response.output) > 1:  # sanity check
            logging.info(
                "[Debug] "
                f"Multiple messages in response: {[x.type for x in response.output]}"
            )

        latest_assistant_response = ""
        current_conversation = []
        message_types_in_response = []
        for _response in response.output:
            match _response.type:
                case "function_call":
                    # append model's function call message
                    messages.append(
                        {
                            "id": _response.id,
                            "type": "function_call",
                            "call_id": _response.call_id,
                            "name": _response.name,
                            "arguments": _response.arguments,
                        }
                    )
                    current_conversation.append(
                        ["function_call", f"{_response.name} {_response.arguments}"]
                    )

                    # function call
                    messages_function_call, messages_function_call_text, _ = (
                        call_function(
                            args,
                            example,
                            template_components,
                            instructions,
                            _response,
                        )
                    )
                    # append function call output + alpha
                    # note: use extend as multiple messages may be included
                    if messages_function_call:
                        messages.extend(messages_function_call)
                        current_conversation.extend(messages_function_call_text)
                    else:
                        sys.exit("no output returned from function call")
                    message_types_in_response.append("function_call")
                case "message":
                    # think: does it effect if i change this to 'reasoning' manually?
                    # append model's response
                    if len(_response.content) > 1:
                        logging.warning(f"Multiple contents: {_response.content=}")
                    latest_assistant_response = "\n".join(
                        [x.text for x in _response.content]
                    )
                    # note: looks like "id" should be added to assistant's output
                    messages.append(
                        {
                            "id": _response.id,
                            "role": "assistant",
                            "type": "message",
                            "content": latest_assistant_response,
                        }
                    )
                    current_conversation.append(["assistant", latest_assistant_response])
                    message_types_in_response.append("message")
                case "reasoning":
                    messages.append(
                        {
                            "id": _response.id,
                            "type": "reasoning",
                            "summary": [dict(x) for x in _response.summary],
                        }
                    )

                    current_conversation.append(
                        ["reasoning", "\n".join([x.text for x in _response.summary])]
                    )
                    message_types_in_response.append("reasoning")
                case _:
                    logging.error(f"Undefined {_response.type=}")

        # check if answer is produced
        if "<answer>" in latest_assistant_response:
            flag_answer_predicted = True
            answer = latest_assistant_response.split("<answer>")[-1].replace(
                "</answer>", ""
            )

        # check output types
        if "function_call" in message_types_in_response:
            count_non_tool_output = 0
        else:
            count_non_tool_output += 1

        # manual cut-in to prolong or end conversation
        message_cut_in, message_cut_in_text = cut_in(
            args,
            example,
            template_components,
            count_turn,
            count_non_tool_output,
            flag_answer_predicted,
        )
        if message_cut_in:
            messages.append(message_cut_in)
            current_conversation.append(message_cut_in_text)

        for _conversation in current_conversation:
            logging.info(f"[{_conversation[0]}]")
            logging.info(_conversation[1])

        conversation_history.append(current_conversation)

        messages = truncate_images_if_needed(args, messages)

        # termination condition
        if count_turn >= args.max_turn:
            logging.info(f"Reached {args.max_turn=}. Terminate this process.")
            flag_end_conversation = True
        else:
            if flag_answer_predicted:
                if count_turn > args.threshold_continue:
                    if count_non_tool_output > 0:
                        logging.info(
                            "Answer predicted & Latest response contains no tool call. "
                            "So finish now."
                        )
                        flag_end_conversation = True
                    else:
                        logging.info(
                            f"Although turn {args.threshold_continue} reached, "
                            "as a tool is called, continue this process"
                        )
                else:
                    pass
            else:
                pass

        if flag_end_conversation:
            break

        count_turn += 1

    cost = estimate_cost(args.model_id, total_usage)

    return messages, conversation_history, answer, cost


def main(args):
    random.seed(args.seed)

    # load data
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    # load instruction
    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    # load prompt template
    template_components = load_prompt_template(args.filepath_template)

    # load tools
    tools = load_tools(args)
    logging.info(pformat(tools, width=100))

    new_examples, total_cost = [], 0

    # output file
    filepath_output = (
        args.dirpath_output / f"{args.model_id}_{args.reasoning}_presample.json"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    # start inference
    logging.info("Prediction starts ...")
    for idx, example in enumerate(examples):
        if examples_prev and idx < len(examples_prev["examples"]):
            if examples_prev["examples"][idx] != "Error":
                new_examples.append(examples_prev["examples"][idx])
                continue

        logging.info(f"[Example {idx}]")
        try:
            messages, conversation_history, answer, cost = run_single_example(
                args,
                example,
                template_components,
                tools,
                toy2instruction[example["toy_id"]],
            )
            new_example = deepcopy(example)
            new_example["prediction"] = {
                "conversation_history": conversation_history,
                "answer": answer,
                "cost": cost,
            }
            new_examples.append(new_example)
            total_cost += cost
        except Exception:
            logging.exception("An error occurred")
            new_examples.append("Error")

        # save output every time
        metadata["cost"] = f"{total_cost:.3f}"
        final_output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(final_output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(description="prototype")
    parser.add_argument("--filepath_qa", type=Path, help="filepath for qa input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument(
        "--dirpath_intermediate_output",
        type=Path,
        help="dirpath to store intermediate outputs",
    )
    parser.add_argument("--filepath_tool", type=Path, help="filepath for tools")
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id", default="")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.1)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=128
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument(
        "--max_total_images",
        type=int,
        help="max total frames to feed (100 is for anthropic api, but use 80 for 32MB limit)",
        default=70,
    )
    parser.add_argument("--max_turn", type=int, help="max turns", default=10)
    parser.add_argument(
        "--threshold_continue",
        type=int,
        help="the minimum turn to spend before generating answers for each example",
        default=5,
    )
    parser.add_argument(
        "--threshold_max_text_turn",
        type=int,
        help="the max #turn of non-function outputs",
        default=2,
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(args.dirpath_log / f"prototype_{get_date()}.log"),
        ],
    )

    logging.info("Arguments:")
    pprint(vars(args))

    main(args)
