import json
import os
import re
import torch
import decord
import torch
decord.bridge.set_bridge('torch')

from video_llama.common.config import Config
from video_llama.common.dist_utils import get_rank
from video_llama.common.registry import registry
from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation, SeparatorStyle, conv_llava_llama_2
from video_llama.processors.video_processor import ToTHWC, ToUint8, load_video
from cons_utils import BaseOptions, prompt, generate_question


class VideoLLaMA_Options(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument("--cfg-path", default="video_llama/eval_configs/video_llama_eval_only_vl.yaml",
                                 help="path to configuration file.")
        self.parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
        self.parser.add_argument("--options", nargs="+",
                                 help="override some settings in the used config, the key-value pair "
                                      "in xxx=yyy format will be merged into config file (deprecate), "
                                      "change to --cfg-options instead.",
                                 )

class VideoLLaMA:
    def __init__(self, args):
        cfg = Config(args)
        model_config = cfg.model_cfg
        model_config.device_8bit = args.gpu_id
        model_cls = registry.get_model_class(model_config.arch)
        self.model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
        self.model.eval()

        vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
        self.vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
        self.device = 'cuda:{}'.format(args.gpu_id)
        self.chat = Chat(self.model, self.vis_processor, device=self.device)
        self.debug = args.debug

        if args.fine_tuned:
            if args.dset_name in ["charades"]:
                model_config.ckpt = "/mnt/Video-LLaMA/videollama_stage2_finetune_charades/20240906190/checkpoint_4.pth"
            else:
                model_config.ckpt = "/mnt/Video-LLaMA/videollama_stage2_finetune_activitynet/20240904171/checkpoint_2.pth"   # you can use our pretrained ckpt from https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Pretrained/

    def initialize_chat(self, msg, add_detail=None):
        chat_state = default_conversation.copy()
        chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
        if add_detail:
            chat_state.system += add_detail
        chat_state.append_message(chat_state.roles[0], "<Video><ImageHere></Video> " + msg)

        return chat_state

    def inference(self, chat, chat_state, video_features, query):
        chat.ask(query, chat_state)
        llm_message = chat.answer(conv=chat_state,
                                  img_list=video_features,
                                  num_beams=1,
                                  temperature=0.05,
                                  max_new_tokens=300,
                                  max_length=2000)[0]

        return llm_message

    def load_video_features(self, video_path):
        video_features = []
        video, msg = load_video(
            video_path=video_path,
            n_frms=8,
            height=224,
            width=224,
            sampling="uniform",
            return_msg=True
        )
        video = self.vis_processor.transform(video)
        video = video.unsqueeze(0).to(self.device)
        image_emb, _ = self.model.encode_videoQformer_visual(video)
        video_features.append(image_emb)

        return video_features, msg

    def run(self, task, video_features, query, duration, st=None, ed=None, msg=None):
        # if st and ed:
        #     st, ed = min(st, duration), min(ed, duration)
        #
        # add_detail = prompt["add_detail"]
        # if task in ["grounding"]:
        #     question = prompt[task].format(event=query)
        #     add_detail = None
        #
        # elif task in ["occurrence"]:
        #     tasks = ["pos_occurrence", "neg_occurrence"]
        #     task_choice = random.randint(0, 1)
        #     if task_choice == 0: # pos
        #         question = random.choice(prompt[tasks[0]]).format(event=query, st=st, ed=ed)
        #     else: # neg
        #         question = random.choice(prompt[tasks[1]]).format(event=query, st=st, ed=ed)
        #     # question = prompt[task].format(event=query, st=st, ed=ed)
        #
        # elif task in ["co_occurrence", "sequential_after", "sequential_before"]:
        #     if not isinstance(query, list):
        #         raise ValueError(f"Invalid style of query: {query}")
        #
        #     question = prompt[task].format(target1=query[0], target2=query[1])
        #
        # elif task in ["compositional"]:
        #     query = query.replace("?", "")
        #     question = prompt[task].format(question=query, st=st, ed=ed)
        #
        # else:
        #     raise NotImplementedError(f"Not implemented task: {task}")

        question, add_detail, choice = generate_question(task, prompt, query, duration, st, ed)

        chat_state = self.initialize_chat(msg, add_detail)
        answer = self.inference(self.chat, chat_state, video_features, question)

        if self.debug:
            print("Question:" + question)
            print("Answer:" + answer)
            print("")

        if task in ["grounding"]:
            return self.extract_time(answer)

        if task in ["occurrence"]:
            return [choice, answer]

        return answer

    def extract_time(self, text):
        numbers = re.findall(r'\d+\.\d+|\d+', text)

        # Convert the found strings to float or int, depending on your need
        numbers = [float(num) if '.' in num else int(num) for num in numbers]

        if len(numbers) < 2:
            return [0, 0]

        return numbers[:2]
