"""
Prototype tool-augmented procedural activity assistant
model: gemini
method: presample in initial message & postsample in cut-in

"""

from argparse import ArgumentParser
import json
import logging
from pathlib import Path
from pprint import pformat
import sys
from copy import deepcopy
from collections import defaultdict
import random
import os

from google.genai import types

from utils import (
    get_date,
    format_instruction,
    # check_prompt,
)

from utils_api import (
    estimate_cost,
    call_api_single,
    GOOGLE_USAGE_KEYS,
)

from tools import (
    call_function,
    load_tools,
)

from prompts import (
    # format_initial_user_message,
    load_prompt_template,
    get_duration,
)

from google import genai


def cut_in(
    args,
    example,
    template_components,
    count_turn,
    count_non_tool_output,
    flag_answer_predicted,
    flag_cutin_w_image,
):
    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

    parts, parts_text = [], ""

    logging.info("Check if cut-in is required")

    # max_turn = args.max_turn
    threshold_max_text_turn = args.threshold_max_text_turn
    threshold_continue = args.threshold_continue

    flag = flag_cutin_w_image

    # THINK: better to think more. create a diagram may help
    if flag_answer_predicted:
        if count_turn > threshold_continue:
            logging.info(
                f"{threshold_continue} turns have passed. No prolonging required."
            )
        else:
            if count_non_tool_output > 0 and not flag_cutin_w_image:
                # encourage model to explore more if model's output is only answer
                logging.info(
                    "Answer is provided, but let's encourage the model to explore more."
                )
                # cut-in w/ image
                filepaths_and_ids, fps = uniform_sample(args, example)
                cutin_text = template_components["user"]["cutin_w_image"].replace(
                    "{fps}", fps
                )
                parts.append(types.Part(text=cutin_text))
                parts_text += cutin_text

                # uniform pre-sample
                for filepath, idx in filepaths_and_ids:
                    text_part = genai.types.Part(text=f"Frame {idx}")
                    parts.append(text_part)
                    file = client.files.upload(file=filepath)
                    file_part = genai.types.Part.from_uri(
                        file_uri=file.uri,
                        mime_type=file.mime_type,
                    )
                    parts.append(file_part)
                    parts_text += f"Frame {idx}: {filepath}\n"

                flag = True
            else:
                logging.info("tool_call detected or already cutin.")
    else:
        # encourage model to output answer in the specified format
        if count_turn > threshold_continue:
            logging.info("Let's encourage the model to output an answer")
            cutin_text = template_components["user"]["answer"].replace(
                "{question}", example["question"]
            )
            parts.append(types.Part(text=cutin_text))
            parts_text += cutin_text
        else:
            if (
                count_non_tool_output >= threshold_max_text_turn
                and count_turn <= threshold_continue
            ):
                logging.info("Let's encourage the model to continue its exploration")
                # if some consecutive non-function outputs and not enough tool calling
                # encourage model to explore more
                # cut-in w/ image
                filepaths_and_ids, fps = uniform_sample(args, example)
                cutin_text = template_components["user"]["cutin_w_image"].replace(
                    "{fps}", fps
                )
                parts.append(types.Part(text=cutin_text))
                parts_text += cutin_text

                # uniform pre-sample
                for filepath, idx in filepaths_and_ids:
                    text_part = genai.types.Part(text=f"Frame {idx}")
                    parts.append(text_part)
                    file = client.files.upload(file=filepath)
                    file_part = genai.types.Part.from_uri(
                        file_uri=file.uri,
                        mime_type=file.mime_type,
                    )
                    parts.append(file_part)
                    parts_text += f"Frame {idx}: {filepath}\n"

                flag = True
            else:
                logging.info("No need right now")

    return parts, parts_text, flag


def count_image_in_contents(contents):
    count = 0
    for content in contents:
        for part in content.parts:
            if part.function_response:
                if isinstance(part.function_response.response["result"], str):
                    pass
                else:
                    count += 1
            if part.file_data:
                count += 1
    return count


