"""
Prototype tool-augmented procedural activity assistant

"""

from argparse import ArgumentParser
import json
import logging
from pathlib import Path
from pprint import pformat
from copy import deepcopy
import random

from utils import (
    get_date,
    format_instruction,
    # check_prompt,
)

from utils_api import (
    encode_image,
    call_api_single,
)

from tools import (
    call_function,
    load_tools,
)

from prompts import (
    # format_initial_user_message,
    load_prompt_template,
    get_duration,
)


def cut_in(
    args,
    example,
    template_components,
    count_turn,
    count_non_tool_output,
    flag_answer_predicted,
):
    content = ""

    logging.info(f"Check if cut-in is required: {count_turn=}, {count_non_tool_output=}")

    # max_turn = args.max_turn
    threshold_max_text_turn = args.threshold_max_text_turn
    threshold_continue = args.threshold_continue

    # THINK: better to think more. create a diagram may help
    if flag_answer_predicted:
        if count_turn > threshold_continue:
            logging.info(
                f"{threshold_continue} turns have passed. No prolonging required."
            )
        else:
            if count_non_tool_output > 0:
                # encourage model to explore more if model's output is only answer
                logging.info(
                    "Answer is provided, but let's encourage the model to explore more."
                )
                content = template_components["user"]["continue"]
            else:
                logging.info(
                    "tool_call detected. let model answer based on the tool result"
                )
                logging.info("Encourage model to think based on tool result.")
                content = "Make sure to review the tool result and think before making any additional tool calls or answering the question (Format: <answer>your answer</answer>)."
    else:
        # encourage model to output answer in the specified format
        if count_turn > threshold_continue:
            logging.info("Let's encourage the model to output an answer")
            content = template_components["user"]["answer"].replace(
                "{question}", example["question"]
            )
        else:
            if (
                count_non_tool_output >= threshold_max_text_turn
                and count_turn <= threshold_continue
            ):
                logging.info("Let's encourage the model to continue its exploration")
                # if some consecutive non-function outputs and not enough tool calling
                # encourage model to explore more
                content = template_components["user"]["continue"]
            else:
                if count_non_tool_output == 0:
                    logging.info("Encourage model to think based on tool result.")
                    content = "Make sure to review the tool result and think before making any additional tool calls or answering the question (Format: <answer>your answer</answer>)."
                else:
                    logging.info("No need right now")

    if content:
        message = {"role": "user", "content": content}
        message_text = ["user", content]
    else:
        message = None
        message_text = ""

    return message, message_text


def count_image_in_messages(messages):
    count = 0
    for message in messages:
        if message["role"] in ["function", "user"] and isinstance(
            message["content"], list
        ):
            for block in message["content"]:
                if block["type"] == "image_url":
                    count += 1
    return count


def truncate_images_if_needed(args, messages):
    """
    truncate initial images if #images reached max_#images

    """

    # 1. check #images
    total_images = count_image_in_messages(messages)

    # 2. truncation if needed
    if total_images <= args.max_total_images:
        logging.info(f"#images: {total_images} (not changed)")
        return messages

    new_messages = []
    images_to_remove = total_images - args.max_total_images
    images_removed = 0
    logging.info(f"[debug] remove {images_to_remove} images")
    for message in messages:
        if message["role"] in ["function", "user"] and isinstance(
            message["content"], list
        ):
            new_content = []
            for block in message["content"]:
                if block["type"] == "image_url" and images_removed < images_to_remove:
                    new_content.append(
                        {
                            "type": "text",
                            "text": "<image>Truncated because #images exceeds the max</image>",
                        }
                    )
                    images_removed += 1
                    logging.info(
                        f"debug: {images_removed}/{images_to_remove} images removed"
                    )
                else:
                    new_content.append(block)
            new_messages.append({"role": message["role"], "content": new_content})
        else:
            new_messages.append(message)
    final_count = count_image_in_messages(new_messages)
    logging.info(f"#image: {total_images} -> {final_count} (reduced)")

    return new_messages


def uniform_sample(args, example):
    dirpath = args.dirpath_frame / example["sequence_id"] / args.angle
    start, end = int(example["video"]["start"]), int(example["video"]["end"])

    filepaths_and_ids = []
    for i in range(start, end + 1):
        second_adjusted = f"{(i-start):04d}"
        filepaths_and_ids.append([str(dirpath / f"{i}.png"), second_adjusted])

    num_frame_total = len(filepaths_and_ids)

    if len(filepaths_and_ids) > args.max_frames:
        filepaths_sampled = random.sample(filepaths_and_ids, args.max_frames)
        filepaths_sampled = sorted(filepaths_sampled, key=lambda x: x[1])
        fps = f"{len(filepaths_sampled) / num_frame_total:.2f}"
    else:
        filepaths_sampled = filepaths_and_ids
        fps = "1"

    return filepaths_sampled, fps


