"""
Reproduction of Temporal Chain of Thought paper

model: internvl 3
"""

from argparse import ArgumentParser
from copy import deepcopy
import json
import logging
from pathlib import Path
import yaml
from tqdm import tqdm
from pprint import pformat
import random

# import sys
import re

from utils import (
    get_date,
    format_instruction,
)

from utils_api import (
    call_api_single,
)

from prompts import (
    format_input_tcot_selection,
    format_input_tcot_answer,
)

from baseline_tcot_google import (
    create_segment,
    uniform_sample,
)


def postprocess_selection(response):
    """
    postprocess response for frame id selection
    1. try normal json.loads
    2. try re
    if both failed, fallback option to add all frame ids.

    """

    selected_frame_ids, justifications = [], []
    flag_all_parsing_failed = None

    if len(response.choices) > 1:
        logging.warning(f"debug: multiple choices in {response=}")
    message = response.choices[0].message

    if message.content:
        flag_first_parsing_failed = False

        logging.info(f"debug: target to parse: {message.content=}")

        # 1. parsing with json
        try:
            logging.info("debug: try parsing with json")
            if message.content.startswith("```json"):
                text = message.content.replace("```json", "")
                text = text.replace("```", "")
            elif message.content.startswith("```"):
                text = message.content.replace("```", "")
            else:
                text = message.content

            text = text.strip()
            output = json.loads(text)
            selected_frame_ids += output["frame_ids"]
            justifications.append(output["justification"])
        except Exception:
            flag_first_parsing_failed = True
            logging.warning("Error occured in the output parsing w/ json. Try re next")

        # 2. parsing with re
        if flag_first_parsing_failed:
            try:
                # Regex for frame_ids (captures everything inside the brackets)
                frame_ids_match = re.search(r'"frame_ids": \[(.*?)\]', message.content)
                if frame_ids_match:
                    selected_frame_ids += [
                        int(x.strip()) for x in frame_ids_match.group(1).split(",")
                    ]
                else:
                    if '"frame_ids": []' in message.content:
                        logging.info("debug: no ids chosen")
                        selected_frame_ids += []
                    else:
                        logging.warning("re did not find ids.")

                # Regex for justification
                justification_match = re.search(
                    r'"justification": "([^"]*)"', message.content
                )
                if justification_match:
                    justifications.append(justification_match.group(1))
                else:
                    logging.warning("re did not find justification.")
                    logging.warning(f"{message.content}")

            except Exception:
                logging.warning("The output cannot be parsed with re.")
                flag_all_parsing_failed = True

    return selected_frame_ids, justifications, flag_all_parsing_failed