def truncate_images_if_needed(args, contents):
    """
    truncate initial images if #images reached max_#images

    """

    # 1. check #images
    total_images = count_image_in_contents(contents)

    # 2. truncation if needed
    if total_images <= args.max_total_images:
        logging.info(f"#images: {total_images} (not changed)")
        return contents

    new_contents = []
    images_to_remove = total_images - args.max_total_images
    images_removed = 0
    logging.info(f"[debug] remove {images_to_remove} images")

    for content in contents:
        if isinstance(content.parts, list):
            new_parts = []

            for part in content.parts:
                if part.function_response:
                    if (
                        not isinstance(part.function_response.response["result"], str)
                    ) and images_removed < images_to_remove:
                        new_part = types.Part.from_function_response(
                            name=part.function_response.name,
                            response={
                                "result": "<image>Truncated because #images exceeds the max</image>"
                            },
                        )
                        images_removed += 1
                        logging.info(
                            f"debug: {images_removed}/{images_to_remove} images removed"
                        )
                    else:
                        new_part = part
                elif part.file_data:
                    if images_removed < images_to_remove:
                        new_part = types.Part(
                            text="<image>Truncated because #images exceeds the max</image>"
                        )
                        images_removed += 1
                        logging.info(
                            f"debug: {images_removed}/{images_to_remove} images removed"
                        )
                    else:
                        new_part = part
                else:
                    new_part = part
                new_parts.append(new_part)
        else:
            logging.error("this should not happen")
            sys.exit("force stop (truncate_images_if_needed)")

        new_contents.append(types.Content(role=content.role, parts=new_parts))

    final_count = count_image_in_contents(new_contents)
    logging.info(f"#image: {total_images} -> {final_count} (reduced)")

    return new_contents


def uniform_sample(args, example):
    dirpath = args.dirpath_frame / example["sequence_id"] / args.angle
    start, end = int(example["video"]["start"]), int(example["video"]["end"])

    filepaths_and_ids = []
    for i in range(start, end + 1):
        second_adjusted = f"{(i-start):04d}"
        filepaths_and_ids.append([str(dirpath / f"{i}.png"), second_adjusted])

    num_frame_total = len(filepaths_and_ids)

    if len(filepaths_and_ids) > args.max_frames:
        filepaths_sampled = random.sample(filepaths_and_ids, args.max_frames)
        filepaths_sampled = sorted(filepaths_sampled, key=lambda x: x[1])
        fps = f"{len(filepaths_sampled) / num_frame_total:.2f}"
    else:
        filepaths_sampled = filepaths_and_ids
        fps = "1"

    return filepaths_sampled, fps


def format_initial_user_message_pre_sample(args, example, template_components):
    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

    parts = []
    text_prompt = ""
    uploaded_files = []

    # initial message
    new_content = (
        template_components["user"]["initial"]
        .replace("{question}", example["question"])
        .replace(
            "{duration}",
            get_duration(example["video"]["start"], example["video"]["end"]),
        )
    )
    parts.append(genai.types.Part(text=new_content))
    text_prompt += new_content

    # uniform pre-sample
    filepaths_and_ids, fps = uniform_sample(args, example)
    template_uniform_sample = (
        template_components["user"]["pre_sample"]
        .replace("{num}", str(args.max_frames))
        .replace("{fps}", fps)
    )
    parts.append(genai.types.Part(text=template_uniform_sample))
    text_prompt += template_uniform_sample + "\n"

    for filepath, idx in filepaths_and_ids:
        text_part = genai.types.Part(text=f"Frame {idx}")
        parts.append(text_part)
        file = client.files.upload(file=filepath)
        file_part = genai.types.Part.from_uri(
            file_uri=file.uri,
            mime_type=file.mime_type,
        )
        parts.append(file_part)
        text_prompt += f"Frame {idx}: {filepath}\n"

        uploaded_files.append(file)

    message = genai.types.Content(role="user", parts=parts)

    return message, text_prompt, uploaded_files


