import math
import os
from datetime import timedelta
from typing import List, Optional, Tuple, Union
import pandas as pd
import numpy as np
import copy
import torch
from torch import nn
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
import decord
from llava.constants import (
    DEFAULT_IM_END_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IMAGE_TOKEN,
    IGNORE_INDEX,
    IMAGE_TOKEN_INDEX,
)
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
    KeywordsStoppingCriteria,
    get_model_name_from_path,
    tokenizer_image_token,
)
from llava.model.builder import load_pretrained_model
# from llava.model.language_model.llava_llama import LlavaConfig
# from llava.model.language_model.llava_qwen import LlavaQwenConfig
from llava.model.language_model.llava_qwen_mpe import MPELlavaQwenConfig

from loguru import logger as eval_logger
from PIL import Image
from tqdm import tqdm
# from transformers import AutoConfig, AutoModelForCausalLM
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification
)

from llava.utils import get_sampling_strategy
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model


AutoConfig.register("llava_qwen_mpe", MPELlavaQwenConfig)


class HookTool:
    def __init__(self):
        self.fea = None
    def hook_fun(self, module, fea_in, fea_out):
        self.fea = fea_out.detach().cpu()

def get_gating_logit_by_hook(model):
    fea_hooks = []
    for n, m in model.named_modules():
        if 'gate' in n and "vision_model" in n and isinstance(m, nn.Linear):
            # print(n,m)
            cur_hook = HookTool()
            m.register_forward_hook(cur_hook.hook_fun)
            fea_hooks.append(cur_hook)
    return fea_hooks

