from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates

import torch
import copy
import numpy as np
from PIL import Image
import decord
from safetensors import safe_open
import os
from transformers import AutoTokenizer, SiglipTextModel

from .model import VTR_Model
class LLavaVideoEncoder(VTR_Model):
    def __init__(self, model_path, device):
        super().__init__()
        self.device = device
        self.model_path = model_path
        self.siglip_model_path = "/mnt/bn/wzr/models/siglip-so400m-patch14-384"
        self.siglip_text_model = None
        self.llava_model = None

    def load_video_model(self):
        self.llava_tokenizer, self.llava_model, self.llava_processor, _ = load_pretrained_model(
            self.model_path, None, "llava_qwen", 
            torch_dtype="bfloat16", 
            device_map=self.device,
            # overwrite_config=overwrite_config
        )
        with safe_open(os.path.join(self.siglip_model_path, "model.safetensors"), framework="pt", device="cpu") as f:
            state_dict = {}
            for k in f.keys():
                if str(k).startswith("vision_model.head."):
                    state_dict.update({k[len("vision_model.head."):]: f.get_tensor(k)})
            # print(state_dict.keys())
            self.llava_model.get_model().get_vision_tower().vision_tower.vision_model.head.load_state_dict(state_dict)

    def load_text_model(self):
        self.siglip_tokenizer = AutoTokenizer.from_pretrained(self.siglip_model_path)
        self.siglip_text_model = SiglipTextModel.from_pretrained(self.siglip_model_path).to(self.device)

    def get_text_embedding(self, text):
        if self.siglip_text_model is None:
            self.load_text_model()
        text_inputs = self.siglip_tokenizer([text], padding="max_length", return_tensors="pt")
        if text_inputs["input_ids"].shape[1] > 64:
            print("!!!!!!!!!!!!", text_inputs["input_ids"].shape)
            text_inputs["input_ids"] = text_inputs["input_ids"][:,:64]
        text_inputs["input_ids"] = text_inputs["input_ids"].to(self.device)
        with torch.no_grad():
            text_outputs = self.siglip_text_model(**text_inputs)
        text_outputs = text_outputs.pooler_output 
        text_embedding = text_outputs / text_outputs.norm(dim=-1, keepdim=True)
        return text_embedding 

    
    def get_video_embedding(self,
                            frames, return_image_embeddings=False):
        if self.llava_model is None:
            self.load_video_model()
        with torch.no_grad():
            video = self.llava_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(torch.bfloat16).to(self.device)
            encoded_image_features, video_pooler_output = self.llava_model.get_model().get_vision_tower()(video)
            visual_output = torch.mean(video_pooler_output.float(), dim=0)
            visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True)
            video_embedding = visual_output.unsqueeze(0)
        return video_embedding
    
if __name__ == "__main__":
    def get_video_frames_for_default_llava(video_path, max_frames_num=64,fps=1,force_sample=False):
        if max_frames_num == 0:
            return np.zeros((1, 336, 336, 3))
        # vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1)
        vr = decord.VideoReader(video_path)
        total_frame_num = len(vr)
        video_time = total_frame_num / vr.get_avg_fps()
        fps = round(vr.get_avg_fps()/fps)
        frame_idx = [i for i in range(0, len(vr), fps)]
        frame_time = [i/fps for i in frame_idx]
        if len(frame_idx) > max_frames_num or force_sample:
            sample_fps = max_frames_num
            uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
            frame_idx = uniform_sampled_frames.tolist()
            frame_time = [i/vr.get_avg_fps() for i in frame_idx]
        frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
        spare_frames = vr.get_batch(frame_idx).asnumpy()
        time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(spare_frames)} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
        return spare_frames, time_instruciton
    
    model = LLavaVideoEncoder("/mnt/bn/wzr/models/LLaVA-Video-7B-Qwen2", "cuda")
    frames, _ = get_video_frames_for_default_llava("/mnt/bn/wzr/datasets/VideoMME/data/_8lBR0E_Tx8.mp4")
    video_embedding = model.get_video_embedding(frames)
    text_embedding = model.get_text_embedding("What is the action in this video?")
    print(video_embedding @ text_embedding.T)