def select_frame(args, example, template_components):
    """
    1st stage of selecting frames
    1. segment
    2. select using a model
    3. sample again if len(samples) > max

    """

    # 1. segment videos
    segments = create_segment(args, example)

    # 2. select frames
    selected_frames = []  # [filepath, ...]
    history = []
    for segment in segments:
        # fomrat input
        messages, messages_text, id2filepath = format_input_tcot_selection(
            args, example, segment, template_components
        )

        # call api
        try:
            response, _ = call_api_single(args, "", messages)
        except Exception:
            # when an error occurs, fallback option to include all frames
            logging.exception("Error in select_frame()")
            selected_frames.extend(segment)
            history.append(
                {
                    "prompt": messages_text,
                    "selected_frames": segment,
                    "justifications": "Error in select_frame, call api. Include all.",
                }
            )
            continue

        # postprocess
        selected_frame_ids_segment, justifications_segment, flag_parsing_failed = (
            postprocess_selection(response)
        )
        logging.info(f"debug: {selected_frame_ids_segment=}, {justifications_segment=}")
        selected_frames_segment = []
        if flag_parsing_failed:
            logging.warning("Parsing did not work. Fallback option.")
            selected_frames_segment = list(id2filepath.values())
            justifications_segment = ["Parsing did not work. Fallback option."]
        else:
            for idx, filepath in id2filepath.items():
                if idx in selected_frame_ids_segment:
                    selected_frames_segment.append(filepath)
        history.append(
            {
                "prompt": messages_text,
                "selected_frames": selected_frames_segment,
                "justifications": justifications_segment,
            }
        )

        selected_frames.extend(selected_frames_segment)

    logging.info(f"debug: {len(selected_frames)=}")

    # 3. sample if len(samples) > max
    max_selected_frames = args.max_frames - args.num_uniform
    if len(selected_frames) > (max_selected_frames):
        logging.info(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        history.append(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        selected_frames = random.sample(selected_frames, max_selected_frames)

    return selected_frames, history


def postprocess_answer(response):
    """
    postprocess answer response

    """

    answer, thinking = "", ""
    if len(response.choices) > 1:
        logging.warning(f"debug: multiple choices in {response=}")
    message = response.choices[0].message

    if message.content:
        splits = message.content.split("<answer>")
        thinking += "\n".join(splits[:-1])
        answer = splits[-1].replace("</answer>", "")

    return answer.strip(), thinking.strip()


def run_single_example(args, example, template_components, instructions):
    history = []

    # select frames: segment -> select
    logging.info("Frame selection")
    selected_frames, history_selection = select_frame(args, example, template_components)
    history.extend(history_selection)
    history.append(["selected_frames", [Path(x).stem for x in selected_frames]])
    logging.info(
        f"debug: selected_frames.keys() {[Path(x).stem for x in selected_frames]}"
    )

    # concat selected frames and uniformed ones
    frames_input = []
    filepaths_uniform_sampled = uniform_sample(args, example)
    history.append(
        ["filepaths_uniform_sampled", [Path(x).stem for x in filepaths_uniform_sampled]]
    )
    filepaths_all = selected_frames + filepaths_uniform_sampled
    logging.info(f"debug: filepaths_all {[Path(x).stem for x in filepaths_all]}")

    # de-duplication
    filepaths_all_deduplicated = sorted(list(set(filepaths_all)))
    logging.info(
        f"debug: filepaths_all_deduplicated"
        f" {[Path(x).stem for x in filepaths_all_deduplicated]}"
    )
    history.append(
        ["filepaths_all_deduplicated", [Path(x).stem for x in filepaths_all_deduplicated]]
    )
    frames_input = filepaths_all_deduplicated

    # format input
    messages, messages_text, _ = format_input_tcot_answer(
        args, example, frames_input, template_components, instructions
    )
    history.extend(messages_text)

    # call api
    logging.info("QA based on selected frames")
    try:
        response, _ = call_api_single(args, "", messages)
    except Exception:
        logging.exception("Error in run_single_example")
        raise Exception("Error in run_single_example")

    # postprocess
    answer, thinking = postprocess_answer(response)
    logging.info(f"answer: {answer}, thinking: {thinking}")

    history.append([["thinking", thinking], ["answer", answer]])

    return answer, history


def main(args):
    random.seed(args.seed)

    with open(args.filepath_input, "r") as f:
        examples = json.load(f)

    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    with open(args.filepath_template, "r") as f:
        template_components = yaml.safe_load(f)

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    filepath_output = (
        args.dirpath_output
        / f"tcot_{Path(args.model_id).name}_{args.reasoning}_{args.filepath_input.name}"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
    else:
        logging.info("Initial attempt")
        examples_prev = []

    new_examples = []
    for idx, example in tqdm(enumerate(examples), total=len(examples)):
        # skip if already done
        if examples_prev and idx < len(examples_prev["examples"]):
            new_examples.append(examples_prev["examples"][idx])
            continue

        logging.info(f"Example {idx=}")
        try:
            answer, history = run_single_example(
                args,
                example,
                template_components,
                toy2instruction[example["toy_id"]],
            )
            new_example = deepcopy(example)
            new_example["prediction"] = {
                "history": history,
                "answer": answer,
            }
            new_examples.append(new_example)
        except Exception:
            logging.exception("Error occurred.")
            new_example = deepcopy(example)
            new_example["prediction"] = "Error"
            new_examples.append(new_example)

        output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(output, f, indent=4)
            f.write("\n")


if __name__ == "__main__":
    parser = ArgumentParser(
        description="Temporal Chain of Thought reimplementation code for InternVL3"
    )
    parser.add_argument("--filepath_input", type=Path, help="filepath for input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.0)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=4608
    )
    parser.add_argument(
        "--budget_tokens", type=int, help="max budget tokens for reasoning", default=4098
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--api_key_for_qwen",
        type=str,
        help="dummy api key for internvl",
        default="dummpy",
    )
    parser.add_argument(
        "--base_url_internvl",
        type=str,
        help="port for internvl",
        default="http://localhost:8007/v1",
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=64)
    parser.add_argument(
        "--num_uniform",
        type=int,
        help="#frames sampled uniformly",
        default=16,
    )
    parser.add_argument(
        "--size_segment",
        type=int,
        help="max #frames per segment",
        default=32,
    )
    parser.add_argument("--num_segment", type=int, help="num segment", default=4)
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir()

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log
                / f"baseline_tcot_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info(pformat(vars(args)))

    main(args)