def format_initial_user_message_presample(args, example, template_components):
    content, message_text = [], ""

    # initial message
    content_initial = (
        template_components["user"]["initial"]
        .replace("{question}", example["question"])
        .replace(
            "{duration}",
            get_duration(example["video"]["start"], example["video"]["end"]),
        )
    )
    content.append({"type": "text", "text": content_initial})
    message_text += content_initial + "\n"

    # uniform pre-sample
    filepaths_and_ids, fps = uniform_sample(args, example)
    content_uniform_sample = (
        template_components["user"]["pre_sample"]
        .replace("{num}", str(args.max_frames))
        .replace("{fps}", fps)
    )
    content.append({"type": "text", "text": content_uniform_sample})
    message_text += content_uniform_sample + "\n"

    for filepath, idx in filepaths_and_ids:
        content.append({"type": "text", "text": f"Frame {idx}"})
        content.append(
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{encode_image(filepath)}"},
            }
        )
        message_text += f"Frame: {idx}, {str(filepath)}\n"

    message = {"role": "user", "content": content}

    return message, message_text


def run_single_example(args, example, template_components, tools, instructions):
    count_turn = 0  # count conversational turns
    count_non_tool_output = 0
    conversation_history = []
    messages = []

    # register turn 0
    system_developer_message = template_components["system"]
    messages.append({"role": "system", "content": system_developer_message})
    conversation_history.append(["system", system_developer_message])

    initial_user_message, initial_user_message_text = (
        format_initial_user_message_presample(args, example, template_components)
    )
    messages.append(initial_user_message)
    conversation_history.append(["user", initial_user_message_text])
    message_types_in_response = ["init"]

    flag_answer_predicted = False
    answer = ""
    flag_end_conversation = False
    while count_turn < args.max_turn:
        logging.info(f"Next: turn {count_turn + 1}")

        # api call
        if "tool_call" in message_types_in_response or flag_end_conversation:
            response, usage = call_api_single(args, "", messages, [], "none")
        else:
            response, usage = call_api_single(args, "", messages, tools, "auto")

        if not response:
            messages.append("Error")
            break

        latest_assistant_response = ""
        current_turn = []
        message_types_in_response = []
        if len(response.choices) > 1:  # sanity check
            logging.warning("[Debug] Multiple choices/candidates in response")

        message = response.choices[0].message

        if message.content:
            messages.append({"role": "assistant", "content": message.content})
            current_turn.append(["assistant", message.content])
            latest_assistant_response = message.content
            message_types_in_response.append("text")

        if message.tool_calls:
            # list
            if len(message.tool_calls) > 1:
                logging.warning("Multiple tools are called")
            tool_call = message.tool_calls[0]
            logging.info(f"debug {tool_call=}")
            messages.append(
                {
                    "role": "assistant",
                    "tool_calls": [
                        {
                            "id": tool_call.id,
                            "type": tool_call.type,
                            "function": tool_call.function,
                        }
                    ],
                }
            )
            current_turn.append(
                [
                    "assistant",
                    f"{tool_call.function.name} {tool_call.function.arguments}",
                ]
            )

            # function call
            tool_response_message, tool_response_message_text, _ = call_function(
                args,
                example,
                template_components,
                instructions,
                tool_call.function,
            )
            if tool_response_message:
                messages.append(tool_response_message)
                current_turn.append(tool_response_message_text)
                message_types_in_response.append("tool_call")
            else:
                logging.error("no output returned from function call")
                content = "Error returned from function call. Make sure to specify tools, e.g., function names or arguments, properly."
                messages.append({"role": "function", "content": content})
                current_turn.append(content)

        logging.info(f"{message_types_in_response=}")

        # check if answer is produced
        if "<answer>" in latest_assistant_response:
            flag_answer_predicted = True
            answer = latest_assistant_response.split("<answer>")[-1].replace(
                "</answer>", ""
            )

        # check output types
        if "tool_call" in message_types_in_response:
            count_non_tool_output = 0
        else:
            count_non_tool_output += 1

        # manual cut-in to prolong or end conversation
        message_cut_in, message_cut_in_text = cut_in(
            args,
            example,
            template_components,
            count_turn,
            count_non_tool_output,
            flag_answer_predicted,
        )
        if message_cut_in:
            messages.append(message_cut_in)
            current_turn.append(message_cut_in_text)

        # register conv history in text
        conversation_history.append(current_turn)
        for role, content in current_turn:
            logging.info(f"{role=}")
            logging.info(f"{content=}")

        # update contents if #images > max_#images
        messages = truncate_images_if_needed(args, messages)

        # termination condition
        if count_turn >= args.max_turn:
            logging.info(f"Reached {args.max_turn=}. Terminate this process.")
            flag_end_conversation = True
        else:
            if flag_answer_predicted:
                if count_turn > args.threshold_continue:
                    if count_non_tool_output > 0:
                        logging.info(
                            "Answer predicted & Latest response contains no tool call. "
                            "So finish now."
                        )
                        flag_end_conversation = True
                    else:
                        logging.info(
                            f"Although turn {args.threshold_continue} reached, "
                            "as a tool is called, continue this process"
                        )
                else:
                    pass
            else:
                pass

        if flag_end_conversation:
            break

        count_turn += 1

    return messages, conversation_history, answer


