import torch
import warnings

# Suppress all warnings
warnings.filterwarnings("ignore")

# Suppress PyTorch specific warnings
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_warn_always(False)

import os
import json
import copy
from datetime import timedelta
from typing import List, Optional, Tuple, Union

from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from packaging import version
from tqdm import tqdm

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

from loguru import logger as eval_logger

try:
    from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
    from llava.conversation import conv_templates
    from llava.mm_utils import (
        get_model_name_from_path,
        process_images,
        tokenizer_image_token,
    )
    from llava.model.builder import load_pretrained_model
except Exception as e:
    eval_logger.debug("LLaVA is not installed. Please install LLaVA to use this model.\nError: %s" % e)

# inference implementation for attention, can be "sdpa", "eager", "flash_attention_2"
if version.parse(torch.__version__) >= version.parse("2.1.2"):
    best_fit_attn_implementation = "sdpa"
else:
    best_fit_attn_implementation = "eager"


@register_model("llava")
class Llava(lmms):
    """
    Llava Model
    """

    def __init__(
        self,
        pretrained: str = "liuhaotian/llava-v1.5-7b",
        truncation: Optional[bool] = True,
        device: Optional[str] = "cuda:0",
        batch_size: Optional[Union[int, str]] = 1,
        model_name=None,
        attn_implementation=best_fit_attn_implementation,
        device_map="cuda:0",
        conv_template="vicuna_v1",
        use_cache=True,
        tie_weights: bool = True,
        truncate_context=False,  # whether to truncate the context in generation, set it False for LLaVA-1.6
        customized_config=None,  # ends in json
        cfg=None,
        **kwargs,
    ) -> None:
        super().__init__()
        assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

        self.cfg = cfg or {}

        accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
        accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
        self.accelerator = accelerator
        if accelerator.num_processes > 1:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"
        elif accelerator.num_processes == 1 and device_map == "auto":
            self._device = torch.device(device)
            self.device_map = device_map
        else:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"

        llava_model_args = {"multimodal": True}
        if customized_config is not None:
            llava_model_args["customized_config"] = customized_config
        if attn_implementation is not None:
            llava_model_args["attn_implementation"] = attn_implementation
        if "use_flash_attention_2" in kwargs:
            llava_model_args["use_flash_attention_2"] = kwargs["use_flash_attention_2"]

        model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)
        self.model_name = model_name
        try:
            self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(
                pretrained, None, model_name, device_map=self.device_map, **llava_model_args
            )
        except TypeError:
            llava_model_args.pop("multimodal", None)
            self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(
                pretrained, None, model_name, device_map=self.device_map, **llava_model_args
            )
        self._config = self._model.config
        self.model.eval()
        if tie_weights:
            self.model.tie_weights()

        self.truncation = truncation
        self.batch_size_per_gpu = int(batch_size)
        self.conv_template = conv_template
        self.use_cache = use_cache
        self.truncate_context = truncate_context

        if accelerator.num_processes > 1:
            assert accelerator.distributed_type in [
                DistributedType.FSDP,
                DistributedType.MULTI_GPU,
                DistributedType.DEEPSPEED,
            ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
            if accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs = {
                    "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
                    "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
                }
                AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
                eval_logger.info(
                    "Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0"
                )

            if accelerator.distributed_type in (DistributedType.FSDP, DistributedType.DEEPSPEED):
                self._model = accelerator.prepare(self.model)
            else:
                self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
            self.accelerator = accelerator
            if self.accelerator.is_local_main_process:
                eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes
        elif accelerator.num_processes == 1 and device_map == "auto":
            eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
            self._rank = 0
            self._world_size = 1
        else:
            eval_logger.info(f"Using single device: {self._device}")
            self.model.to(self._device)
            self._rank = 0
            self._world_size = 1

        # Add head masking configuration
        if self.cfg.get('metadata').get("head_masking", False):
            self.head_masking_start_layer = self.cfg.get('metadata').get("head_masking_start_layer", 0) if self.cfg else 0
            self.head_masking_end_layer = self.cfg.get('metadata').get("head_masking_end_layer", 31) if self.cfg else 31
            self.head_masking_start_head = self.cfg.get('metadata').get("head_masking_start_head", 0) if self.cfg else 0
            self.head_masking_end_head = self.cfg.get('metadata').get("head_masking_end_head", 31) if self.cfg else 31

        # self.gate_mlp = self.cfg.get('metadata').get("gate_mlp", False)
        self.gate_truthful_head = self.cfg.get('metadata').get("gate_truthful_head", False)
        # self.gate_truthful_mlp = self.cfg.get('metadata').get("gate_truthful_mlp", False)
        # self.gate_truthful_mlp_chunk = self.cfg.get('metadata').get("gate_truthful_mlp_chunk", False)
        
        # self.mlp_ablation = self.cfg.get('metadata').get("mlp_ablation", False)

        truthful_head_filepath = self.cfg.get('metadata').get("truthful_head_filepath", None)
        if truthful_head_filepath is not None:
            import numpy as np
            self.truthful_head = torch.from_numpy(np.load(truthful_head_filepath))
        else:
            self.truthful_head = None

        # truthful_mlp_chunk_filepath = self.cfg.get('metadata').get("truthful_mlp_chunk_filepath", None)
        # if truthful_mlp_chunk_filepath is not None:
        #     import numpy as np
        #     self.truthful_mlp_chunk = torch.from_numpy(np.load(truthful_mlp_chunk_filepath))
        # else:
        #     self.truthful_mlp_chunk = None

        self.k_heads_intervention = self.cfg.get('metadata').get("k_heads_intervention", False)
        if self.k_heads_intervention:
            self.k = self.cfg.get('metadata').get('hyperparams').get('k', 20)
            self.mode = self.cfg.get('metadata').get('hyperparams').get("mode", "bottom")
            self.selected_heads = self.select_heads(self.truthful_head, self.k, self.mode)
            if not self.gate_truthful_head:
                self.truthful_head = None
        else:
            self.selected_heads = None
        
        self.hyperparams = self.cfg.get('metadata').get("hyperparams", None)


    def select_heads(self, score_matrix, k=20, mode="bottom"):
        """
        score_matrix: torch.Tensor of shape (num_layers, num_heads)
        k: number of heads to select
        mode: "top" or "bottom"
        return: list of (layer_idx, head_idx)
        """

        num_layers, num_heads = score_matrix.shape
        flat_scores = score_matrix.view(-1)  # (num_layers * num_heads)
        
        if mode == "top":
            values, indices = torch.topk(flat_scores, k)
        elif mode == "bottom":
            values, indices = torch.topk(flat_scores, k, largest=False)
        else:
            raise ValueError("mode must be 'top' or 'bottom'")
        
        selected = []
        for idx in indices.tolist():
            layer_idx = idx // num_heads
            head_idx = idx % num_heads
            selected.append((layer_idx, head_idx))
        selected = set(selected)
        return selected



    @property
    def config(self):
        return self._config

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        return self._max_length

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.tokenizer.padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
        if self.tokenizer.padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
        add_special_tokens = False if add_special_tokens is None else add_special_tokens
        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
        if left_truncate_len:
            encoding = encoding[-len(encoding):]  # no-op safe
        return encoding

    def tok_decode(self, tokens):
        try:
            return self.tokenizer.decode(tokens)
        except Exception:
            return self.tokenizer.decode([tokens])

    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
        res = []
        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

        for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
            if type(doc_to_target) == str:
                continuation = doc_to_target
            else:
                continuation = doc_to_target(self.task_dict[task][split][doc_id])
            visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
            visuals = self.flatten(visuals)
            image_sizes = [[visual.size[0], visual.size[1]] for visual in visuals]
            if visuals:
                image = process_images(visuals, self._image_processor, self._config)
                if type(image) is list:
                    image = [_image.to(dtype=torch.float16, device=self.device) for _image in image]
                else:
                    image = image.to(dtype=torch.float16, device=self.device)
            else:
                image = None

            prompts_input = contexts[0] if isinstance(contexts, list) else contexts
            if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input:
                image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
                image_tokens = " ".join(image_tokens)
                prompts_input = image_tokens + "\n" + (contexts[0] if isinstance(contexts, list) else contexts)

            if "llama_3" in self.conv_template:
                conv = copy.deepcopy(conv_templates[self.conv_template])
            else:
                conv = conv_templates[self.conv_template].copy()
            conv.append_message(conv.roles[0], prompts_input)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
            contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            conv.messages[1][1] = continuation

            prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            labels = input_ids.clone()
            labels[0, : contxt_id.shape[1]] = -100
            with torch.inference_mode():
                outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes)
            loss = outputs["loss"]
            logits = outputs["logits"]
            greedy_tokens = logits.argmax(dim=-1)
            cont_toks = input_ids[:, contxt_id.shape[1] :]
            greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]]
            max_equal = (greedy_tokens == cont_toks).all()
            res.append((float(loss.item()), bool(max_equal)))
            pbar.update(1)
        pbar.close()
        return res

    def flatten(self, input):
        if not input or any(i is None for i in input):
            return []
        new_list = []
        for i in input:
            if i:
                for j in i:
                    new_list.append(j)
        return new_list

    def generate_until(self, requests: List[Instance], mask_info: dict = None) -> List[str]:
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            toks = self.tok_encode(x[0])
            return -len(toks), x[0]

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
        num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
        if not self.cfg.get('metadata').get("head_masking", False):
            pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
        for chunk in chunks:
            contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
            task = task[0]
            split = split[0]
            batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]  # [B, N]
            flattened_visuals = self.flatten(batched_visuals)  # [B*N]
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]

            # Set default values for until and max_new_tokens
            until = [self.tok_decode(self.eot_token_id)]

            # Update values from gen_kwargs if present
            if "until" in gen_kwargs:
                until = gen_kwargs.pop("until")
                if isinstance(until, str):
                    until = [until]
                elif not isinstance(until, list):
                    raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")

            if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__:
                # here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation
                self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
                eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}")
            # encode, pad, and truncate contexts for this batch
            if flattened_visuals:
                image_tensor = process_images(flattened_visuals, self._image_processor, self._config)
                if type(image_tensor) is list:
                    image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
                else:
                    image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
            else:
                image_tensor = None

            # prompts_input = contexts[0]

            question_input = []

            for visual, context in zip(batched_visuals, contexts):
                if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
                    """
                    Three senarios:
                    1. No image, and there for, no image token should be added.
                    2. image token is already specified in the context, so we don't need to add it.
                    3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
                    """
                    image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN]
                    image_tokens = " ".join(image_tokens)
                    question = image_tokens + "\n" + context
                else:
                    question = context

                if self.cfg.get('metadata').get("save_attention", False) or self.cfg.get('metadata').get("head_masking", False) or self.cfg.get('metadata').get("no_system_prompt", False): 
                    question_input.append(question)
                else:
                    # This is much safer for llama3, as we now have some object type in it
                    if "llama_3" in self.conv_template:
                        conv = copy.deepcopy(conv_templates[self.conv_template])
                    else:
                        conv = conv_templates[self.conv_template].copy()
                    conv.append_message(conv.roles[0], question)
                    conv.append_message(conv.roles[1], None)
                    prompt_question = conv.get_prompt()
                    question_input.append(prompt_question)

            # input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            # preconfigure gen_kwargs with defaults
            gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))]
            if "max_new_tokens" not in gen_kwargs:
                gen_kwargs["max_new_tokens"] = 1024
            if "temperature" not in gen_kwargs:
                gen_kwargs["temperature"] = 0
            if "top_p" not in gen_kwargs:
                gen_kwargs["top_p"] = None
            if "num_beams" not in gen_kwargs:
                gen_kwargs["num_beams"] = 1

            input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input]
            pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
            input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device)
            attention_masks = input_ids.ne(pad_token_ids).to(self.device)

            
            # These steps are not in LLaVA's original code, but are necessary for generation to work
            try:
                with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
                    cont = self.model.generate(
                        input_ids,
                        attention_mask=attention_masks,
                        pad_token_id=pad_token_ids,
                        images=image_tensor,
                        image_sizes=gen_kwargs["image_sizes"],
                        do_sample=True if gen_kwargs["temperature"] > 0 else False,
                        # temperature=gen_kwargs["temperature"],
                        # top_p=gen_kwargs["top_p"],
                        num_beams=gen_kwargs["num_beams"],
                        max_new_tokens=gen_kwargs["max_new_tokens"],
                        use_cache=self.use_cache,
                        return_dict_in_generate=self.cfg.get('metadata').get("return_dict_in_generate", False),
                        output_attentions=self.cfg.get('metadata').get("output_attentions", False),
                        mask_info=mask_info,
                        # gate_mlp=self.gate_mlp,
                        gate_truthful_head=self.gate_truthful_head,
                        # gate_truthful_mlp=self.gate_truthful_mlp,
                        # gate_truthful_mlp_chunk=self.gate_truthful_mlp_chunk,
                        # mlp_ablation=self.mlp_ablation,
                        truthful_head=self.truthful_head,
                        # truthful_mlp_chunk=self.truthful_mlp_chunk,
                        selected_heads=self.selected_heads,
                        hyperparams=self.hyperparams,
                    )

                if self.cfg.get('metadata').get("return_dict_in_generate", False):
                    text_outputs = self.tokenizer.batch_decode(cont.sequences, skip_special_tokens=True)
                else:
                    text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
            except Exception as e:
                raise e
                eval_logger.error(f"Error {e} in generating")
                cont = ""
                text_outputs = [""]

            res.extend(text_outputs)
            self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
            if not self.cfg.get('metadata').get("head_masking", False):
                pbar.update(1)
            # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)
    
        if not self.cfg.get('metadata').get("head_masking", False):
            pbar.close()
        
        if self.cfg.get('metadata').get("head_masking", False):
            # attentions = [[att.cpu() for att in layer] for layer in cont.attentions]
            # attentions = torch.stack([a.cpu() for a in attentions[0]])

            # attentions = [[att for att in layer] for layer in cont.attentions]
            # attentions = torch.stack([a for a in attentions[0]])
            # attentions = attentions[3:15, :, :, -1, -1].mean(0).mean(1)
            if self.model_name == "llava-v1.6-vicuna-7b":
                attentions = torch.stack([attn.cpu() for attn in cont.attentions[0]])[3:15, :, :, -1, -1].mean(0).mean(1)
            else:
                attentions = torch.stack([attn for attn in cont.attentions[0]])[3:15, :, :, -1, -1].mean(0).mean(1)

            del cont.attentions
            import gc
            gc.collect()
            torch.cuda.empty_cache()

            return res, attentions

        elif self.cfg.get('metadata').get("save_attention", False) or self.cfg.get('metadata').get("save_visual_attn_entropy", False):
            if self.model_name == "llava-v1.6-vicuna-7b":
                attentions = torch.stack([attn.cpu() for attn in cont.attentions[0]])
            else:
                attentions = torch.stack([attn.cpu() for attn in cont.attentions[0]])

            del cont.attentions
            import gc
            gc.collect()
            torch.cuda.empty_cache()
            
            return res, attentions

        else:
            return res

    def generate_until_with_head_masking(self, requests: List[Instance]) -> List[str]:
        import json
        import os
        from datetime import datetime

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Get output path from configuration or use default
        if self.cfg.get('metadata').get("output_dir", None):
            base_output_dir = self.cfg.get('metadata').get("output_dir")
        else:
            # Fallback to current working directory
            base_output_dir = os.getcwd()
        
        # Create subdirectory for attention difference results
        output_dir = os.path.join(base_output_dir, f"attn_diff_results/{self.model_name}")
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f"attn_diff_results_{timestamp}.jsonl")

        print(f"Starting attention difference analysis for {len(requests)} samples...")
        print(f"Results will be saved to: {output_file}")

        with torch.no_grad():
            for request in tqdm(requests, total=len(requests), desc="Processing samples"):
                contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
                sample_data = self.task_dict[task][split][doc_id]

                image_path = sample_data.get('image_path', '')
                text_input = contexts[0] if isinstance(contexts, list) else contexts
                pos_target_word = sample_data.get('positive_target_word', '')
                neg_target_word = sample_data.get('negative_target_word', '')
                    
                try:
                    single_pos_requests = self.update_samples_with_target_word([request], 'positive')
                    single_neg_requests = self.update_samples_with_target_word([request], 'negative')
                    
                    # Get attention from original model for this sample
                    mask_info = None
                    _, pos_attn = self.generate_until(single_pos_requests, mask_info)
                    _, neg_attn = self.generate_until(single_neg_requests, mask_info)

                    
                    # Get attention difference from original model
                    original_attn_diff, original_pos_attn, original_neg_attn = self.get_attn_diff(pos_attn, neg_attn)

                    attn_diff_results = {}
                    attn_diff_results['original'] = {
                        'attn_diff': original_attn_diff[0].item(),
                        'pos_attn': original_pos_attn[0].item(),
                        'neg_attn': original_neg_attn[0].item(),
                    }
                    
                    
                    ablation_count = 0
                    for layer in tqdm(range(self.head_masking_start_layer, self.head_masking_end_layer + 1), desc=f"Sample {doc_id} - Layers"):
                        for head in tqdm(range(self.head_masking_start_head, self.head_masking_end_head + 1), desc=f"Sample {doc_id} - Heads"):
                            mask_info = self.update_mask_info(layer, head)
                            _, pos_attn = self.generate_until(single_pos_requests, mask_info)
                            _, neg_attn = self.generate_until(single_neg_requests, mask_info)
                            
                            # Get attention difference
                            ablated_attn_diff, ablated_pos_attn, ablated_neg_attn = self.get_attn_diff(pos_attn, neg_attn)
                            attn_diff_results[(layer, head)] = {
                                'attn_diff': ablated_attn_diff[0].item(),
                                'pos_attn': ablated_pos_attn[0].item(),
                                'neg_attn': ablated_neg_attn[0].item()
                            }
                            

                    # Sort attention differences by magnitude (most important first) - ONLY ablated results
                    # Filter out 'original' from sorting
                    ablated_items = [(k, v) for k, v in attn_diff_results.items() if k != 'original']
                    sorted_ablated_attn_diff = sorted(
                        ablated_items, 
                        key=lambda x: x[1]['attn_diff'], 
                        reverse=True
                    )
                    
                    # Take only top-10
                    # top_10_sorted_attn_diff = sorted_ablated_attn_diff[:10]
                    # Take all ablated results
                    sorted_attn_diff = sorted_ablated_attn_diff
                    del sorted_ablated_attn_diff, ablated_items
                    
                    # Create the final result for this sample
                    sample_result = {
                        "image_path": image_path,
                        "text_input": text_input,
                        "pos_target_word": pos_target_word,
                        "neg_target_word": neg_target_word,
                        # "sorted_attn_diff": top_10_sorted_attn_diff,  # Only top-10 ablated results
                        "sorted_attn_diff": sorted_attn_diff, # All ablated results
                        "original_attn_diff": attn_diff_results['original'],  # Keep original separate
                        "doc_id": doc_id,
                        "timestamp": timestamp,
                        "output_file": output_file
                    }

                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(sample_result, ensure_ascii=False) + '\n')

                    del sample_result
                    print(f"Sample {doc_id}: Completed attention difference analysis - saved to {output_file}")
                    
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"❌ OOM ERROR on Sample {doc_id}!")
                        print(f"❌ Image path: {image_path}")
                        print(f"❌ Text length: {len(text_input)} chars")
                        if torch.cuda.is_available():
                            memory_after = torch.cuda.memory_allocated() / 1024**3  # GB
                            print(f"❌ GPU memory at OOM: {memory_after:.2f} GB")
                        # Continue with next sample instead of crashing
                        continue
                    else:
                        raise e
        
        print(f"Completed attention difference analysis for all {len(requests)} samples!")
        print(f"All results saved to: {output_file}")
        
        # Return minimal response for framework compatibility
        # Just return a simple acknowledgment since results are already saved to file
        return [json.dumps({"status": "completed", "output_file": output_file, "num_samples": len(requests)})]

    def generate_until_with_head_masking_inference(self, requests: List[Instance]) -> List[str]:
        """
        Process samples individually to avoid OOM issues.
        Writes results incrementally to JSONL file and returns minimal response for framework compatibility.
        """
        import json
        import os
        from datetime import datetime
        
        # # Create output directory and filename
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # # Get output path from configuration or use default
        # if self.cfg.get('metadata').get("output_dir", None):
        #     base_output_dir = self.cfg.get('metadata').get("output_dir")
        # else:
        #     # Fallback to current working directory
        #     base_output_dir = os.getcwd()
        
        # # Create subdirectory for attention difference results
        # output_dir = os.path.join(base_output_dir, f"attn_diff_results/{self.model_name}")
        # os.makedirs(output_dir, exist_ok=True)
        # output_file = os.path.join(output_dir, f"attn_diff_results_{timestamp}.jsonl")
        
        print(f"Starting attention difference analysis for {len(requests)} samples...")
        # print(f"Results will be saved to: {output_file}")
        
        # Process each sample individually to avoid memory accumulation
        with torch.no_grad():
            for request in tqdm(requests, total=len(requests), desc="Processing samples"):
                # Extract sample information
                contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
                sample_data = self.task_dict[task][split][doc_id]
                
                # Get image path and text input
                # image_path = sample_data.get('image_path', '')
                # text_input = contexts[0] if isinstance(contexts, list) else contexts
                # pos_target_word = sample_data.get('positive_target_word', '')
                # neg_target_word = sample_data.get('negative_target_word', '')
                    
                try:
                    # Create single-sample requests
                    # single_pos_requests = self.update_samples_with_target_word([request], 'positive')
                    # single_neg_requests = self.update_samples_with_target_word([request], 'negative')
                    
                    # Get attention from original model for this sample
                    mask_info = None
                    _, pos_attn = self.generate_until(single_pos_requests, mask_info)
                    # _, neg_attn = self.generate_until(single_neg_requests, mask_info)

                    
                    # Get attention difference from original model
                    original_attn_diff, original_pos_attn, original_neg_attn = self.get_attn_diff(pos_attn, neg_attn)
                    
                    # Store attention differences for this sample
                    attn_diff_results = {}
                    attn_diff_results['original'] = {
                        'attn_diff': original_attn_diff[0].item(),
                        'pos_attn': original_pos_attn[0].item(),
                        'neg_attn': original_neg_attn[0].item(),
                    }
                    
                    
                    ablation_count = 0
                    for layer in tqdm(range(self.head_masking_start_layer, self.head_masking_end_layer + 1), desc=f"Sample {doc_id} - Layers"):
                        for head in tqdm(range(self.head_masking_start_head, self.head_masking_end_head + 1), desc=f"Sample {doc_id} - Heads"):
                            ablation_count += 1
                            
                            mask_info = self.update_mask_info(layer, head)
                            
                            # Get attention from ablated model for this sample
                            _, pos_attn = self.generate_until(single_pos_requests, mask_info)
                            _, neg_attn = self.generate_until(single_neg_requests, mask_info)
                            
                            # Get attention difference
                            ablated_attn_diff, ablated_pos_attn, ablated_neg_attn = self.get_attn_diff(pos_attn, neg_attn)
                            
                            attn_diff_results[(layer, head)] = {
                                'attn_diff': ablated_attn_diff[0].item(),
                                'pos_attn': ablated_pos_attn[0].item(),
                                'neg_attn': ablated_neg_attn[0].item()
                            }
                            

                    # Sort attention differences by magnitude (most important first) - ONLY ablated results
                    # Filter out 'original' from sorting
                    ablated_items = [(k, v) for k, v in attn_diff_results.items() if k != 'original']
                    sorted_ablated_attn_diff = sorted(
                        ablated_items, 
                        key=lambda x: x[1]['attn_diff'], 
                        reverse=True
                    )
                    
                    # Take only top-10
                    # top_10_sorted_attn_diff = sorted_ablated_attn_diff[:10]
                    # Take all ablated results
                    sorted_attn_diff = sorted_ablated_attn_diff
                    del sorted_ablated_attn_diff, ablated_items
                    
                    # Create the final result for this sample
                    sample_result = {
                        "image_path": image_path,
                        "text_input": text_input,
                        "pos_target_word": pos_target_word,
                        "neg_target_word": neg_target_word,
                        # "sorted_attn_diff": top_10_sorted_attn_diff,  # Only top-10 ablated results
                        "sorted_attn_diff": sorted_attn_diff, # All ablated results
                        "original_attn_diff": attn_diff_results['original'],  # Keep original separate
                        "doc_id": doc_id,
                        "timestamp": timestamp,
                        "output_file": output_file
                    }
                    
                    # Write result immediately to JSONL file
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(sample_result, ensure_ascii=False) + '\n')
                    
                    # Clear GPU memory after processing each sample
                    del sample_result
                    print(f"Sample {doc_id}: Completed attention difference analysis - saved to {output_file}")
                    
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"❌ OOM ERROR on Sample {doc_id}!")
                        print(f"❌ Image path: {image_path}")
                        print(f"❌ Text length: {len(text_input)} chars")
                        if torch.cuda.is_available():
                            memory_after = torch.cuda.memory_allocated() / 1024**3
                            print(f"❌ GPU memory at OOM: {memory_after:.2f} GB")
                        continue
                    else:
                        raise e

        print(f"Completed attention difference analysis for all {len(requests)} samples!")
        print(f"All results saved to: {output_file}")
        return [json.dumps({"status": "completed", "output_file": output_file, "num_samples": len(requests)})]

    def get_attn_diff(self, pos_attn, neg_attn):
        """
        Calculate attention difference for batch processing.
        Returns attention differences for all samples in the batch.
        """
        
        attn_diff = neg_attn - pos_attn  # Shape: (bs,)
        
        return attn_diff, pos_attn, neg_attn

    def update_mask_info(self, layer, head):
        md = self.cfg.get('metadata', {})
        if md.get("head_masking", False):
            mask_qkv = md.get("mask_qkv", ['q'])
            mask_scale_factor = md.get("mask_scale_factor", 0)
            if isinstance(mask_scale_factor, str):
                mask_scale_factor = float(mask_scale_factor)
            mask_info = {
                'layer': layer,
                'head': head,
                'mask_qkv': mask_qkv,
                'mask_scale_factor': mask_scale_factor,
            }
            return mask_info
        return None

    def update_samples_with_target_word(self, requests: List[Instance], target_type: str) -> List[Instance]:
        updated_requests = []
        for request in requests:
            contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
            sample_data = self.task_dict[task][split][doc_id]
            if target_type == 'positive':
                target_word = sample_data.get('positive_target_word', '')
            else:
                target_word = sample_data.get('negative_target_word', '')
            if isinstance(contexts, list):
                updated_contexts = [ctx.replace('{target word}', target_word) for ctx in contexts]
            else:
                updated_contexts = contexts.replace('{target word}', target_word)
            updated_args = (updated_contexts, doc_to_target, doc_to_visual, doc_id, task, split)
            updated_request = Instance(
                request_type=request.request_type,
                arguments=updated_args,
                idx=request.idx,
                metadata={"task": task, "doc_id": doc_id, "repeats": request.repeats}
            )
            updated_requests.append(updated_request)
        return updated_requests

    def generate_until_with_save_attention(self, requests: List[Instance]) -> List[str]:
        """
        Process samples individually to avoid OOM issues.
        Writes results incrementally to JSONL file and returns minimal response for framework compatibility.
        """
        import json
        import os
        from datetime import datetime
        
        # Create output directory and filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Get output path from configuration or use default
        if self.cfg.get('metadata').get("output_dir", None):
            base_output_dir = self.cfg.get('metadata').get("output_dir")
        else:
            # Fallback to current working directory
            base_output_dir = os.getcwd()
        
        # Create subdirectory for attention difference results
        output_dir = os.path.join(base_output_dir, f"save_attention_results/{self.model_name}")
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f"save_attention_results_{timestamp}.jsonl")
        
        print(f"Starting save attention results for {len(requests)} samples...")
        print(f"Results will be saved to: {output_file}")
        
        # Process each sample individually to avoid memory accumulation
        with torch.no_grad():
            for request in tqdm(requests, total=len(requests), desc="Processing samples"):
                # Extract sample information
                contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
                sample_data = self.task_dict[task][split][doc_id]
                
                # Get image path and text input
                image_path = sample_data.get('image_path', '')
                text_input = contexts[0] if isinstance(contexts, list) else contexts
                pos_target_word = sample_data.get('positive_target_word', '')
                neg_target_word = sample_data.get('negative_target_word', '')

                print(f"Processing sample {doc_id} of {len(requests)}")
                    
                try:
                    # Create single-sample requests
                    single_pos_requests = self.update_samples_with_target_word([request], 'positive')
                    single_neg_requests = self.update_samples_with_target_word([request], 'negative')
                    
                    # Get attention from original model for this sample
                    mask_info = None
                    _, pos_attn = self.generate_until(single_pos_requests, mask_info)

                    # Save tensors using torch.save()
                    pt_output_file = os.path.join(output_dir, f"samples/pos_attn/doc_id_{doc_id}.pt")
                    if not os.path.exists(pt_output_file):
                        os.makedirs(os.path.dirname(pt_output_file), exist_ok=True)
                    torch.save(pos_attn, pt_output_file)
                    del pos_attn

                    _, neg_attn = self.generate_until(single_neg_requests, mask_info)
                    pt_output_file = os.path.join(output_dir, f"samples/neg_attn/doc_id_{doc_id}.pt")
                    if not os.path.exists(pt_output_file):
                        os.makedirs(os.path.dirname(pt_output_file), exist_ok=True)
                    torch.save(neg_attn, pt_output_file)
                    del neg_attn
                    
                    # Create the final result for this sample
                    meta_data = {
                        "doc_id": doc_id,
                        "image_path": image_path,
                        "text_input": text_input,
                        "pos_target_word": pos_target_word,
                        "neg_target_word": neg_target_word,
                        # "pos_attn": pos_attn_np,
                        # "neg_attn": neg_attn_np,
                        "timestamp": timestamp,
                    }
                    
                    # Write result immediately to JSONL file
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(meta_data, ensure_ascii=False) + '\n')
                    
                    # Clear GPU memory after processing each sample
                    del meta_data
                    print(f"Sample {doc_id}: Completed save attention results - saved to {output_file}")
                    
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"❌ OOM ERROR on Sample {doc_id}!")
                        print(f"❌ Image path: {image_path}")
                        print(f"❌ Text length: {len(text_input)} chars")
                        if torch.cuda.is_available():
                            memory_after = torch.cuda.memory_allocated() / 1024**3  # GB
                            print(f"❌ GPU memory at OOM: {memory_after:.2f} GB")
                        # Continue with next sample instead of crashing
                        continue
                    else:
                        raise e
        
        print(f"Completed attention difference analysis for all {len(requests)} samples!")
        print(f"All results saved to: {output_file}")
        
        # Return minimal response for framework compatibility
        # Just return a simple acknowledgment since results are already saved to file
        return [json.dumps({"status": "completed", "output_file": output_file, "num_samples": len(requests)})]

    def generate_until_with_save_visual_attn_entropy(self, requests: List[Instance]) -> List[str]:
        """
        Process samples individually to avoid OOM issues.
        Writes results incrementally to JSONL file and returns minimal response for framework compatibility.
        """
        import json
        import os
        from datetime import datetime
        
        # Create output directory and filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Get output path from configuration or use default
        if self.cfg.get('metadata').get("output_dir", None):
            base_output_dir = self.cfg.get('metadata').get("output_dir")
        else:
            # Fallback to current working directory
            base_output_dir = os.getcwd()
        
        # Create subdirectory for visual attention entropy results
        output_dir = os.path.join(base_output_dir, f"save_visual_attn_entropy_results/{self.cfg.get('task', None)}/{self.model_name}")
        # output_dir = os.path.join(base_output_dir, f"save_attention_results/{self.model_name}")
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f"visual_attn_entropy_results_{timestamp}.jsonl")
        
        print(f"Starting save attention results for {len(requests)} samples...")
        print(f"Results will be saved to: {output_file}")
        
        # Process each sample individually to avoid memory accumulation
        with torch.no_grad():
            for request in tqdm(requests, total=len(requests), desc="Processing samples"):
                # Extract sample information
                contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
                sample_data = self.task_dict[task][split][doc_id]
                
                # Get image path and text input
                image_path = sample_data.get('image_source', '')
                text_input = contexts[0] if isinstance(contexts, list) else contexts

                print(f"Processing sample {doc_id} of {len(requests)}")
                    
                try:                   
                    # Get attention from original model for this sample
                    
                    _, attn = self.generate_until([request]) # attn: (nl, bs=1, nh, q_len, k_len)

                    # 1. Calculate visual attention entropy
                    visual_attn_entropy = self.calculate_visual_attn_entropy(attn)

                    # Create the final result for this sample
                    final_result = {
                        "doc_id": doc_id,
                        "image_path": image_path,
                        "text_input": text_input,
                        "timestamp": timestamp,
                        "visual_attn_entropy": visual_attn_entropy.cpu().numpy().tolist()
                    }
                    
                    # 2. Save the result to JSONL file
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(final_result, ensure_ascii=False) + '\n')

                    # Clear GPU memory after processing each sample
                    del final_result
                    print(f"Sample {doc_id}: Completed save attention results - saved to {output_file}")
                    
                except RuntimeError as e:
                    # if "out of memory" in str(e).lower():
                    #     print(f"❌ OOM ERROR on Sample {doc_id}!")
                    #     print(f"❌ Image path: {image_path}")
                    #     print(f"❌ Text length: {len(text_input)} chars")
                    #     if torch.cuda.is_available():
                    #         memory_after = torch.cuda.memory_allocated() / 1024**3  # GB
                    #         print(f"❌ GPU memory at OOM: {memory_after:.2f} GB")
                    #     # Continue with next sample instead of crashing
                    #     continue
                    # else:
                    #     raise e
                    raise e
        
        print(f"Completed attention difference analysis for all {len(requests)} samples!")
        print(f"All results saved to: {output_file}")
        
        # Return minimal response for framework compatibility
        # Just return a simple acknowledgment since results are already saved to file
        return [json.dumps({"status": "completed", "output_file": output_file, "num_samples": len(requests)})]

    def calculate_visual_attn_entropy(self, attn):

        # calculate head-wise visual attention entropy w/ first-step attention map
        # attn: (num_layers, bs, num_heads, query_len, key_len) # in first-step, key_len == query_len

        # llava-1.5 case (hard coding)
        vis_start_idx = self.cfg.get('metadata').get("vis_start_idx", 1) # TODO: Check if this is correct
        vis_end_idx = self.cfg.get('metadata').get("vis_end_idx", 577) # TODO: Check if this is correct

        if self.model_name == "llava-v1.6-vicuna-7b":
            vis_start_idx = 1
            vis_end_idx = 2145
        
        visual_attn = attn[:, 0, :, :, vis_start_idx:vis_end_idx] # (nl, nh, q_len, visual_token_len)
        visual_attn_last = visual_attn[:, :, -1, :] # (nl, nh, visual_token_len)
        visual_attn_last_sum = visual_attn_last.sum(dim=-1, keepdim=True) # (nl, nh, 1)

        visual_attn_norm = visual_attn_last / (visual_attn_last_sum + 1e-6).float() # (nl, nh, 1)
        visual_attn_entropy = -torch.sum(visual_attn_norm * torch.log(visual_attn_norm + 1e-3), dim=-1) # (nl, nh)

        return visual_attn_entropy



    def generate_until_with_bias_retrieved_samples_save_attention(self, requests: List[Instance]) -> List[str]:
        """
        Process samples individually to avoid OOM issues.
        Writes results incrementally to JSONL file and returns minimal response for framework compatibility.
        """
        import json
        import os
        from datetime import datetime
        
        # Create output directory and filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Get output path from configuration or use default
        if self.cfg.get('metadata').get("output_dir", None):
            base_output_dir = self.cfg.get('metadata').get("output_dir")
        else:
            # Fallback to current working directory
            base_output_dir = os.getcwd()
        
        # Get bias retrieved samples list
        bias_retrieved_samples_list = self.cfg.get('metadata').get("bias_retrieved_samples_list", None)
        if bias_retrieved_samples_list is not None:
            bias_retrieved_samples_list = json.load(open(bias_retrieved_samples_list[self.model_name]))
        else:
            bias_retrieved_samples_list = []
        
        # Create subdirectory for attention difference results
        output_dir = os.path.join(base_output_dir, f"save_attention_results/{self.cfg.get('task', None)}/{self.model_name}")
        # output_dir = os.path.join(base_output_dir, f"save_attention_results/{self.model_name}")
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f"bias_retrieved_samples/save_attention_results_{timestamp}.jsonl")
        
        print(f"Starting save attention results for {len(requests)} samples...")
        print(f"Results will be saved to: {output_file}")
        
        # Process each sample individually to avoid memory accumulation
        with torch.no_grad():
            for request in tqdm(requests, total=len(requests), desc="Processing samples"):
                # Extract sample information
                contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
                sample_data = self.task_dict[task][split][doc_id]
                if doc_id not in bias_retrieved_samples_list:
                    continue
                
                # Get image path and text input
                # image_path = sample_data.get('image_path', '')
                # text_input = contexts[0] if isinstance(contexts, list) else contexts

                print(f"Processing sample {doc_id} of {len(requests)}")
                    
                try:                   
                    # Get attention from original model for this sample
                    mask_info = None
                    _, attn = self.generate_until([request], mask_info)

                    # Save tensors using torch.save()
                    pt_output_file = os.path.join(output_dir, f"bias_retrieved_samples/doc_id_{doc_id}.pt")
                    if not os.path.exists(pt_output_file):
                        os.makedirs(os.path.dirname(pt_output_file), exist_ok=True)
                    torch.save(attn, pt_output_file)
                    del attn
                    
                    # Create the final result for this sample
                    # meta_data = {
                    #     "doc_id": doc_id,
                    #     "image_path": image_path,
                    #     "text_input": text_input,
                    #     "timestamp": timestamp,
                    # }
                    
                    # Write result immediately to JSONL file
                    del sample_data['image']
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(sample_data, ensure_ascii=False) + '\n')
                    
                    # Clear GPU memory after processing each sample
                    del sample_data
                    print(f"Sample {doc_id}: Completed save attention results - saved to {output_file}")
                    
                except RuntimeError as e:
                    # if "out of memory" in str(e).lower():
                    #     print(f"❌ OOM ERROR on Sample {doc_id}!")
                    #     print(f"❌ Image path: {image_path}")
                    #     print(f"❌ Text length: {len(text_input)} chars")
                    #     if torch.cuda.is_available():
                    #         memory_after = torch.cuda.memory_allocated() / 1024**3  # GB
                    #         print(f"❌ GPU memory at OOM: {memory_after:.2f} GB")
                    #     # Continue with next sample instead of crashing
                    #     continue
                    # else:
                    #     raise e
                    raise e
        
        print(f"Completed attention difference analysis for all {len(requests)} samples!")
        print(f"All results saved to: {output_file}")
        
        # Return minimal response for framework compatibility
        # Just return a simple acknowledgment since results are already saved to file
        return [json.dumps({"status": "completed", "output_file": output_file, "num_samples": len(requests)})]


    def generate_until_multi_round(self, requests) -> List[str]:
        raise NotImplementedError("TODO: Implement multi-round generation for LLaVA")

    @classmethod
    def from_config(cls, cfg, model_args=None):
        """
        Instantiate from config dictionary.
        
        Args:
            cfg: Task configuration dictionary
            model_args: Model arguments string (e.g., "pretrained=liuhaotian/llava-v1.6-vicuna-7b,attn_implementation=eager")
        """
        # Extract model-specific config from the task config
        # You can access task-specific parameters here
        generation_kwargs = cfg.get("generation_kwargs", {})
        
        # Parse model arguments if provided
        if model_args:
            from lmms_eval.utils import simple_parse_args_string
            parsed_model_args = simple_parse_args_string(model_args)
            pretrained = parsed_model_args.get("pretrained", "liuhaotian/llava-v1.6-vicuna-7b")
            device = parsed_model_args.get("device", "cuda:0")
            batch_size = parsed_model_args.get("batch_size", 1)
            attn_implementation = parsed_model_args.get("attn_implementation", "eager")
            conv_template = parsed_model_args.get("conv_template", "vicuna_v1")
            use_cache = parsed_model_args.get("use_cache", True)
            truncate_context = parsed_model_args.get("truncate_context", False)
        else:
            pretrained = "liuhaotian/llava-v1.6-vicuna-7b"
            device = "cuda:0"
            batch_size = 1
            attn_implementation = "eager"
            conv_template = "vicuna_v1"
            use_cache = True
            truncate_context = False
        
        return cls(
            pretrained=pretrained,  # Extract from model_args instead of hardcoding
            device=device,
            batch_size=batch_size,
            attn_implementation=attn_implementation,
            conv_template=conv_template,
            use_cache=use_cache,
            truncate_context=truncate_context,
            cfg=cfg,  # Pass the full task config
        )

    