import torch
import random
import os
import numpy as np
import math
import decord
from copy import deepcopy
from typing import List, Optional, Tuple, Union, Dict, Any

try:
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    NPU_AVAILABLE = True
except ImportError:
    NPU_AVAILABLE = False

from PIL import Image
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from tqdm import tqdm
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
import torch.nn.functional as F

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

try:
    from qwen_vl_utils import process_vision_info
except ImportError:
    eval_logger.warning("qwen_vl_utils not found")

# Import unified TASM compressor
from lmms_eval.caching.tasm_compressor import (
    TASMCompressor,
    TASMCompressorConfig,
    create_tasm_compressor,
)


def set_seed(seed_value):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def compute_js_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    p = p + eps
    q = q + eps
    p = p / p.sum(dim=-1, keepdim=True)
    q = q / q.sum(dim=-1, keepdim=True)
    
    m = 0.5 * (p + q)
    kl_pm = (p * (p.log() - m.log())).sum(dim=-1)
    kl_qm = (q * (q.log() - m.log())).sum(dim=-1)
    
    return 0.5 * (kl_pm + kl_qm)


@register_model("qwen2_vl_tasm")
class Qwen2VL_TASM(lmms):
    def __init__(
        self,
        pretrained: str = "Qwen/Qwen2-VL-7B-Instruct",
        device: Optional[str] = "cuda",
        device_map: Optional[str] = "cuda",
        batch_size: Optional[Union[int, str]] = 1,
        use_cache: bool = True,
        use_flash_attention_2: Optional[bool] = False,
        max_pixels: int = 256 * 28 * 28,
        min_pixels: int = 256 * 28 * 28,
        max_num_frames: int = 32,
        num_fewshot: int = 0,
        fewshot_split: str = "fewshot",
        task_vector_method: str = "combined", 
        task_vector_weight: float = 0.3, 
        enable_token_merging: bool = False,
        merge_similarity_threshold: float = 0.5,
        preserve_spatial: bool = True,
        spatial_window: int = 3,
        enable_dynamic_retrieval: bool = True,
        core_ratio: float = 0.2,  
        latent_ratio: float = 0.4,  
        retrieval_top_k: int = 96,  
        js_threshold: float = 0.002,  
        target_compression_ratio: float = 0.35,  
        sink_tokens: int = 4,
        local_tokens: int = 96,  
        attn_implementation: Optional[str] = "sdpa",
        **kwargs,
    ) -> None:
        super().__init__()
        if kwargs:
            eval_logger.warning(f"Ignoring kwargs: {kwargs}")
        
        accelerator = Accelerator()
        if accelerator.num_processes > 1:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"
        elif accelerator.num_processes == 1 and device_map == "auto":
            self._device = torch.device(device)
            self.device_map = device_map
        else:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"

        _attn_impl = "flash_attention_2" if use_flash_attention_2 else (attn_implementation or "sdpa")
        self._model = Qwen2VLForConditionalGeneration.from_pretrained(
            pretrained,
            torch_dtype=torch.bfloat16,
            device_map=self.device_map,
            attn_implementation=_attn_impl,
        ).eval()
        
        self.processor = AutoProcessor.from_pretrained(
            pretrained, max_pixels=max_pixels, min_pixels=min_pixels
        )
        self.max_pixels = max_pixels
        self.min_pixels = min_pixels
        self.max_num_frames = max_num_frames
        self._tokenizer = AutoTokenizer.from_pretrained(pretrained)
        self._config = self.model.config
        
        self.batch_size_per_gpu = int(batch_size)
        self.use_cache = use_cache
        self.k_shot = num_fewshot
        self.fewshot_split = fewshot_split
        
        self.tasm_config = TASMCompressorConfig(
            task_vector_method=task_vector_method,
            task_vector_weight=task_vector_weight,
            enable_merging=enable_token_merging,
            merge_similarity_threshold=merge_similarity_threshold,
            preserve_spatial=preserve_spatial,
            spatial_window=spatial_window,
            enable_dynamic_retrieval=enable_dynamic_retrieval,
            core_ratio=core_ratio,
            latent_ratio=latent_ratio,
            retrieval_top_k=retrieval_top_k,
            js_threshold=js_threshold,
            target_compression_ratio=target_compression_ratio,
            sink_tokens=sink_tokens,
            local_tokens=local_tokens,
        )

        self.tasm_compressor = create_tasm_compressor(self._config, self.tasm_config)
        self.tasm_compressor = self.tasm_compressor.to(self._device, dtype=torch.bfloat16)

        self.js_threshold = js_threshold
        self.target_compression_ratio = target_compression_ratio

        self.task_dict = {}
        self.fewshot_ctx = {} 
        self.fewshot_ctx_feat = {}  
        self.current_task = None
        
        eval_logger.info(f"=" * 60)
        eval_logger.info(f"TASM (Task-Aware Structured Memory) Initialized")
        eval_logger.info(f"=" * 60)
        eval_logger.info(f"Innovation 1 - Task Vector:")
        eval_logger.info(f"  Method: {task_vector_method}, Weight: {task_vector_weight}")
        eval_logger.info(f"Innovation 2 - Token Merging:")
        eval_logger.info(f"  Enabled: {enable_token_merging}, Threshold: {merge_similarity_threshold}")
        eval_logger.info(f"  Spatial: {preserve_spatial}, Window: {spatial_window}")
        eval_logger.info(f"Innovation 3 - Dynamic Memory:")
        eval_logger.info(f"  Enabled: {enable_dynamic_retrieval}")
        eval_logger.info(f"  Core: {core_ratio*100:.0f}%, Latent: {latent_ratio*100:.0f}%")
        eval_logger.info(f"  Retrieval Top-K: {retrieval_top_k}, JS Threshold: {js_threshold}")
        eval_logger.info(f"General:")
        eval_logger.info(f"  Target Compression: {target_compression_ratio*100:.0f}%")
        eval_logger.info(f"=" * 60)

        if accelerator.num_processes > 1:
            assert accelerator.distributed_type in [
                DistributedType.FSDP, DistributedType.MULTI_GPU, 
                DistributedType.DEEPSPEED, DistributedType.MULTI_NPU
            ]
            if accelerator.distributed_type == DistributedType.DEEPSPEED:
                self._model, *_ = accelerator.prepare(self.model)
            elif accelerator.distributed_type == DistributedType.FSDP:
                self._model = accelerator.prepare(self.model)
            else:
                self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
            self.accelerator = accelerator
            if self.accelerator.is_local_main_process:
                eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes
        else:
            self._rank = 0
            self._world_size = 1

    @property
    def config(self):
        return self._config

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        return self._model

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def flatten(self, input):
        new_list = []
        for i in input:
            for j in i:
                new_list.append(j)
        return new_list

    
    def process_tasm_compression(
        self, 
        context: List[Dict], 
        outer_kv, 
        custom_kv_pos_offset: int = 0,
        compressed_len: int = 0,
    ) -> Tuple[Any, int, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        device = "cuda" if self.device_map == "auto" else self.device

        context_texts = [
            self.processor.apply_chat_template(
                context, tokenize=False, add_generation_prompt=True, add_vision_id=False
            )[:-22] if outer_kv is None else 
            self.processor.apply_chat_template(
                context, tokenize=False, add_generation_prompt=True, add_vision_id=False
            )[58:-22]
        ]
        context_image_inputs, context_video_inputs = process_vision_info([context])
        context_inputs = self.processor(
            text=context_texts,
            images=context_image_inputs,
            videos=context_video_inputs,
            padding=False,
            return_tensors="pt",
        )

        labels = -100 * torch.ones_like(context_inputs["input_ids"])
        label_starts = torch.nonzero(context_inputs["input_ids"] == 77091) + 1
        label_ends = torch.nonzero(context_inputs["input_ids"] == 151645)[2::2] if outer_kv is None else \
                     torch.nonzero(context_inputs["input_ids"] == 151645)[1::2]
        for label_start_, label_end_ in zip(label_starts[:, 1], label_ends[:, 1]):
            labels[0, label_start_:label_end_+1] = context_inputs["input_ids"].squeeze(0)[label_start_:label_end_+1]
        
        context_inputs["labels"] = labels
        context_inputs = context_inputs.to(device)

        sep_context = [context[i:i+2] for i in range(0, len(context), 2)]
        sep_context_texts = [
            self.processor.apply_chat_template(
                sep_context_, tokenize=False, add_generation_prompt=True, add_vision_id=False
            )[58:-22]
            for sep_context_ in sep_context
        ]
        sep_context_inputs = self.processor(
            text=sep_context_texts,
            images=context_image_inputs,
            videos=context_video_inputs,
            padding=True,
            return_tensors="pt",
        )

        sep_labels = -100 * torch.ones_like(sep_context_inputs["input_ids"])
        sep_label_starts = torch.nonzero(sep_context_inputs["input_ids"] == 77091)
        if sep_label_starts.numel() > 0:
            sep_label_starts = sep_label_starts + 1
            for ind in range(sep_labels.shape[0]):
                batch_matches = sep_label_starts[sep_label_starts[:, 0] == ind]
                if batch_matches.numel() > 0:
                    start_pos = batch_matches[0, 1].item()
                    sep_labels[ind, start_pos:-1] = sep_context_inputs["input_ids"][ind, start_pos:-1]
        
        sep_context_inputs["labels"] = sep_labels
        sep_context_inputs = sep_context_inputs.to(device)
        sep_context_inputs = {"icl_" + k: v for k, v in sep_context_inputs.items()}

        with torch.no_grad():
            if hasattr(self.model, 'context_organize_tasm'):
                kv_cache, custom_kv_pos_offset, kv_avg_len, all_img_num, all_ques_num, all_ans_num, \
                    all_img_keep, all_ques_keep, all_ans_keep = self.model.context_organize_tasm(
                        **context_inputs, **sep_context_inputs,
                        use_cache=True,
                        custom_kv=outer_kv,
                        custom_kv_pos_offset=custom_kv_pos_offset,
                        tasm_compressor=self.tasm_compressor,
                    )
            else:
                kv_cache, custom_kv_pos_offset, kv_avg_len, all_img_num, all_ques_num, all_ans_num, \
                    all_img_keep, all_ques_keep, all_ans_keep = self._tasm_compress_fallback(
                        context_inputs, sep_context_inputs, outer_kv, custom_kv_pos_offset
                    )
        
        torch.cuda.empty_cache()
        compressed_len = kv_avg_len
        
        return kv_cache, custom_kv_pos_offset, compressed_len, all_img_num, all_ques_num, all_ans_num, \
               all_img_keep, all_ques_keep, all_ans_keep

    def _tasm_compress_fallback(
        self,
        context_inputs: Dict,
        sep_context_inputs: Dict,
        outer_kv,
        custom_kv_pos_offset: int,
    ) -> Tuple:
        kv_cache, custom_kv_pos_offset, kv_avg_len, all_img_num, all_ques_num, all_ans_num, \
            all_img_keep, all_ques_keep, all_ans_keep = self.model.context_organize_pre(
                **context_inputs, **sep_context_inputs,
                use_cache=True,
                custom_kv=outer_kv,
                custom_kv_pos_offset=custom_kv_pos_offset,
            )

        if self.tasm_config.enable_merging or self.tasm_config.enable_dynamic_retrieval:
            num_layers = len(kv_cache.key_cache)
            
            for layer_idx in range(num_layers):
                key_states = kv_cache.key_cache[layer_idx]
                value_states = kv_cache.value_cache[layer_idx]

                if self.tasm_config.enable_dynamic_retrieval:
                    importance = value_states.norm(dim=-1)
                    importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)

                    self.tasm_compressor.memory.store(
                        layer_idx=layer_idx,
                        key_states=key_states,
                        value_states=value_states,
                        importance_scores=importance,
                    )
        
        return kv_cache, custom_kv_pos_offset, kv_avg_len, all_img_num, all_ques_num, all_ans_num, \
               all_img_keep, all_ques_keep, all_ans_keep

    def _extract_task_vector(
        self,
        fewshot_messages: List[Dict],
        task_name: str,
    ):
        if len(fewshot_messages) == 0:
            return

        text = self.processor.apply_chat_template(
            fewshot_messages, 
            tokenize=False, 
            add_generation_prompt=False
        )
        image_inputs, video_inputs = process_vision_info(fewshot_messages)
        
        inputs = self.processor(
            text=[text],
            images=image_inputs if image_inputs else None,
            videos=video_inputs if video_inputs else None,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self._device)

        with torch.no_grad():
            outputs = self.model(
                **inputs,
                output_hidden_states=True,
                return_dict=True,
            )
        
        hidden_states = outputs.hidden_states

        last_hidden = hidden_states[-1]

        seq_len = last_hidden.shape[1]
        q_hidden = last_hidden[:, :seq_len // 2, :]
        a_hidden = last_hidden[:, seq_len // 2:, :]

        self.tasm_compressor.extract_task_vector(q_hidden, a_hidden)
        
        eval_logger.debug(f"Task vector extracted for task: {task_name}")

    def _should_retrieve_from_latent(
        self,
        current_attention: torch.Tensor,
    ) -> bool:
        return self.tasm_compressor.memory.should_retrieve(current_attention)

    def _build_fewshot_context(
        self,
        task: str,
        doc_to_visual: callable,
        doc_to_text: callable,
        doc_to_answer: callable,
    ) -> List[Dict]:
        if self.k_shot == 0:
            return []

        if "imagenet" in task.lower():
            set_seed(0)
            sel_ctx = [i for i in range(self.k_shot)]
        else:
            random.seed(42)
            sel_ctx = random.sample(
                [i for i in range(len(self.task_dict[task][self.fewshot_split]))],
                min(self.k_shot, len(self.task_dict[task][self.fewshot_split]))
            )
        
        ctx_text = [doc_to_text(self.task_dict[task][self.fewshot_split][ctx_]) for ctx_ in sel_ctx]
        ctx_ans = [doc_to_answer(self.task_dict[task][self.fewshot_split][ctx_]) for ctx_ in sel_ctx]
        ctx_visuals = [doc_to_visual(self.task_dict[task][self.fewshot_split][ctx_]) for ctx_ in sel_ctx]
        ctx_visuals = self.flatten(ctx_visuals)
        
        eval_logger.info(f"Using {self.fewshot_split} split as fewshot demonstrations ({len(sel_ctx)} shots)")
        
        ctx_msg = []
        for ind in range(len(sel_ctx)):
            if len(ctx_visuals) > 0:
                visual = ctx_visuals[ind] if ind < len(ctx_visuals) else None
                if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")):
                    visual = os.path.expanduser(visual)
                    vr = decord.VideoReader(visual, ctx=decord.cpu(), num_threads=16)
                    ctx_msg.append({
                        "role": "user",
                        "content": [
                            {"type": "video", "video": visual, "max_pixels": self.max_pixels},
                            {"type": "text", "text": ctx_text[ind]}
                        ]
                    })
                elif isinstance(visual, Image.Image):
                    ctx_msg.append({
                        "role": "user",
                        "content": [
                            {"type": "image", "image": visual},
                            {"type": "text", "text": ctx_text[ind]}
                        ]
                    })
                elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual):
                    image_content = [{"type": "image", "image": v} for v in visual]
                    ctx_msg.append({
                        "role": "user",
                        "content": image_content + [{"type": "text", "text": ctx_text[ind]}]
                    })
                else:
                    ctx_msg.append({
                        "role": "user",
                        "content": [{"type": "text", "text": ctx_text[ind]}]
                    })
            else:
                ctx_msg.append({
                    "role": "user",
                    "content": [{"type": "text", "text": ctx_text[ind]}]
                })
            ctx_msg.append({
                "role": "assistant",
                "content": [{"type": "text", "text": ctx_ans[ind]}]
            })
        
        return ctx_msg

    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
        raise NotImplementedError("Loglikelihood is not implemented for TASM model")

    def generate_until(self, requests: List[Instance]) -> List[str]:
        res = []

        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return -len(toks), x[0]

        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="TASM Model Responding")
        re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
        useful_len = len(requests[0].arguments)
        
        for chunk in chunks:
            if useful_len == 6:
                contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
                doc_to_text = None
                doc_to_answer = None
            else:
                contexts, all_gen_kwargs, doc_to_visual, doc_to_text, doc_to_answer, doc_id, task, split = zip(*chunk)
                doc_to_text = doc_to_text[0]
                doc_to_answer = doc_to_answer[0]
            
            task = task[0]
            split = split[0]
            visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
            visuals = self.flatten(visuals)

            if task not in self.fewshot_ctx.keys() and self.k_shot > 0:
                ctx_msg = self._build_fewshot_context(
                    task=task,
                    doc_to_visual=doc_to_visual[0],
                    doc_to_text=doc_to_text,
                    doc_to_answer=doc_to_answer,
                )
                self.fewshot_ctx[task] = ctx_msg

                if ctx_msg:
                    self._extract_task_vector(ctx_msg, task)
                    
            elif self.k_shot > 0:
                ctx_msg = self.fewshot_ctx[task]
            else:
                ctx_msg = None
            
            gen_kwargs = all_gen_kwargs[0]

            until = [self.tokenizer.decode(self.eot_token_id)]
            if "until" in gen_kwargs:
                until = gen_kwargs.pop("until")
                if isinstance(until, str):
                    until = [until]
            
            if isinstance(contexts, tuple):
                contexts = list(contexts)
            
            for i in range(len(contexts)):
                if "<image>" in contexts[i]:
                    contexts[i] = contexts[i].replace("<image>", "")

            messages = []
            for i, context in enumerate(contexts):
                if "<image>" in context:
                    context = context.replace("<image>", "")
                
                message = []
                if len(visuals) > 0:
                    visual = visuals[i] if i < len(visuals) else None
                    if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")):
                        visual = os.path.expanduser(visual)
                        message.append({
                            "role": "user",
                            "content": [
                                {"type": "video", "video": visual, "max_pixels": self.max_pixels},
                                {"type": "text", "text": context}
                            ]
                        })
                    elif isinstance(visual, Image.Image):
                        message.append({
                            "role": "user",
                            "content": [
                                {"type": "image", "image": visual},
                                {"type": "text", "text": context}
                            ]
                        })
                    elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual):
                        image_content = [{"type": "image", "image": v} for v in visual]
                        message.append({
                            "role": "user",
                            "content": image_content + [{"type": "text", "text": context}]
                        })
                    else:
                        message.append({
                            "role": "user",
                            "content": [{"type": "text", "text": context}]
                        })
                else:
                    message.append({
                        "role": "user",
                        "content": [{"type": "text", "text": context}]
                    })
                messages.append(message)

            if ctx_msg is not None:
                texts = [
                    self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)[58:]
                    for msg in messages
                ]
                
                if task not in self.fewshot_ctx_feat.keys():
                    img_num, ques_num, ans_num = 0, 0, 0
                    img_keep, ques_keep, ans_keep = 0, 0, 0

                    samples_per_iter = 20 if "imagenet" in task.lower() else 4
                    contexts_split = [
                        ctx_msg[i * (samples_per_iter * 2):(i + 1) * (samples_per_iter * 2)]
                        for i in range(len(ctx_msg) // (samples_per_iter * 2) + 
                                      int(len(ctx_msg) % (samples_per_iter * 2) > 0))
                    ]
                    
                    outer_kv = None
                    kv_offset = 0
                    compressed_len = 0

                    self.tasm_compressor.clear()
                    
                    for ind, sub_context in enumerate(contexts_split):
                        outer_kv, kv_offset, compressed_len, all_img_num, all_ques_num, all_ans_num, \
                            all_img_keep, all_ques_keep, all_ans_keep = self.process_tasm_compression(
                                sub_context, outer_kv, kv_offset, compressed_len
                            )
                        
                        img_num += all_img_num.item() if torch.is_tensor(all_img_num) else all_img_num
                        ques_num += all_ques_num.item() if torch.is_tensor(all_ques_num) else all_ques_num
                        ans_num += all_ans_num.item() if torch.is_tensor(all_ans_num) else all_ans_num
                        img_keep += all_img_keep.item() if torch.is_tensor(all_img_keep) else all_img_keep
                        ques_keep += all_ques_keep.item() if torch.is_tensor(all_ques_keep) else all_ques_keep
                        ans_keep += all_ans_keep.item() if torch.is_tensor(all_ans_keep) else all_ans_keep

                    stats = self.tasm_compressor.get_stats()
                    eval_logger.info(f"TASM Compression for task '{task}':")
                    eval_logger.info(f"  Final compressed length: {compressed_len} / {outer_kv.get_past_seq_len()}")
                    eval_logger.info(f"  Compression ratio: {stats.get('compression_ratio', 0):.3f}")
                    
                    self.fewshot_ctx_feat[task] = (outer_kv, kv_offset)
                else:
                    outer_kv, kv_offset = self.fewshot_ctx_feat[task]
            else:
                outer_kv = None
                kv_offset = 0
                texts = [
                    self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
                    for msg in messages
                ]

            image_inputs, video_inputs = process_vision_info(messages, self.max_num_frames)
            
            inputs = self.processor(
                text=texts,
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            
            if self.device_map == "auto":
                inputs = inputs.to("cuda")
            else:
                inputs = inputs.to(self.device)

            if "max_new_tokens" not in gen_kwargs:
                gen_kwargs["max_new_tokens"] = 128
            if "temperature" not in gen_kwargs:
                gen_kwargs["temperature"] = 0
            if "top_p" not in gen_kwargs:
                gen_kwargs["top_p"] = None
            if "num_beams" not in gen_kwargs:
                gen_kwargs["num_beams"] = 1

            pad_token_id = self.tokenizer.pad_token_id

            if outer_kv is not None and self.tasm_config.enable_dynamic_retrieval:
                self.tasm_compressor.clear_retrieval_cache()

            cont = self.model.generate(
                **inputs,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=pad_token_id,
                do_sample=True if gen_kwargs["temperature"] > 0 else False,
                temperature=gen_kwargs["temperature"],
                top_p=gen_kwargs["top_p"],
                num_beams=gen_kwargs["num_beams"],
                max_new_tokens=gen_kwargs["max_new_tokens"],
                use_cache=self.use_cache,
                custom_kv=outer_kv,
                custom_kv_pos_offset=kv_offset,
            )

            generated_ids_trimmed = [
                out_ids[len(in_ids):] 
                for in_ids, out_ids in zip(inputs["input_ids"], cont)
            ]
            answers = self.processor.batch_decode(
                generated_ids_trimmed, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            )

            for i, ans in enumerate(answers):
                for term in until:
                    if len(term) > 0:
                        ans = ans.split(term)[0]
                answers[i] = ans

            for ans, context in zip(answers, contexts):
                res.append(ans)
                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
                pbar.update(1)

        res = re_ords.get_original(res)
        pbar.close()
        
        return res

    def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
        res = []

        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return -len(toks), x[0]

        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="TASM Multi-Round")
        re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)

        for chunk in chunks:
            contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
            task = task[0]
            split = split[0]
            
            visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
            visuals = self.flatten(visuals)
            
            gen_kwargs = all_gen_kwargs[0]
            
            until = [self.tokenizer.decode(self.eot_token_id)]
            if "until" in gen_kwargs:
                until = gen_kwargs.pop("until")
                if isinstance(until, str):
                    until = [until]

            for i, context in enumerate(contexts):
                if isinstance(context, list):
                    messages = context
                else:
                    messages = [{"role": "user", "content": context}]

                if len(visuals) > 0 and i < len(visuals):
                    visual = visuals[i]
                    if isinstance(visual, Image.Image):
                        if isinstance(messages[0]["content"], str):
                            messages[0]["content"] = [
                                {"type": "image", "image": visual},
                                {"type": "text", "text": messages[0]["content"]}
                            ]
                
                text = self.processor.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
                
                image_inputs, video_inputs = process_vision_info([messages], self.max_num_frames)
                
                inputs = self.processor(
                    text=[text],
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(self.device)

                if "max_new_tokens" not in gen_kwargs:
                    gen_kwargs["max_new_tokens"] = 128
                if "temperature" not in gen_kwargs:
                    gen_kwargs["temperature"] = 0
                if "top_p" not in gen_kwargs:
                    gen_kwargs["top_p"] = None
                if "num_beams" not in gen_kwargs:
                    gen_kwargs["num_beams"] = 1

                cont = self.model.generate(
                    **inputs,
                    eos_token_id=self.tokenizer.eos_token_id,
                    pad_token_id=self.tokenizer.pad_token_id,
                    do_sample=True if gen_kwargs["temperature"] > 0 else False,
                    temperature=gen_kwargs["temperature"],
                    top_p=gen_kwargs["top_p"],
                    num_beams=gen_kwargs["num_beams"],
                    max_new_tokens=gen_kwargs["max_new_tokens"],
                    use_cache=self.use_cache,
                )

                generated_ids = cont[0][len(inputs["input_ids"][0]):]
                ans = self.processor.decode(generated_ids, skip_special_tokens=True)
                
                for term in until:
                    if len(term) > 0:
                        ans = ans.split(term)[0]
                
                res.append(ans)
                pbar.update(1)

        res = re_ords.get_original(res)
        pbar.close()
        
        return res

    def clear_tasm_cache(self, task: Optional[str] = None):
        if task is not None:
            if task in self.fewshot_ctx:
                del self.fewshot_ctx[task]
            if task in self.fewshot_ctx_feat:
                del self.fewshot_ctx_feat[task]
        else:
            self.fewshot_ctx.clear()
            self.fewshot_ctx_feat.clear()
        self.tasm_compressor.clear()
        torch.cuda.empty_cache()
        eval_logger.info(f"TASM cache cleared for task: {task if task else 'all'}")

    def get_compression_stats(self) -> Dict[str, Any]:
        return self.tasm_compressor.get_stats()

