"""
Reproduction of Temporal Chain of Thought paper

model: MiMo VL
"""

from argparse import ArgumentParser
from copy import deepcopy
import json
import logging
from pathlib import Path
import yaml
from tqdm import tqdm
from pprint import pformat
import random

# import sys
import re

from utils import (
    get_date,
    format_instruction,
)

from utils_api import (
    call_api_single,
)

from prompts import (
    format_input_tcot_selection,
    format_input_tcot_answer,
)

from baseline_tcot_google import (
    create_segment,
    uniform_sample,
)


def postprocess_selection(response):
    """
    postprocess response for frame id selection
    1. try normal json.loads
    2. try re
    if both failed, fallback option to add all frame ids.

    """

    selected_frame_ids, justifications = [], []
    flag_all_parsing_failed = None

    if len(response.choices) > 1:
        logging.warning(f"debug: multiple choices in {response=}")
    message = response.choices[0].message

    if message.content:
        flag_first_parsing_failed = False

        logging.info(f"debug: target to parse: {message.content=}")

        # 1. parsing with json
        try:
            logging.info("debug: try parsing with json")
            if message.content.startswith("```json"):
                text = message.content.replace("```json", "")
                text = text.replace("```", "")
            elif message.content.startswith("```"):
                text = message.content.replace("```", "")
            else:
                text = message.content

            text = text.strip()
            output = json.loads(text)
            selected_frame_ids += output["frame_ids"]
            justifications.append(output["justification"])
        except Exception:
            flag_first_parsing_failed = True
            logging.warning("Error occured in the output parsing w/ json. Try re next")

        # 2. parsing with re
        if flag_first_parsing_failed:
            try:
                # Regex for frame_ids (captures everything inside the brackets)
                frame_ids_match = re.search(r'"frame_ids": \[(.*?)\]', message.content)
                if frame_ids_match:
                    selected_frame_ids += [
                        int(x.strip()) for x in frame_ids_match.group(1).split(",")
                    ]
                else:
                    if '"frame_ids": []' in message.content:
                        logging.info("debug: no ids chosen")
                        selected_frame_ids += []
                    else:
                        logging.warning("re did not find ids.")

                # Regex for justification
                justification_match = re.search(
                    r'"justification": "([^"]*)"', message.content
                )
                if justification_match:
                    justifications.append(justification_match.group(1))
                else:
                    logging.warning("re did not find justification.")
                    logging.warning(f"{message.content}")

            except Exception:
                logging.warning("The output cannot be parsed with re.")
                flag_all_parsing_failed = True

    return selected_frame_ids, justifications, flag_all_parsing_failed


