import datetime
import os
import sys
import requests
import clip
import torch
import re
from PIL import Image
import requests
import random
from io import BytesIO
from transformers import TextStreamer, BartForConditionalGeneration, BartTokenizer
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    from PIL import Image
    BICUBIC = Image.BICUBIC
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize

from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from vtimellm.conversation import conv_templates, SeparatorStyle
from vtimellm.model.builder import load_pretrained_model
from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor

from cons_utils import BaseOptions, prompt

# Modify the prompts for the model
prompt["grounding"] = "During which frames can we see the event '{event}'?"

# for task, sentence in prompt.items():
#     if isinstance(prompt[task], str):
#         prompt[task] = prompt[task].replace("seconds", "frames")



class VTimeLLM_Options(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument("--clip_path", type=str, default="/mnt/vtimellm/ViT-L-14.pt")
        self.parser.add_argument("--model_base", type=str, default="/mnt/vidllm/ckpt/vicuna-7b-v1.5")
        self.parser.add_argument("--pretrain_mm_mlp_adapter", type=str, default="/mnt/vtimellm/vtimellm-vicuna-v1-5-7b-stage1/mm_projector.bin")
        self.parser.add_argument("--stage2", type=str, default="/mnt/vtimellm/vtimellm-vicuna-v1-5-7b-stage2")
        self.parser.add_argument("--stage3", type=str, default="/mnt/vtimellm/vtimellm-vicuna-v1-5-7b-stage3")
        self.parser.add_argument("--n_frame", type=int, default=100)


class VTimeLLM:
    def __init__(self, args):
        self.disable_torch_init()
        self.tokenizer, self.model, self.context_len = load_pretrained_model(args, args.stage2, args.stage3)
        self.model = self.model.cuda()
        self.model = self.model.to(torch.float16)

        self.clip_model, _ = clip.load(args.clip_path)
        self.clip_model = self.clip_model.cuda()
        self.video_loader = VideoExtractor(N=args.n_frame)
        self.transform = Compose([
            Resize(224, interpolation=BICUBIC),
            CenterCrop(224),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

        self.n_frame = args.n_frame
        self.debug = args.debug

    def load_video_features(self, video_path):
        try:
            _, images = self.video_loader.extract({'id': None, 'video': video_path})
        except:
            print(f"pass the video {video_path}")

        images = self.transform(images / 255.0)
        images = images.to(torch.float16)

        with torch.no_grad():
            video_features = self.clip_model.encode_image(images.to('cuda'))

        return video_features, None

    def extract_time_vtimellm(self, sentence):
        matches = re.search(r"(\d{2}) (to|and) (\d{2})", sentence)
        if not matches:
            return [0, 0]
        st = float(matches.group(1))
        ed = float(matches.group(3))

        return [st, ed]

    def second_to_frame(self, seconds, duration):
        return [int(seconds[0] / duration * self.n_frame), int(seconds[1] / duration * self.n_frame)]

    def frame_to_second(self, frames, duration):
        return [round(frames[0] / self.n_frame * duration, 2), round(frames[1] / self.n_frame * duration, 2)]

    def run(self, task, video_features, query, duration, st=None, ed=None, msg=None):
        choice = random.choice(["pos", "neg"])
        if st is not None and ed is not None:
            st, ed = min(st, duration), min(ed, duration)
            st, ed = self.second_to_frame([st, ed], duration)

        if task in ["grounding"]:
            question = prompt[task].replace("seconds", "frames").format(event=query)

        elif task in ["occurrence"]:
            question = random.choice(prompt[choice]).replace("seconds", "frames").format(event=query, st=st, ed=ed)

        elif task in ["co_occurrence", "sequential_after", "sequential_before"]:
            if not isinstance(query, list):
                raise ValueError(f"Invalid style of query: {query}")
            question = prompt[task].replace("seconds", "frames").format(target1=query[0], target2=query[1])

        elif task in ["compositional"]:
            query = query.replace("?", "")
            question = prompt[task].replace("seconds", "frames").format(question=query, st=st, ed=ed)

        else:
            raise NotImplementedError(f"Not implemented task: {task}")

        answer = self.inference(video_features=video_features, inp=question)

        if self.debug:
            print("Question:" + question)
            print("Answer:" + answer)
            print("")

        if task in ["grounding"]:
            return self.frame_to_second(self.extract_time_vtimellm(answer), duration)

        if task in ["occurrence"]:
            return [choice, answer]

        return answer

    def inference(self, video_features, inp, return_conv=False):
        conv = conv_templates['v1'].copy()
        roles = conv.roles
        inp = f"{roles[0]}: {inp}"
        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp

        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 # plain:sep(###) v1:sep2(None)
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        # Run model inference
        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=video_features[None,].cuda(),
                do_sample=True,
                temperature=0.05,
                max_new_tokens=256,
                repetition_penalty=1.0,
                length_penalty=1,
                use_cache=True,
                stopping_criteria=[stopping_criteria]
            )

        outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip()
        conv.messages[-1][-1] = outputs

        if return_conv:
            return outputs, conv

        return outputs

    def disable_torch_init(self):
        """
        Disable the redundant torch default initialization to accelerate model creation.
        """
        setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
        setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)