import json
import re
import numpy as np
import argparse
import pandas as pd
import random
# import decord
import logging

neg_occurrence = [
    "Is the event '{event}' absent from {st} to {ed} seconds in the video?",
    "Is the event '{event}' not present from {st} to {ed} seconds in the video?",
    "Does the event '{event}' not happen from {st} to {ed} seconds in the video?",
    "Is the event '{event}' missing from {st} to {ed} seconds in the video?"
]

pos_occurrence = [
    "Is the event '{event}' present from {st} to {ed} seconds in the video?",
    "Is the event '{event}' occurring from {st} to {ed} seconds in the video?",
    "Does the event '{event}' happen from {st} to {ed} seconds in the video?",
    "Is the event '{event}' included from {st} to {ed} seconds in the video?"
]

prompt = {
    "grounding": "Please answer when the event '{event}' occurs in the video. The output format should be: 'start - end seconds'. Please return its start time and end time.",
    "pos": pos_occurrence,
    "neg": neg_occurrence,
    "add_detail": "You should only answer with 'Yes' or 'No'.",
    "compositional": "{question} from {st} to {ed} seconds in the video?",
    "co_occurrence": "Do the events '{target1}' and '{target2}' happen at the same time in the video?.",
    "sequential_after": "Does the event '{target1}' happen after the event '{target2}' in the video?.",
    "sequential_before": "Does the event '{target1}' happen before the event '{target2}' in the video?.",
}

# ANSI escape codes for colors
class Formatter(logging.Formatter):
    COLOR_CODES = {
        'DEBUG': '\033[94m',    # Blue
        'INFO': '\033[92m',     # Green
        'WARNING': '\033[93m',  # Yellow
        'ERROR': '\033[91m',    # Red
        'CRITICAL': '\033[95m', # Magenta
    }
    RESET_CODE = '\033[0m'  # Reset color

    def format(self, record):
        log_color = self.COLOR_CODES.get(record.levelname, self.RESET_CODE)
        message = super().format(record)
        return f"{log_color}{message}{self.RESET_CODE}"


def load_logger():
    # Custom logger setup
    logger = logging.getLogger("Logger")
    logger.setLevel(logging.DEBUG)

    # Console handler with color formatter
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)

    # Define formatter with color
    formatter = Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    console_handler.setFormatter(formatter)

    # Add handler to logger
    logger.addHandler(console_handler)
    return logger


class BaseOptions(object):
    def __init__(self):
        self.parser = None
        self.initialized = False
        self.opt = None

    def initialize(self):
        self.initialized = True
        parser = argparse.ArgumentParser()
        parser.add_argument("--model_type", type=str, required=True, choices=['Video-ChatGPT', 'Video-LLaMA', 'VTimeLLM', 'TimeChat'])
        parser.add_argument("--dset_name", type=str, required=True, choices=['activitynet', 'charades'])
        parser.add_argument("--topn", type=int, default=3)
        parser.add_argument("--seed", type=int, default=1234)
        parser.add_argument('--overwrite', action="store_true")
        parser.add_argument('--correctness', action="store_true")
        parser.add_argument('--fine_tuned', action="store_true")
        parser.add_argument('--debug', action="store_true")
        parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
        self.parser = parser

    def display(self, opt):
        # Display settings
        print(dict_to_markdown(vars(opt), max_str_len=120))

    def parse(self, visualize=False):
        if not self.initialized:
            self.initialize()

        opt = self.parser.parse_args()
        opt.video_root = f"/data/video_datasets/{opt.dset_name}"
        opt.test_path = f"consistency_annotations/{opt.dset_name}_consistency_test.json"

        self.display(opt) if visualize else None
        self.opt = opt

        return opt


def generate_question(task, prompt, query, duration, st=None, ed=None):
    choice = random.choice(["pos", "neg"])
    if st and ed:
        st, ed = min(st, duration), min(ed, duration)

    add_detail = prompt["add_detail"]
    if task in ["grounding"]:
        question = prompt[task].format(event=query)
        add_detail = None

    elif task in ["occurrence"]:
        question = random.choice(prompt[choice]).format(event=query, st=st, ed=ed)

    elif task in ["co_occurrence", "sequential_after", "sequential_before"]:
        if not isinstance(query, list):
            raise ValueError(f"Invalid style of query: {query}")

        question = prompt[task].format(target1=query[0], target2=query[1])

    elif task in ["compositional"]:
        query = query.replace("?", "")
        question = prompt[task].format(question=query, st=st, ed=ed)

    else:
        raise NotImplementedError(f"Not implemented task: {task}")

    return question, add_detail, choice


def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.replace("'","").strip("\n")) for l in f.readlines()]


def save_jsonl(data, filename):
    """data is a list"""
    with open(filename, "w") as f:
        f.write("\n".join([json.dumps(e) for e in data]))


def save_json(data, filename, save_pretty=False, sort_keys=False):
    with open(filename, "w") as f:
        if save_pretty:
            f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
        else:
            json.dump(data, f)


def load_json(filename):
    with open(filename, "r") as f:
        return json.load(f)


def get_iou(A, B):
    try:
        max0 = max((A[0]), (B[0]))
        min0 = min((A[0]), (B[0]))
        max1 = max((A[1]), (B[1]))
        min1 = min((A[1]), (B[1]))
        return round(max(min1 - max0, 0) / (max1 - min0), 2)

    except:
        return 0


def dict_to_markdown(d, max_str_len=120):
    # convert list into its str representation
    d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()}
    # truncate string that is longer than max_str_len
    if max_str_len is not None:
        d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()}
    return pd.DataFrame(d, index=[0]).transpose().to_markdown()