"""
Reproduction of Temporal Chain of Thought paper

model: anthropic

TODO: modify for claude
"""

from argparse import ArgumentParser
from collections import defaultdict
from copy import deepcopy
import json
import logging
from pathlib import Path
import yaml
from tqdm import tqdm
from pprint import pformat
import random

# import sys
import re

from utils import (
    get_date,
    format_instruction,
)

from utils_api import (
    call_api_single,
    estimate_cost,
    ANTHROPIC_USAGE_KEYS,
)

from prompts import (
    format_input_tcot_selection,
    format_input_tcot_answer,
)

from baseline_tcot_google import (
    create_segment,
    uniform_sample,
)


def postprocess_selection(response):
    """
    postprocess response for frame id selection
    1. try normal json.loads
    2. try re
    if both failed, fallback option to add all frame ids.

    """

    selected_frame_ids, justifications = [], []
    flag_all_parsing_failed = None

    response_text = ""
    for block in response.content:
        match block.type:
            case "text":
                response_text += block.text
            case "thinking":
                justifications.append(["reasoning", block.thinking])
            case _:
                logging.error(f"Undefined {block.type=}")

    # parsing
    if response_text:
        flag_first_parsing_failed = False

        logging.info(f"debug: target to parse: {response_text=}")

        # 1. parsing with json
        try:
            logging.info("debug: try parsing with json")
            if response_text.startswith("```json"):
                text = response_text.content.replace("```json", "")
                text = text.replace("```", "")
            elif response_text.startswith("```"):
                text = response_text.replace("```", "")
            else:
                text = response_text

            text = text.strip()
            output = json.loads(text)
            selected_frame_ids += output["frame_ids"]
            justifications.append(output["justification"])
        except Exception:
            flag_first_parsing_failed = True
            logging.warning("Error occured in the output parsing w/ json. Try re next")

        # 2. parsing with re
        if flag_first_parsing_failed:
            try:
                # Regex for frame_ids (captures everything inside the brackets)
                frame_ids_match = re.search(r'"frame_ids": \[(.*?)\]', response_text)
                if frame_ids_match:
                    selected_frame_ids += [
                        int(x.strip()) for x in frame_ids_match.group(1).split(",")
                    ]
                else:
                    if '"frame_ids": []' in response_text:
                        logging.info("debug: no ids chosen")
                        selected_frame_ids += []
                    else:
                        logging.warning("re did not find ids.")

                # Regex for justification
                justification_match = re.search(
                    r'"justification": "([^"]*)"', response_text
                )
                if justification_match:
                    justifications.append(justification_match.group(1))
                else:
                    logging.warning("re did not find justification.")

            except Exception:
                logging.warning("The output cannot be parsed with re.")
                flag_all_parsing_failed = True

    return selected_frame_ids, justifications, flag_all_parsing_failed