def select_frame(args, example, template_components):
    """
    1st stage of selecting frames
    1. segment
    2. select using a model
    3. sample again if len(samples) > max

    """

    # 1. segment videos
    segments = create_segment(args, example)

    # 2. select frames
    selected_frames = []  # [filepath, ...]
    history = []
    for segment in segments:
        # fomrat input
        messages, messages_text, id2filepath = format_input_tcot_selection(
            args, example, segment, template_components
        )

        # call api
        try:
            response, _ = call_api_single(args, "", messages)
        except Exception:
            # when an error occurs, fallback option to include all frames
            logging.exception("Error in select_frame()")
            selected_frames.extend(segment)
            history.append(
                {
                    "prompt": messages_text,
                    "selected_frames": segment,
                    "justifications": "Error in select_frame, call api. Include all.",
                }
            )
            continue

        # postprocess
        selected_frame_ids_segment, justifications_segment, flag_parsing_failed = (
            postprocess_selection(response)
        )
        logging.info(f"debug: {selected_frame_ids_segment=}, {justifications_segment=}")
        selected_frames_segment = []
        if flag_parsing_failed:
            logging.warning("Parsing did not work. Fallback option.")
            selected_frames_segment = list(id2filepath.values())
            justifications_segment = ["Parsing did not work. Fallback option."]
        else:
            for idx, filepath in id2filepath.items():
                if idx in selected_frame_ids_segment:
                    selected_frames_segment.append(filepath)
        history.append(
            {
                "prompt": messages_text,
                "selected_frames": selected_frames_segment,
                "justifications": justifications_segment,
            }
        )

        selected_frames.extend(selected_frames_segment)

    logging.info(f"debug: {len(selected_frames)=}")

    # 3. sample if len(samples) > max
    max_selected_frames = args.max_frames - args.num_uniform
    if len(selected_frames) > (max_selected_frames):
        logging.info(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        history.append(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        selected_frames = random.sample(selected_frames, max_selected_frames)

    return selected_frames, history


def postprocess_answer(response):
    """
    postprocess answer response

    """

    answer, thinking = "", ""
    if len(response.choices) > 1:
        logging.warning(f"debug: multiple choices in {response=}")
    message = response.choices[0].message

    if message.content:
        splits = message.content.split("<answer>")
        thinking += "\n".join(splits[:-1])
        answer = splits[-1].replace("</answer>", "")

    return answer.strip(), thinking.strip()


def run_single_example(args, example, template_components, instructions):
    history = []

    # select frames: segment -> select
    logging.info("Frame selection")
    selected_frames, history_selection = select_frame(args, example, template_components)
    history.extend(history_selection)
    history.append(["selected_frames", [Path(x).stem for x in selected_frames]])
    logging.info(
        f"debug: selected_frames.keys() {[Path(x).stem for x in selected_frames]}"
    )

    # concat selected frames and uniformed ones
    frames_input = []
    filepaths_uniform_sampled = uniform_sample(args, example)
    history.append(
        ["filepaths_uniform_sampled", [Path(x).stem for x in filepaths_uniform_sampled]]
    )
    filepaths_all = selected_frames + filepaths_uniform_sampled
    logging.info(f"debug: filepaths_all {[Path(x).stem for x in filepaths_all]}")

    # de-duplication
    filepaths_all_deduplicated = sorted(list(set(filepaths_all)))
    logging.info(
        f"debug: filepaths_all_deduplicated"
        f" {[Path(x).stem for x in filepaths_all_deduplicated]}"
    )
    history.append(
        ["filepaths_all_deduplicated", [Path(x).stem for x in filepaths_all_deduplicated]]
    )
    frames_input = filepaths_all_deduplicated

    # format input
    messages, messages_text, _ = format_input_tcot_answer(
        args, example, frames_input, template_components, instructions
    )
    history.extend(messages_text)

    # call api
    logging.info("QA based on selected frames")
    try:
        response, _ = call_api_single(args, "", messages)
    except Exception:
        logging.exception("Error in run_single_example")
        raise Exception("Error in run_single_example")

    # postprocess
    answer, thinking = postprocess_answer(response)
    logging.info(f"answer: {answer}, thinking: {thinking}")

    history.append([["thinking", thinking], ["answer", answer]])

    return answer, history


def main(args):
    random.seed(args.seed)

    with open(args.filepath_input, "r") as f:
        examples = json.load(f)

    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    with open(args.filepath_template, "r") as f:
        template_components = yaml.safe_load(f)

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    filepath_output = (
        args.dirpath_output
        / f"tcot_{Path(args.model_id).name}_{args.reasoning}_{args.filepath_input.name}"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
    else:
        logging.info("Initial attempt")
        examples_prev = []

    new_examples = []
    for idx, example in tqdm(enumerate(examples), total=len(examples)):
        # skip if already done
        if examples_prev and idx < len(examples_prev["examples"]):
            new_examples.append(examples_prev["examples"][idx])
            continue

        logging.info(f"Example {idx=}")
        try:
            answer, history = run_single_example(
                args,
                example,
                template_components,
                toy2instruction[example["toy_id"]],
            )
            new_example = deepcopy(example)
            new_example["prediction"] = {
                "history": history,
                "answer": answer,
            }
            new_examples.append(new_example)
        except Exception:
            logging.exception("Error occurred.")
            new_example = deepcopy(example)
            new_example["prediction"] = "Error"
            new_examples.append(new_example)

        output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(output, f, indent=4)
            f.write("\n")


if __name__ == "__main__":
    parser = ArgumentParser(
        description="Temporal Chain of Thought reimplementation code for MiMo VL"
    )
    parser.add_argument("--filepath_input", type=Path, help="filepath for input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.0)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=4608
    )
    parser.add_argument(
        "--budget_tokens", type=int, help="max budget tokens for reasoning", default=4098
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--api_key_for_qwen",
        type=str,
        help="dummy api key for vllm",
        default="dummpy",
    )
    parser.add_argument(
        "--base_url_mimo",
        type=str,
        help="port for mimo",
        default="http://localhost:8088/v1",
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=64)
    parser.add_argument(
        "--num_uniform",
        type=int,
        help="#frames sampled uniformly",
        default=16,
    )
    parser.add_argument(
        "--size_segment",
        type=int,
        help="max #frames per segment",
        default=32,
    )
    parser.add_argument("--num_segment", type=int, help="num segment", default=4)
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir()

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log
                / f"baseline_tcot_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info(pformat(vars(args)))

    main(args)