def delete_all_uploaded_files():
    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
    logging.info("Delete all uploaded files ...")
    for file in client.files.list():
        client.files.delete(name=file.name)
    return None


def run_single_example(args, example, template_components, tools, instructions):
    count_turn = 0  # count conversational turns
    count_non_tool_output = 0
    conversation_history = []
    total_usage = defaultdict(int)
    uploaded_files_all = []

    system_developer_message = template_components["system"]
    initial_user_message, initial_user_message_text, uploaded_files_initial = (
        format_initial_user_message_pre_sample(args, example, template_components)
    )
    uploaded_files_all += uploaded_files_initial

    # register turn 0
    conversation_history.append(
        [["system", system_developer_message], ["user", initial_user_message_text]]
    )

    contents = [initial_user_message]
    flag_answer_predicted = False
    answer = ""
    flag_end_conversation = False
    flag_cutin_w_image = False
    while count_turn < args.max_turn:
        logging.info(f"Next: turn {count_turn + 1}")

        # api call
        tool_choice = "auto"

        response, usage = call_api_single(
            args, system_developer_message, contents, tools, tool_choice
        )
        if response == "Error":
            sys.exit("stop")
        for key in GOOGLE_USAGE_KEYS:
            total_usage[key] += usage[key]

        latest_assistant_response = ""
        current_turn_assistant, current_turn_user = [], []
        message_types_in_response = []
        # note: diff from openai.
        # function_call under model and function_response under user
        user_parts = []
        logging.info(f"debug {response.candidates=}")

        if (not response.candidates) and flag_answer_predicted:
            logging.info("Answer is already predicted and no response generated. Finish.")
            break

        if len(response.candidates) > 1:  # sanity check
            logging.warning("[Debug] Multiple candidates in response")
        else:
            response = response.candidates[0]

        for part in response.content.parts:
            if part.function_call:
                current_turn_assistant.append(
                    f"{part.function_call.name} {part.function_call.args}"
                )

                # function call
                tool_result_parts, tool_result_parts_text, uploaded_files = call_function(
                    args,
                    example,
                    template_components,
                    instructions,
                    part.function_call,
                )
                uploaded_files_all += uploaded_files

                if tool_result_parts:
                    # note: use extend as multiple parts may be contained
                    user_parts.extend(tool_result_parts)
                    current_turn_user.append(tool_result_parts_text)
                else:
                    logging.error("no output returned from function call")
                    sys.exit("Force Quit (function call no output)")
                message_types_in_response.append("function_call")
            elif part.thought:
                current_turn_assistant.append(part.text)
                message_types_in_response.append("thought")
            elif part.text:
                current_turn_assistant.append(part.text)
                latest_assistant_response = part.text
                message_types_in_response.append("text")
            else:
                logging.warning(f"Uncovered type of {part=}")

        logging.info(f"{message_types_in_response=}")

        # add everything to the history
        contents.append(response.content)
        current_turn = [["assistant", current_turn_assistant]]
        logging.info("[assistant]")
        logging.info(current_turn_assistant)

        # check if answer is produced
        if "<answer>" in latest_assistant_response:
            flag_answer_predicted = True
            answer = latest_assistant_response.split("<answer>")[-1].replace(
                "</answer>", ""
            )

        # check output types
        if "function_call" in message_types_in_response:
            count_non_tool_output = 0
        else:
            count_non_tool_output += 1

        # manual cut-in to prolong or end conversation
        parts_cut_in, parts_cut_in_text, flag_cutin_w_image = cut_in(
            args,
            example,
            template_components,
            count_turn,
            count_non_tool_output,
            flag_answer_predicted,
            flag_cutin_w_image,
        )
        if parts_cut_in:
            user_parts.extend(parts_cut_in)
            current_turn_user.append(parts_cut_in_text)

        # create a user message possibly including function_call_output and
        if user_parts:
            contents.append(types.Content(role="user", parts=user_parts))
            logging.info("[user]")
            logging.info(current_turn_user)
            current_turn.append(["user", current_turn_user])

        # register conv history in text
        conversation_history.append(current_turn)

        # update contents if #images > max_#images
        contents = truncate_images_if_needed(args, contents)

        # termination condition
        if count_turn >= args.max_turn:
            logging.info(f"Reached {args.max_turn=}. Terminate this process.")
            flag_end_conversation = True
        else:
            if flag_answer_predicted:
                if count_turn > args.threshold_continue:
                    if count_non_tool_output > 0:
                        logging.info(
                            "Answer predicted & Latest response contains no tool call. "
                            "So finish now."
                        )
                        flag_end_conversation = True
                    else:
                        logging.info(
                            f"Although turn {args.threshold_continue} reached, "
                            "as a tool is called, continue this process"
                        )
                else:
                    pass
            else:
                pass

        if flag_end_conversation:
            break

        count_turn += 1
    cost = estimate_cost(args.model_id, total_usage)

    # delete files
    delete_all_uploaded_files()

    return contents, conversation_history, answer, cost