def select_frame(args, example, template_components):
    """
    1st stage of selecting frames
    1. segment
    2. select using a model
    3. sample again if len(samples) > max

    """

    total_usage = defaultdict(int)

    # 1. segment videos
    segments = create_segment(args, example)

    # 2. select frames
    selected_frames = []  # [filepath, ...]
    history = []
    for segment in segments:
        # fomrat input
        messages, messages_text, id2filepath = format_input_tcot_selection(
            args, example, segment, template_components
        )

        # call api
        try:
            response, usage = call_api_single(args, "", messages, benchmark=True)
        except Exception:
            # when an error occurs, fallback option to include all frames
            logging.exception("Error in select_frame()")
            selected_frames.extend(segment)
            history.append(
                {
                    "prompt": messages_text,
                    "selected_frames": segment,
                    "justifications": "Error in select_frame, call api. Include all.",
                }
            )
            continue

        for key in ANTHROPIC_USAGE_KEYS:
            total_usage[key] += usage[key]

        # postprocess
        selected_frame_ids_segment, justifications_segment, flag_parsing_failed = (
            postprocess_selection(response)
        )
        logging.info(f"debug: {selected_frame_ids_segment=}, {justifications_segment=}")
        selected_frames_segment = []
        if flag_parsing_failed:
            logging.warning("Parsing did not work. Fallback option.")
            selected_frames_segment = list(id2filepath.values())
            justifications_segment = ["Parsing did not work. Fallback option."]
        else:
            for idx, filepath in id2filepath.items():
                if idx in selected_frame_ids_segment:
                    selected_frames_segment.append(filepath)
        history.append(
            {
                "prompt": messages_text,
                "selected_frames": selected_frames_segment,
                "justifications": justifications_segment,
            }
        )

        selected_frames.extend(selected_frames_segment)

    logging.info(f"debug: {len(selected_frames)=}")

    # 3. sample if len(samples) > max
    max_selected_frames = args.max_frames - args.num_uniform
    if len(selected_frames) > (max_selected_frames):
        logging.info(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        history.append(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        selected_frames = random.sample(selected_frames, max_selected_frames)

    return selected_frames, history, total_usage


def postprocess_answer(response):
    """
    postprocess answer response

    """

    response_text, thinking = "", ""
    for block in response.content:
        match block.type:
            case "text":
                response_text += block.text
            case "thinking":
                thinking += block.thinking
            case _:
                logging.error(f"Undefined {block.type=}")

    # manual parsing
    if response_text:
        splits = response_text.split("<answer>")
        thinking += "\n".join(splits[:-1])
        answer = splits[-1].replace("</answer>", "")

    return answer.strip(), thinking.strip()


def run_single_example(args, example, template_components, instructions):
    history = []
    total_usage = defaultdict(int)

    # select frames: segment -> select
    logging.info("Frame selection")
    selected_frames, history_selection, usage_selection = select_frame(
        args, example, template_components
    )
    for key in ANTHROPIC_USAGE_KEYS:
        total_usage[key] += usage_selection[key]
    history.extend(history_selection)
    history.append(["selected_frames", [Path(x).stem for x in selected_frames]])
    logging.info(
        f"debug: selected_frames.keys() {[Path(x).stem for x in selected_frames]}"
    )

    # concat selected frames and uniformed ones
    frames_input = []
    filepaths_uniform_sampled = uniform_sample(args, example)
    history.append(
        ["filepaths_uniform_sampled", [Path(x).stem for x in filepaths_uniform_sampled]]
    )
    filepaths_all = selected_frames + filepaths_uniform_sampled
    logging.info(f"debug: filepaths_all {[Path(x).stem for x in filepaths_all]}")

    # de-duplication
    filepaths_all_deduplicated = sorted(list(set(filepaths_all)))
    logging.info(
        f"debug: filepaths_all_deduplicated"
        f" {[Path(x).stem for x in filepaths_all_deduplicated]}"
    )
    history.append(
        ["filepaths_all_deduplicated", [Path(x).stem for x in filepaths_all_deduplicated]]
    )
    frames_input = filepaths_all_deduplicated

    # format input
    messages, messages_text, _ = format_input_tcot_answer(
        args, example, frames_input, template_components, instructions
    )
    history.extend(messages_text)

    # call api
    logging.info("QA based on selected frames")
    try:
        response, usage_answer = call_api_single(args, "", messages, benchmark=True)
    except Exception:
        logging.exception("Error in run_single_example")
        raise Exception("Error in run_single_example")

    for key in ANTHROPIC_USAGE_KEYS:
        total_usage[key] += usage_answer[key]

    # postprocess
    answer, thinking = postprocess_answer(response)
    logging.info(f"answer: {answer}, thinking: {thinking}")

    history.append([["thinking", thinking], ["answer", answer]])

    return answer, history, total_usage


def check_if_prediction_exists(example, examples_prev):
    corresponding_example = None
    if "examples" in examples_prev:
        for example_prev in examples_prev["examples"]:
            if (
                "question" in example_prev
                and example["question"] == example_prev["question"]
            ):
                if (
                    "prediction" in example_prev
                    and example_prev["prediction"]["answer"].strip()
                ):
                    corresponding_example = deepcopy(example_prev)

    return corresponding_example


def main(args):
    random.seed(args.seed)

    with open(args.filepath_input, "r") as f:
        examples = json.load(f)

    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    with open(args.filepath_template, "r") as f:
        template_components = yaml.safe_load(f)

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    filepath_output = (
        args.dirpath_output
        / f"tcot_{Path(args.model_id).name}_{args.reasoning}_{args.filepath_input.name}"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []
        total_cost = 0

    new_examples = []
    for idx, example in tqdm(enumerate(examples), total=len(examples)):
        logging.info(f"[Example {idx}]")

        # check if prediction already exists
        corresponding_example_prev = check_if_prediction_exists(example, examples_prev)
        if corresponding_example_prev:
            logging.info("Previous prediction exists. Skip.")
            new_examples.append(corresponding_example_prev)
        else:
            try:
                answer, history, usage = run_single_example(
                    args,
                    example,
                    template_components,
                    toy2instruction[example["toy_id"]],
                )
                new_example = deepcopy(example)
                new_example["prediction"] = {
                    "history": history,
                    "answer": answer,
                    "usage": usage,
                }
                new_examples.append(new_example)
                total_cost += estimate_cost(args.model_id, usage)
            except Exception:
                logging.exception("Error occurred.")
                new_examples.append("Error")

        metadata["cost"] = f"{total_cost:.3f}"
        output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(
        description="Temporal Chain of Thought reimplementation code for Qwen 2.5 VL"
    )
    parser.add_argument("--filepath_input", type=Path, help="filepath for input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.0)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=4608
    )
    parser.add_argument(
        "--budget_tokens", type=int, help="max budget tokens for reasoning", default=4098
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--api_key_for_qwen", type=str, help="dummy api key for qwen", default="dummpy"
    )
    parser.add_argument(
        "--base_url_qwen",
        type=str,
        help="port for qwen",
        default="http://localhost:8000/v1",
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=64)
    parser.add_argument(
        "--num_uniform",
        type=int,
        help="#frames sampled uniformly",
        default=16,
    )
    parser.add_argument(
        "--size_segment",
        type=int,
        help="max #frames per segment",
        default=32,
    )
    parser.add_argument("--num_segment", type=int, help="num segment", default=4)
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir()

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log
                / f"baseline_tcot_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info(pformat(vars(args)))

    main(args)