@register_model("llava_vid")
class LlavaVid(lmms):
    """
    LlavaVid Model
    """

    def __init__(
        self,
        pretrained: str = "/mnt/new_cpfs/yqs/model/lmms-lab/LLaVA-Video-7B-Qwen2-Soft4-expanded",
        truncation: Optional[bool] = True,
        torch_dtype: Optional[Union[str, torch.dtype]] = "bfloat16",
        device: Optional[str] = "cuda:0",
        batch_size: Optional[Union[int, str]] = 1,
        attn_implementation=("sdpa" if torch.__version__ >= "2.1.2" else "eager"),  
        # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2". Seems FA2 is not effective during inference: https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5
        device_map="cuda:0",
        conv_template="vicuna_v1",
        use_cache=True,
        truncate_context=False,  # whether to truncate the context in generation, set it False for LLaVA-1.6
        max_frames_num: int = 20,
        video_fps: int = 1,
        mm_resampler_type: str = "spatial_pool",
        mm_spatial_pool_stride: int = 2,
        mm_spatial_pool_out_channels: int = 1024,
        mm_spatial_pool_mode: str = "bilinear",
        mm_resampler_location: str = "before",
        mm_newline_position: str = "grid",
        overwrite: bool = True,
        video_decode_backend: str = "decord",
        delay_load: bool = False,
        tie_weights: bool = True,
        force_sample: bool = False,
        add_time_instruction: bool = True,
        # add_faster_video: bool = False,
        faster_token_stride: int = 10,
        **kwargs,
    ) -> None:
        print("!!!!!!!!!!!!!!!!!LLavaVid!!!!!!!!!!!!!!!!!!")

        # import debugpy
        # debugpy.listen(("localhost", 5675))
        # print("Waiting for debugger attach...")
        # debugpy.wait_for_client()
        # print("Debugger attached, starting training...")

        super().__init__()
        assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

        accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
        accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
        if accelerator.num_processes > 1:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"
        elif accelerator.num_processes == 1 and (device_map == "auto" or device_map == "balanced_low_0"):
            self._device = torch.device(device)
            self.device_map = device_map
        else:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"

        self.pretrained = pretrained
        self.model_name = get_model_name_from_path(pretrained)
        self.video_decode_backend = video_decode_backend
        # self._config = AutoConfig.from_pretrained(self.pretrained)
        self.overwrite = overwrite
        self.mm_resampler_type = mm_resampler_type
        self.mm_spatial_pool_stride = int(mm_spatial_pool_stride)
        self.mm_spatial_pool_out_channels = int(mm_spatial_pool_out_channels)
        self.mm_spatial_pool_mode = mm_spatial_pool_mode
        self.max_frames_num = int(max_frames_num)
        self.fps = int(video_fps)
        self.mm_resampler_location = mm_resampler_location
        self.delay_load = delay_load
        self.force_sample = force_sample
        self.add_time_instruction = add_time_instruction
        print("force sample:", self.force_sample)
        # self.add_faster_video = add_faster_video
        # self.faster_token_stride = faster_token_stride
        self.torch_dtype = torch_dtype
        if self.overwrite == True:
            overwrite_config = {}
            # overwrite_config["mm_resampler_type"] = self.mm_resampler_type
            overwrite_config["mm_spatial_pool_stride"] = self.mm_spatial_pool_stride
            overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode
            overwrite_config["mm_pooling_position"] = self.mm_resampler_location
            overwrite_config["mm_newline_position"] = mm_newline_position
            overwrite_config["delay_load"] = self.delay_load
            # overwrite_config["attn_implementation"] = attn_implementation
            # Nonetype 此处self很多报错
            self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, self.model_name, device_map=self.device_map, torch_dtype=self.torch_dtype, overwrite_config=overwrite_config, attn_implementation=attn_implementation)
        else:
            self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, self.model_name, device_map=self.device_map, torch_dtype=self.torch_dtype, attn_implementation=attn_implementation)

        self._config = self._model.config

        if self._tokenizer.pad_token_id is None:
            if "qwen" in self._tokenizer.name_or_path.lower():
                print("Setting pad token to bos token for qwen model.")
                self._tokenizer.pad_token_id = 151643

        self.model.eval()
        if tie_weights:
            self.model.tie_weights()
        self.truncation = truncation
        self.batch_size_per_gpu = int(batch_size)
        self.conv_template = conv_template
        self.use_cache = use_cache
        self.truncate_context = truncate_context

        self.inductor_path = "/mnt/new_cpfs/yqs/model/HuggingFaceTB/SmolLM2-135M-Instruct/results/5e-5-3e/checkpoint-8466/" # config.inductor_path
        # /mnt/new_cpfs/yqs/model/HuggingFaceTB/SmolLM2-360M-Instruct/results/5e-5-3e/checkpoint-25395/
        self.inductor = AutoModelForSequenceClassification.from_pretrained(self.inductor_path, torch_dtype=torch.bfloat16) # MPE
        self.inductor_tokenizer = AutoTokenizer.from_pretrained(self.inductor_path)
        self.inductor.eval()
        self.inductor.training = False
        for p in self.inductor.parameters():
            p.requires_grad = False
        self.task_types = [
            "Static Attributes Recognition", # 0
            "Action Recognition", # 1
            "Emotion & Fine-grained Recognition", # 2
            "Temporal sequence positioning", # 3
            "Spatial Orientation & Navigation", # 4
            "Summary & Generalization", # 5
            "Reasoning & Logic", # 6
            "Text Recognition (OCR) & Cross-modal Alignment" # 7
        ]

        # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
        if accelerator.num_processes > 1:
            assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
            # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
            # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
            # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
            if accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs = {
                    "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
                    "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
                }
                AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
                eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
            if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
                self._model = accelerator.prepare(self.model)
            else:
                self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
            self.accelerator = accelerator
            if self.accelerator.is_local_main_process:
                eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes
        elif accelerator.num_processes == 1 and device_map == "auto":
            eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
            self._rank = 0
            self._word_size = 1
        else:
            eval_logger.info(f"Using single device: {self._device}")
            self.model.to(self._device)
            self._rank = 0
            self._world_size = 1

    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        return self._max_length

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.tokenizer.padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
        if self.tokenizer.padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
        """ """
        add_special_tokens = False if add_special_tokens is None else add_special_tokens
        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
        return encoding

    def load_image(self, image_path):
        frame_files = [os.path.join(image_path, f) for f in os.listdir(image_path) if os.path.isfile(os.path.join(image_path, f))]
        frame_files.sort()  # Ensure the frames are sorted if they are named sequentially

        # TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames
        num_frames_to_sample = 10

        total_frames = len(frame_files)

        sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)

        # Read and store the sampled frames
        video = []
        for idx in sampled_indices:
            frame_path = frame_files[idx]
            try:
                with Image.open(frame_path) as img:
                    frame = img.convert("RGB")
                    video.append(frame)
            except IOError:
                print(f"Failed to read frame at path: {frame_path}")
        return video

    def load_video_decord(self, video_path, probs, prompt_len, force_sample=False):
        if not isinstance(video_path, str) or not os.path.exists(video_path):
            raise ValueError(f"无效视频路径: {video_path}")

        # 初始化decord参数（强制CPU、单线程）
        decord.bridge.set_bridge('native')  # 避免潜在的torch冲突
        ctx = decord.cpu(0)
        try:
            # 使用更稳定的视频读取方式
            vr = decord.VideoReader(video_path, ctx=ctx, num_threads=1)
        except Exception as e:
            raise ValueError(f"无法打开视频文件：{video_path}") from e

        try:
            # 获取基础元数据（修正版）
            total_frames = len(vr)
            # 获取视频原始尺寸的正确方式
            try:
                meta = vr.get_meta_data()  # 标准方法获取元数据
                original_width = int(meta['width'])
                original_height = int(meta['height'])
            except (KeyError, AttributeError, IndexError):
                # 元数据获取失败时的备选方案
                if total_frames > 0:
                    first_frame = vr[0] # .asnumpy()
                    original_height, original_width = first_frame.shape[:2]
                else:
                    original_width = original_height = 384  # 默认尺寸

            avg_fps = vr.get_avg_fps()  # 直接获取平均帧率
            # 智能帧率处理
            actual_fps = avg_fps if not math.isnan(avg_fps) and avg_fps > 0 else 30.0

            # 计算视频总时长（三种方法）
            duration_methods = [
                lambda: total_frames / actual_fps,  # 方法1：总帧数/帧率
                lambda: vr.get_frame_timestamp(total_frames-1)[-1] if total_frames > 0 else 0,  # 方法2：最后一帧时间戳
                lambda: len(vr) / 30.0  # 方法3：保底策略（假设30fps）
            ]

            total_duration = 0.0
            for method in duration_methods:
                try:
                    total_duration = method()
                    if total_duration > 0:
                        break
                except Exception as e:
                    print(f"时长计算方法失败: {str(e)}")
                    continue
            if total_duration <= 0:
                total_duration = 0.0

            # 处理时间范围参数
            start_sec = 0.0
            end_sec = total_duration

            # 边界保护（双重验证）
            start_sec = max(0.0, start_sec)
            end_sec = max(start_sec, min(end_sec, total_duration))
            duration_sec = end_sec - start_sec
            if duration_sec <= 0:
                raise ValueError("无效的时间范围")
            # 获取采样策略
            task_type = np.argmax(probs)
            sampling_interval, resolution = get_sampling_strategy(int(duration_sec), task_type, prompt_len)

            # 生成采样时间点（带边界检查）
            num_samples = max(1, int(duration_sec / sampling_interval))
            times = np.linspace(
                start=start_sec,
                stop=end_sec,
                num=num_samples,
                endpoint=False,
                dtype=np.float32
            )

            # 转换时间戳为帧索引（增强版）
            frame_indices = []
            valid_times = []
            for t in times:
                idx = int(t * actual_fps)
                if 0 <= idx < total_frames:
                    frame_indices.append(idx)
                    valid_times.append(t)
                elif idx >= total_frames:
                    break  # 超出视频长度

            # 批量读取帧（带容错机制）
            try:
                if frame_indices:
                    frames = vr.get_batch(frame_indices).asnumpy()
                else:
                    # 当frame_indices为空时
                    if total_frames == 0:
                        # 视频无帧，生成384x384的数组
                        frames = np.zeros((0, 384, 384, 3), dtype=np.uint8)
                    else:
                        # 视频有帧但时间范围无效，生成原尺寸数组
                        frames = np.zeros((0, original_height, original_width, 3), dtype=np.uint8)
            except decord.DECORDError as e:
                print(f"批量读取失败，转为逐帧读取: {str(e)}")
                frames = []
                for idx in frame_indices:
                    try:
                        frame = vr[idx].asnumpy()
                        frames.append(frame)
                    except:
                        # 生成黑色占位帧，使用original_width和original_height
                        black_frame = np.zeros((original_height, original_width, 3), dtype=np.uint8)
                        frames.append(black_frame)
                # 转换为numpy数组
                frames = np.stack(frames) if frames else np.zeros(
                    (0, 384, 384, 3) if total_frames == 0 else (0, original_height, original_width, 3), dtype=np.uint8
                )

            # 处理空帧，保持原始尺寸
            processed_frames = []
            for frame in frames:
                if frame.size == 0:
                    # 空帧，生成黑色帧
                    black_frame = np.zeros((384, 384, 3), dtype=np.uint8)
                    processed_frames.append(black_frame)
                else:
                    # 保留原尺寸
                    processed_frames.append(frame)

            # 转换为最终的numpy数组
            if processed_frames:
                frame_array = np.stack(processed_frames)
            else:
                # 没有有效帧，根据视频是否有帧决定尺寸
                if total_frames == 0:
                    frame_array = np.zeros((0, 384, 384, 3), dtype=np.uint8)
                else:
                    frame_array = np.zeros((0, original_height, original_width, 3), dtype=np.uint8)

            # 重新计算实际duration（基于有效采样点）
            actual_duration = valid_times[-1] - valid_times[0] if valid_times else 0
            # # ========== 取帧查看逻辑 ==========
            # # 构建输出目录路径
            # base_dir = "/mnt/new_cpfs/yqs/model/frame_decord"
            # video_filename = os.path.basename(video_path)
            # base_name = os.path.splitext(video_filename)[0]
            # # 
            # # 处理目录命名冲突
            # target_dir_base = os.path.join(base_dir, base_name)
            # counter = 0
            # target_dir = target_dir_base
            # while os.path.exists(target_dir):
            #     target_dir = f"{target_dir_base}_{counter}"
            #     counter += 1
            # os.makedirs(target_dir, exist_ok=True)
            # # 保存所有调整后的帧
            # for frame_idx, frame in enumerate(frame_array):
            #     img = Image.fromarray(frame)
            #     frame_path = os.path.join(target_dir, f"frame_{frame_idx:04d}.jpg")
            #     img.save(frame_path)
            # # ========== 结束取帧 ==============

            return frame_array, actual_duration, sampling_interval, resolution
        finally:
            del vr # 显式释放资源（防止内存泄漏）
            decord.bridge.set_bridge('torch')  # 恢复默认bridge

    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
        res = []
        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

        for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
            # encode, pad, and truncate contexts for this batch
            if type(doc_to_target) == str:
                continuation = doc_to_target
            else:
                continuation = doc_to_target(self.task_dict[task][split][doc_id])
            visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
            visuals = self.flatten(visuals)
            videos = []
            for visual in visuals:
                video, video_time, _ = self.load_video_pyav(visual, self.max_frames_num, self.fps, force_sample=self.force_sample)
                video = self._image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda()
                if self.torch_dtype == "bfloat16":
                    video = video.bfloat16()
                else:
                    video = video.half()
                videos.append(video)

            qs = contexts
            if self.model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

            conv = conv_templates[self.conv_template].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)

            conv = conv_templates[self.conv_template].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], continuation)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
            attention_masks = input_ids.ne(self.tokenizer.pad_token_id).long().cuda()

            labels = input_ids.clone()
            # Context part no need to calculate for loss
            labels[0, : contxt_id.shape[1]] = -100

            with torch.inference_mode():
                outputs = self.model(input_ids=input_ids, labels=labels, images=videos, modalities="video")

            loss = outputs["loss"]
            # loss = torch.exp(loss)
            logits = outputs["logits"]
            greedy_tokens = logits.argmax(dim=-1)
            cont_toks = input_ids[:, contxt_id.shape[1] :]  # [1, seq]
            greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]]  # [1, seq]
            max_equal = (greedy_tokens == cont_toks).all()
            res.append((float(loss.item()), bool(max_equal)))
            pbar.update(1)
        pbar.close()
        return res

    def flatten(self, input):
        new_list = []
        for i in input:
            for j in i:
                new_list.append(j)
        return new_list
    
    def induct(self, text):
        def get_prompt_len(prompt):
            return len(self.inductor_tokenizer.tokenize(prompt))

        def process_single_text(single_text):
            batch_max_length = min(get_prompt_len(single_text), 8192) # max_len: 8192
            batch = {"input_ids": [], "attention_mask": []}
            encoding = self.inductor_tokenizer(
                single_text,
                truncation=True,
                max_length=batch_max_length,
                padding=False
            )
    
            pad_len = batch_max_length - len(encoding["input_ids"])
            input_ids = encoding["input_ids"] + [self.inductor_tokenizer.pad_token_id] * pad_len
            attention_mask = encoding["attention_mask"] + [0] * pad_len
    
            batch["input_ids"].append(input_ids)
            batch["attention_mask"].append(attention_mask)
    
            inputs = {
                "input_ids": torch.tensor(batch["input_ids"], dtype=torch.long).to(self.inductor.device),
                "attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long).to(self.inductor.device),
            }
            outputs = self.inductor(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)[0].to(torch.float32).cpu().numpy()
            return probs
    
        # Check if input is a single string or a list of strings
        if isinstance(text, str):
            return [process_single_text(text)], [get_prompt_len(text)]
        elif isinstance(text, list) and all(isinstance(t, str) for t in text):
            return [process_single_text(t) for t in text], [get_prompt_len(t) for t in text]
        else:
            raise ValueError("Input for Inductor must be a string or a list of strings.")

    def generate_until(self, requests) -> List[str]:
        fea_hooks = get_gating_logit_by_hook(self.model)
        hook = True
        
        res = []
        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

        # self.model.from_pretrained("/mnt/new_cpfs/yqs/model/llava-so/Soft4_wjt")

        for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
            visuals = doc_to_visual(self.task_dict[task][split][doc_id])

            visuals = ["/mnt/new_cpfs/yqs/LLaVA-MPE/scripts/llava-so/test_dataset/4_dancer.mp4"]
            contexts = "What kind of dance are they dancing? Explain with about 50 words."

            probs, prompt_len = self.induct([contexts])
            if prompt_len[0] > 8192:
                raise ValueError(f"Prompt length exceeds 8192 tokens: {prompt_len[0]}") 

            videos = []
            try:
                # for visual in visuals:
                if len(visuals) == 1: # 只有一个视频
                    if self.video_decode_backend == "decord":
                        video, video_time, interval, resolution = self.load_video_decord(
                            visuals[0], probs[0], prompt_len[0], force_sample=self.force_sample
                        )
                    elif self.video_decode_backend == "pyav":
                        video, video_time, interval, resolution = self.load_video_pyav(
                            visuals[0], probs[0], prompt_len[0], force_sample=self.force_sample
                        )
                    elif self.video_decode_backend == "image":
                        video = self.load_image(visuals[0])
                else:
                    if task == "seedbench":
                        video = visuals
                        frame_time = "1.00s"
                        video_time = 1
                    elif "mvbench" in task:
                        # video = visuals
                        # Reference: https://github.com/jayleicn/TVQA/blob/dfb0e5fe4582efca574dfddfeafd1008db3b33ef/data/README.md?plain=1#L50C34-L50C60
                        fps = 3
                        video_time = len(visuals) / fps
                        sampled_indices = np.linspace(0, len(visuals) - 1, self.max_frames_num, dtype=int)
                        frame_idx = sampled_indices.tolist()
                        frame_time = [i / fps for i in frame_idx]
                        frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
                        video = [visuals[i] for i in frame_idx]

                video = self._image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda()
                if self.torch_dtype == "bfloat16":
                    video = video.bfloat16()
                else:
                    video = video.half()
                videos.append(video)
            except Exception as e:
                eval_logger.info(f"{e}")
                eval_logger.info(f"Video {visuals} can not load, check the source")
                video_path = "\n".join(visuals)
                res.append(f"Video {video_path} can not load, check the source")
                pbar.update(1)
                continue

            qs = contexts
            # import pdb;pdb.set_trace()
            if self.add_time_instruction:
                task_type = int(probs[0].argmax()) # 确定 task_type 为 probs[0] 中最大值的索引
                if task_type == 0 or task_type == 3: # 帧数优先
                    inductor_instruction = "you'll see higher resolution with fewer frames, so focus on video details."
                elif task_type == 2: # 分辨率优先
                    inductor_instruction = "you'll see lower resolution with more frames, so you should focus on what's happening with video as a whole."
                elif task_type == 1 or task_type == 5: # 平衡偏帧数
                    inductor_instruction = "you'll see balanced frames and resolution (a bit more resolution oriented), so analyse the task carefully."
                else: # 4,6,7 平衡偏分辨率
                    inductor_instruction = "you'll see balanced frames and resolution (a bit more frames oriented), so analyse the task carefully."
                inductor_instruction = f"The type of Query is {self.task_types[task_type]}, {inductor_instruction}"
                time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video)} frames are uniformly sampled from it. The time interval between each frame is {interval:.2f} second. "
                qs = f"{time_instruciton}{inductor_instruction}\n{qs}"
            if self.model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN * len(videos) + "\n" + qs
            # qs = qs if ("single word" in qs) or ("one word" in qs) or ("A, B, C" in qs) or ("choice" in qs) or ("phrase" in qs) or ("best option" in qs) else qs + "\n Answer the question very concisely, with a maximum length of 300 words."

            # This is much safer for llama3, as we now have some object type in it
            if "llama_3" in self.conv_template:
                conv = copy.deepcopy(conv_templates[self.conv_template])
            else:
                conv = conv_templates[self.conv_template].copy()

            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            cur_prompt = qs

            # import pdb;pdb.set_trace()
            # 得到input_ids和attention_masks，此处图像还未经过视觉塔与投影层
            input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
            pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
            if "llama_3" in self.conv_template:
                pad_token_ids = 0  # lmms-lab/llama3-llava-8b is trained on this pad token id. You may need to customize this for other models.
            attention_masks = input_ids.ne(pad_token_ids).long().cuda()

            # 生成的文本中包含指定的关键词时，停止生成，此处指定的是"<|im_end|>"
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str] # ["<|im_end|>"]
            stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

            # import pdb;pdb.set_trace()
            if "max_new_tokens" not in gen_kwargs:
                gen_kwargs["max_new_tokens"] = 1024
            if "temperature" not in gen_kwargs:
                gen_kwargs["temperature"] = 0
            if "top_p" not in gen_kwargs:
                gen_kwargs["top_p"] = None
            if "num_beams" not in gen_kwargs:
                gen_kwargs["num_beams"] = 1
            # import pdb;pdb.set_trace()
            with torch.inference_mode():
                output_ids = self.model.generate(
                    inputs=input_ids,
                    images=videos,
                    attention_mask=attention_masks,
                    modalities="video",
                    use_cache=self.use_cache,
                    stopping_criteria=[stopping_criteria],
                    do_sample=True if gen_kwargs["temperature"] > 0 else False,
                    temperature=0,
                    top_p=gen_kwargs["top_p"],
                    num_beams=gen_kwargs["num_beams"],
                    max_new_tokens=gen_kwargs["max_new_tokens"],
                    # llava-mpe args
                    probs = probs,
                    resolution=resolution,
                )

                # if hook:
                #     max_index = np.argmax(probs[0])
                #     csv_path = f"./csv/task{max_index}_moe.csv"
                #     if not os.path.exists(csv_path):
                #         df = pd.DataFrame(0, index=[f"expert_{i}" for i in range(4)], 
                #                           columns=[f"layer_{i}" for i in range(len(fea_hooks))])
                #         df.to_csv(csv_path)

                #     # print('The number of hooks is:', len(fea_hooks))
                #     df = pd.read_csv(csv_path, index_col=0)
                #     counter = torch.zeros_like(torch.tensor(df.values), dtype=torch.int32)  # shape: [num_experts, num_layers]
                #     for layer_idx, gate_logits in enumerate(fea_hooks):
                #         gate_softmax = nn.functional.softmax(gate_logits.fea, dim=-1, dtype=torch.float)
                #         weights, selected_experts = torch.topk(gate_softmax, 2)
                #         selected_experts = selected_experts.reshape(-1)
                #         for expert_idx in selected_experts:
                #             counter[expert_idx, layer_idx] += 1
                #         # T,N,C = selected_experts.shape
                #         # for i in range(T):
                #         #     for j in range(N):
                #         #         for k in range(C):
                #         #             expert_idx = selected_experts[i, j, k].item()
                #         #             row = f"expert_{expert_idx}"
                #         #             col = f"layer_{layer_idx}"
                #         #             if row in df.index and col in df.columns:
                #         #                 df.loc[row, col] += 1

                #     df += pd.DataFrame(counter.numpy(), index=df.index, columns=df.columns)
                #     df.to_csv(csv_path)
                
            outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            eval_logger.debug(f"Question: {cur_prompt}")
            eval_logger.debug(f"Answer: {outputs}")
            # import pdb;pdb.set_trace()
            res.append(outputs)
            pbar.update(1)
        return res

    def generate_until_multi_round(self, requests) -> List[str]:
        raise NotImplementedError("TODO: Implement multi-round generation for LLaVAVid")