def check_if_prediction_exists(example, examples_prev):
    corresponding_example = None
    if "examples" in examples_prev:
        for example_prev in examples_prev["examples"]:
            if (
                "question" in example_prev
                and example["question"] == example_prev["question"]
            ):
                if (
                    "prediction" in example_prev
                    and example_prev["prediction"]["answer"].strip()
                ):
                    corresponding_example = deepcopy(example_prev)

    return corresponding_example


def main(args):
    random.seed(args.seed)

    # load data
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    # load instruction
    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    # load prompt template
    template_components = load_prompt_template(args.filepath_template)

    # load tools
    tools = load_tools(args)
    # logging.info(pformat(tools[0].function_declarations, width=100))
    new_examples, total_cost = [], 0

    # output file
    filepath_output = (
        args.dirpath_output / f"{args.model_id}_{args.reasoning}_pre-postsample.json"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    # start inference
    logging.info("Prediction starts ...")
    for idx, example in enumerate(examples):
        logging.info(f"[Example {idx}]")

        # check if prediction already exists
        corresponding_example_prev = check_if_prediction_exists(example, examples_prev)
        if corresponding_example_prev:
            logging.info("Previous prediction exists. Skip.")
            new_examples.append(corresponding_example_prev)
        else:
            try:
                messages, conversation_history, answer, cost = run_single_example(
                    args,
                    example,
                    template_components,
                    tools,
                    toy2instruction[example["toy_id"]],
                )
                new_example = deepcopy(example)
                new_example["prediction"] = {
                    "conversation_history": conversation_history,
                    "answer": answer,
                    "cost": cost,
                }
                new_examples.append(new_example)
                total_cost += cost
            except Exception:
                logging.exception("An error occurred")
                new_examples.append("Error")

        # save output every time
        metadata["cost"] = f"{total_cost:.3f}"
        final_output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(final_output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(description="prototype")
    parser.add_argument("--filepath_qa", type=Path, help="filepath for qa input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument(
        "--dirpath_intermediate_output",
        type=Path,
        help="dirpath to store intermediate outputs",
    )
    parser.add_argument("--filepath_tool", type=Path, help="filepath for tools")
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id", default="")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.1)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=2500
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=30)
    parser.add_argument(
        "--max_total_images",
        type=int,
        help="max total frames to feed (100 is for anthropic api, but use 80 for 32MB limit)",
        default=70,
    )
    parser.add_argument("--max_turn", type=int, help="max turns", default=10)
    parser.add_argument(
        "--threshold_continue",
        type=int,
        help="the minimum turn to spend before generating answers for each example",
        default=5,
    )
    parser.add_argument(
        "--threshold_max_text_turn",
        type=int,
        help="the max #turn of non-function outputs",
        default=2,
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--budget_tokens",
        type=int,
        help="thinking budget tokens for anthropic api",
        default=2048,
    )
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log / f"prototype_pre-postsample{get_date()}.log"
            ),
        ],
    )

    logging.info("Arguments:")
    logging.info(pformat(vars(args), indent=4, width=100))

    main(args)
