"""
Reproduction of Temporal Chain of Thought paper

note:
use gemini first
"""

from argparse import ArgumentParser
from copy import deepcopy
from collections import defaultdict
import json
import logging
from pathlib import Path
import yaml
from tqdm import tqdm
from pprint import pformat
import random
import os

# import sys
import re

from google import genai

from utils import (
    get_date,
    format_instruction,
)

from utils_api import (
    call_api_single,
    estimate_cost,
    # OPENAI_MODELS,
    # ANTHROPIC_MODELS,
    # GOOGLE_MODELS,
    # QWEN_MODELS,
    GOOGLE_USAGE_KEYS,
)

from prompts import (
    format_input_tcot_selection,
    format_input_tcot_answer,
)


# later: create & move to utils_tcot.py
def create_segment(args, example):
    """
    segment videos. sample to reduced to the fixed size if necessary

    """

    duration = example["video"]["end"] - example["video"]["start"]
    if duration < args.size_segment * args.num_segment:
        # eg, 60s < 32*4=128 => [32, 28]
        size_segment_actual = args.size_segment
        num_segment_actual = int(duration // size_segment_actual)
        if duration % size_segment_actual != 0:
            num_segment_actual += 1
    else:
        # eg, 210s > 32*4=128 =split=> 210//4=52, [53,53,53,51] =sample=> [32,32,32,32]
        num_segment_actual = args.num_segment
        size_segment_actual = int(duration // num_segment_actual)
        if duration % num_segment_actual != 0:
            size_segment_actual += 1

    logging.info(f"debug: {num_segment_actual=}, {size_segment_actual=}")

    segments = []
    dirpath = args.dirpath_frame / example["sequence_id"] / args.angle
    for i in range(num_segment_actual):
        start = size_segment_actual * i
        end = size_segment_actual * (i + 1)
        one_segment = []
        for j in range(start, end):
            file_id = int(j + example["video"]["start"])
            if file_id < example["video"]["end"]:
                filepath = dirpath / f"{file_id}.png"
                one_segment.append(str(filepath))
        segments.append(one_segment)

    # sanity check
    logging.info(f"{example['video']['start']=}, {example['video']['end']=}, {duration=}")
    for idx, segment in enumerate(segments):
        logging.info(f"debug: segment {(idx+1)=} {len(segment)=}")
        for filepath in segment:
            # logging.info(f"debug: {filepath.name=}")
            assert Path(filepath).exists()

    # check uniform sampling is required
    new_segments = []
    for segment in segments:
        if len(segment) > args.size_segment:
            uniform_samples = random.sample(segment, args.size_segment)
            new_segments.append(sorted(uniform_samples))
        else:
            new_segments.append(segment)

    # sanity check
    logging.info("uniform sampling is done")
    for idx, segment in enumerate(new_segments):
        logging.info(f"debug: segment {(idx+1)=} {len(segment)=}")
        for filepath in segment:
            # logging.info(f"debug: {filepath.name=}")
            assert Path(filepath).exists()

    return new_segments


def postprocess_selection(response):
    """
    postprocess response for frame id selection
    1. try normal json.loads
    2. try re
    if both failed, fallback option to add all frame ids.

    todo: check if a case where no id is selected is considered.

    """

    selected_frame_ids, justifications = [], []
    flag_all_parsing_failed = None
    for candidate in response.candidates:
        for part in candidate.content.parts:
            if not part.text:
                continue
            elif part.thought:
                justifications.append(str(part.text))
            else:
                flag_first_parsing_failed = False
                if part.text.startswith("```json"):
                    text = part.text.replace("```json", "")
                    text = text.replace("```", "")
                elif part.text.startswith("```"):
                    text = part.text.replace("```", "")
                else:
                    text = part.text
                text = text.strip()
                try:
                    output = json.loads(text)
                    selected_frame_ids += output["frame_ids"]
                    justifications.append(output["justification"])
                except Exception:
                    flag_first_parsing_failed = True
                    logging.warning("Error occured in the output parsing. Try re next")
                    # logging.exception(
                    #     "Error occurs in the output parsing. Try re next"
                    # )

                if flag_first_parsing_failed:
                    try:
                        # Regex for frame_ids (captures everything inside the brackets)
                        frame_ids_match = re.search(r'"frame_ids": \[(.*?)\]', part.text)
                        if frame_ids_match:
                            selected_frame_ids += [
                                int(x.strip())
                                for x in frame_ids_match.group(1).split(",")
                            ]
                        else:
                            if '"frame_ids": []' in part.text:
                                logging.info("debug: no ids chosen")
                                selected_frame_ids += []
                            else:
                                logging.warning("re did not find ids.")

                        # Regex for justification
                        justification_match = re.search(
                            r'"justification": "([^"]*)"', part.text
                        )
                        if justification_match:
                            justifications.append(justification_match.group(1))
                        else:
                            logging.warning("re did not find justification.")
                            logging.warning(f"{part.text}")

                    except Exception:
                        logging.warning("The output cannot be parsed with re.")
                        # logging.exception("The output cannot be parsed with re.")
                        flag_all_parsing_failed = True

    return selected_frame_ids, justifications, flag_all_parsing_failed


def select_frame(args, example, template_components):
    """
    1st stage of selecting frames
    1. segment
    2. select using a model
    3. sample again if len(samples) > max

    """

    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

    total_usage = defaultdict(int)

    # 1. segment videos
    segments = create_segment(args, example)

    # 2. select frames
    selected_frames = {}  # {filepath: object, ...}
    history = []
    for segment in segments:
        # fomrat input
        contents, contents_text, files_uploaded = format_input_tcot_selection(
            args, example, segment, template_components
        )

        # call api
        try:
            response, usage = call_api_single(args, "", contents)
        except Exception:
            # when an error occurs, fallback option to include all frames
            logging.exception("Error in select_frame()")
            selected_frame_ids_segment = []
            for idx, path_and_object in files_uploaded.items():
                selected_frames[path_and_object["filepath"]] = path_and_object["object"]
                selected_frame_ids_segment.append(idx)
            history.append(
                {
                    "prompt": contents_text,
                    "selected_frames": selected_frame_ids_segment,
                    "justifications": "Error occurred. Include all.",
                }
            )
            continue

        for key in GOOGLE_USAGE_KEYS:
            total_usage[key] += usage[key]

        # postprocess
        selected_frame_ids_segment, justifications_segment, flag_parsing_failed = (
            postprocess_selection(response)
        )
        logging.info(f"debug: {selected_frame_ids_segment=}, {justifications_segment=}")
        if flag_parsing_failed:
            logging.warning("Parsing did not work. Fallback option.")
            selected_frame_ids_segment = [x for x in files_uploaded.keys()]
            justifications_segment = [""]

        for idx, path_and_object in files_uploaded.items():
            if idx in selected_frame_ids_segment:
                selected_frames[path_and_object["filepath"]] = path_and_object["object"]
            else:
                # delete from Files API
                client.files.delete(name=path_and_object["object"].name)
        history.append(
            {
                "prompt": contents_text,
                "selected_frames": selected_frame_ids_segment,
                "justifications": justifications_segment,
                "usage": usage,
            }
        )

        # logging.info(f"debug (select_frame): {response=}")
        # logging.info(f'debug: {selected_frame_ids_segment=}')
    logging.info(f"debug: {history=}, {len(selected_frames)=}")

    # 3. sample if len(samples) > max
    max_selected_frames = args.max_frames - args.num_uniform
    if len(selected_frames) > (max_selected_frames):
        logging.info(
            f"#selected frames > max: {len(selected_frames)=} > {max_selected_frames=}"
            " Perform uniform sampling."
        )
        sampled_frames = random.sample(list(selected_frames.keys()), max_selected_frames)
        new_selected_frames = {}
        for idx, file_object in selected_frames.items():
            if idx in sampled_frames:
                new_selected_frames[idx] = file_object
            else:
                client.files.delete(name=file_object.name)
    else:
        new_selected_frames = selected_frames

    return new_selected_frames, history, total_usage


def uniform_sample(args, example):
    dirpath = args.dirpath_frame / example["sequence_id"] / args.angle
    start, end = int(example["video"]["start"]), int(example["video"]["end"])

    filepaths = []
    for i in range(start, end + 1):
        filepaths.append(str(dirpath / f"{i}.png"))

    if len(filepaths) > args.num_uniform:
        filepaths_sampled = random.sample(filepaths, args.num_uniform)
    else:
        filepaths_sampled = filepaths

    return filepaths_sampled


def postprocess_answer(response):
    """
    postprocess answer response

    """

    answer, thinking = "", ""
    for candidate in response.candidates:
        for part in candidate.content.parts:
            if part.thought:
                thinking += f"{part.text}\n"
            elif part.text:
                answer += f"{part.text}\n"
            else:
                logging.warning(f"Undefined type: {part=}")

    if "/answer" in answer:
        splits = answer.split("<answer>")
        thinking += "\n".join(splits[:-1])
        answer = splits[-1].replace("</answer>", "")

    return answer.strip(), thinking.strip()


def run_single_example(args, example, template_components, instructions):
    client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

    history = []
    total_usage = defaultdict(int)

    # select frames: segment -> select
    logging.info("Frame selection")
    selected_frames, history_selection, usage_selection = select_frame(
        args, example, template_components
    )
    for key in GOOGLE_USAGE_KEYS:
        total_usage[key] += usage_selection[key]
    history.extend(history_selection)
    history.append(["selected_frames", [Path(x).stem for x in selected_frames.keys()]])
    logging.info(
        f"debug: selected_frames.keys() {[Path(x).stem for x in selected_frames.keys()]}"
    )

    # concat selected frames and uniformed ones
    frames_input = []
    filepaths_uniform_sampled = uniform_sample(args, example)
    history.append(
        ["filepaths_uniform_sampled", [Path(x).stem for x in filepaths_uniform_sampled]]
    )
    filepaths_all = list(selected_frames.keys()) + filepaths_uniform_sampled
    logging.info(f"debug: filepaths_all {[Path(x).stem for x in filepaths_all]}")

    # de-duplication
    filepaths_all_deduplicated = sorted(list(set(filepaths_all)))
    logging.info(
        f"debug: filepaths_all_deduplicated"
        f" {[Path(x).stem for x in filepaths_all_deduplicated]}"
    )
    history.append(
        ["filepaths_all_deduplicated", [Path(x).stem for x in filepaths_all_deduplicated]]
    )
    for filepath in sorted(filepaths_all_deduplicated):
        if filepath in selected_frames:
            frames_input.append(selected_frames[filepath])
        else:
            file_uploaded = client.files.upload(file=filepath)
            frames_input.append(file_uploaded)

    # format input
    contents, contents_text, files_uploaded = format_input_tcot_answer(
        args, example, frames_input, template_components, instructions
    )

    # call api
    logging.info("QA based on selected frames")
    try:
        response, usage_answer = call_api_single(
            args,
            "",
            contents,
        )
    except Exception:
        logging.exception("Error in run_single_example")
        raise Exception("Error in run_single_example")

    for key in GOOGLE_USAGE_KEYS:
        total_usage[key] += usage_answer[key]

    # postprocess
    answer, thinking = postprocess_answer(response)
    logging.info(f"answer: {answer}, thinking: {thinking}")

    # delete from File API
    for file_uploaded in frames_input + files_uploaded:
        client.files.delete(name=file_uploaded.name)

    history.append(thinking)

    return contents_text, answer, history, total_usage


def main(args):
    random.seed(args.seed)

    with open(args.filepath_input, "r") as f:
        examples = json.load(f)

    toy2instruction = format_instruction(
        args.filepath_instruction,
        args.dirpath_instruction_image,
        args.dirpath_parts_image,
    )

    with open(args.filepath_template, "r") as f:
        template_components = yaml.safe_load(f)

    metadata = {
        "data-created": get_date(),
        "args": {k: str(v) if isinstance(v, Path) else v for k, v in vars(args).items()},
    }

    filepath_output = (
        args.dirpath_output
        / f"tcot_{Path(args.model_id).name}_{args.reasoning}_{args.filepath_input.name}"
    )
    if filepath_output.exists():
        with open(filepath_output, "r") as f:
            examples_prev = json.load(f)
        logging.info(
            f"Prev attempt exists. Restart from {len(examples_prev['examples'])}"
        )
        total_cost = float(examples_prev["metadata"]["cost"])
    else:
        logging.info("Initial attempt")
        examples_prev = []
        total_cost = 0

    new_examples = []
    for idx, example in tqdm(enumerate(examples), total=len(examples)):
        # skip if already done
        if examples_prev and idx < len(examples_prev["examples"]):
            new_examples.append(examples_prev["examples"][idx])
            continue

        logging.info(f"Example {idx=}")
        try:
            text_prompt, answer, history, usage = run_single_example(
                args,
                example,
                template_components,
                toy2instruction[example["toy_id"]],
            )
            new_example = deepcopy(example)
            new_example["prediction"] = {
                "text_prompt": text_prompt,
                "answer": answer,
                "history": history,
                "usage": usage,
            }
            new_examples.append(new_example)
            total_cost += estimate_cost(args.model_id, usage)
        except Exception:
            logging.exception("Error occurred.")
            new_example = deepcopy(example)
            new_example["prediction"] = "Error"
            new_examples.append(new_example)

        metadata["cost"] = f"{total_cost:.3f}"
        output = {"metadata": metadata, "examples": new_examples}
        with open(filepath_output, "w") as f:
            json.dump(output, f, indent=4)
            f.write("\n")

    logging.info(f"{total_cost=}")


if __name__ == "__main__":
    parser = ArgumentParser(description="Temporal Chain of Thought reimplementation code")
    parser.add_argument("--filepath_input", type=Path, help="filepath for input")
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument(
        "--dirpath_instruction_image", type=Path, help="dirpath for instruction (image)"
    )
    parser.add_argument(
        "--dirpath_parts_image", type=Path, help="dirpath for parts (image)"
    )
    parser.add_argument("--dirpath_frame", type=Path, help="dirpath for frames")
    parser.add_argument(
        "--filepath_template", type=Path, help="filepath for prompt template"
    )
    parser.add_argument("--dirpath_output", type=Path, help="filepath for output")
    parser.add_argument("--model_id", type=str, help="model id")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.0)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=4608
    )
    parser.add_argument(
        "--budget_tokens", type=int, help="max budget tokens for reasoning", default=4098
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument(
        "--reasoning_effort",
        type=str,
        help="reasoning level for openAI models",
        default="medium",
    )
    parser.add_argument(
        "--summary_type", type=str, help="summary type for openai api", default="auto"
    )
    parser.add_argument(
        "--api_key_for_qwen", type=str, help="dummy api key for qwen", default="dummpy"
    )
    parser.add_argument(
        "--base_url_qwen",
        type=str,
        help="port for qwen",
        default="http://localhost:8000/v1",
    )
    parser.add_argument("--max_frames", type=int, help="max frames to feed", default=64)
    parser.add_argument(
        "--num_uniform",
        type=int,
        help="#frames sampled uniformly",
        default=16,
    )
    parser.add_argument(
        "--size_segment",
        type=int,
        help="max #frames per segment",
        default=32,
    )
    parser.add_argument("--num_segment", type=int, help="num segment", default=4)
    parser.add_argument("--angle", type=str, help="angle", default="C10118_rgb")
    parser.add_argument("--resolution", type=str, help="resolution", default="360p")
    parser.add_argument("--color", type=str, help="color", default="rgb")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=10)
    parser.add_argument("--seed", type=int, help="randome seed", default=7)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath for log")

    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir()

    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir()

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(
                args.dirpath_log
                / f"baseline_tcot_{Path(args.model_id).name}_{get_date()}.log"
            ),
        ],
    )

    logging.info(pformat(vars(args)))

    main(args)