def check_if_prediction_exists(example, examples_prev):
    corresponding_example = None
    if "examples" in examples_prev:
        for example_prev in examples_prev["examples"]:
            if (
                "question" in example_prev
                and example["question"] == example_prev["question"]
            ):
                if (
                    "prediction" in example_prev
                    and example_prev["prediction"]["answer"].strip()
                ):
                    corresponding_example = deepcopy(example_prev)

    return corresponding_example


def main(args):
    random.seed(args.seed)

    # load data
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    # load instruction
    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    # load prompt template
    template_components = load_prompt_template(args.filepath_template)

    # load tools
    tools = load_tools(args)
    new_examples = []

    # output file
    filepath_output = (
        args.dirpath_output
        / f"{Path(args.model_id).name}_{args.reasoning}_presample.json"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
    else:
        logging.info("Initial attempt")
        examples_prev = []

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    # start inference
    logging.info("Prediction starts ...")
    for idx, example in enumerate(examples):
        logging.info(f"[Example {idx}]")

        # check if prediction already exists
        corresponding_example_prev = check_if_prediction_exists(example, examples_prev)
        if corresponding_example_prev:
            logging.info("Previous prediction exists. Skip.")
            new_examples.append(corresponding_example_prev)
        else:
            try:
                messages, conversation_history, answer = run_single_example(
                    args,
                    example,
                    template_components,
                    tools,
                    toy2instruction[example["toy_id"]],
                )
                new_example = deepcopy(example)
                new_example["prediction"] = {
                    "conversation_history": conversation_history,
                    "answer": answer,
                }
                new_examples.append(new_example)
            except Exception:
                logging.exception("An error occurred")
                new_examples.append("Error")

        # save output every time
        final_output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(final_output, f, indent=4)
            f.write("\n")


if __name__ == "__main__":
    parser = ArgumentParser(description="prototype")
    parser.add_argument("--filepath_qa", type=Path, help="filepath for qa input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument(
        "--dirpath_intermediate_output",
        type=Path,
        help="dirpath to store intermediate outputs",
    )
    parser.add_argument("--filepath_tool", type=Path, help="filepath for tools")
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id", default="")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.1)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=2560
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument(
        "--max_total_images",
        type=int,
        help="max total frames to feed (100 is for anthropic api, but use 80 for 32MB limit)",
        default=70,
    )
    parser.add_argument("--max_turn", type=int, help="max turns", default=10)
    parser.add_argument(
        "--threshold_continue",
        type=int,
        help="the minimum turn to spend before generating answers for each example",
        default=5,
    )
    parser.add_argument(
        "--threshold_max_text_turn",
        type=int,
        help="the max #turn of non-function outputs",
        default=2,
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--budget_tokens",
        type=int,
        help="thinking budget tokens for anthropic api",
        default=2048,
    )
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument(
        "--api_key_for_qwen", type=str, help="dummy api key for qwen", default="dummpy"
    )
    parser.add_argument(
        "--base_url_qwen",
        type=str,
        help="port for qwen",
        default="http://localhost:8000/v1",
    )
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log / f"prototype_qwen_presample_{get_date()}.log"
            ),
        ],
    )

    logging.info("Arguments:")
    logging.info(pformat(vars(args), indent=4, width=100))

    main(args)
