import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid

# from llava_med_v1.llava.model.llava import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_med_v1.llava.conversation import conv_templates, SeparatorStyle
from llava_med_v1.llava import LlavaLlamaForCausalLM
from llava_med_v1.llava.utils import disable_torch_init
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from llava_med_v15.llava.mm_utils import get_model_name_from_path

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"

from PIL import Image
import math
from transformers import set_seed, logging

logging.set_verbosity_error()

from baukit import TraceDict
from functools import partial
import json
import os
import re
import numpy as np
import copy
import time
# from utils import add_gaussian_noise_outside_bboxes

from PIL import Image
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.tokenizer = tokenizer
        self.start_len = None
        self.input_ids = input_ids

    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.start_len is None:
            self.start_len = self.input_ids.shape[1]
        else:
            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
            for keyword in self.keywords:
                if keyword in outputs:
                    return True
        return False
    
class Llava_med_v1():
    def __init__(self, model_name_or_path, gpu_id, **kwargs):
        model_path = os.path.expanduser(model_name_or_path)
        self.device = torch.device(f'cuda:{gpu_id}')
        self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, use_cache=True).to(device=self.device, dtype=torch.float16)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.image_processor = CLIPImageProcessor.from_pretrained(self.model.config.mm_vision_tower, torch_dtype=torch.float16)
        self.path = model_name_or_path
        self.vision_tower = self.model.model.vision_tower[0].to(device=self.device, dtype=torch.float16)
        # replace_llama_modality_adaptive()
        # self.model.to(self.device)
        mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
        self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            self.tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            
        self.vision_config = self.vision_tower.config
        self.vision_config.im_patch_token = self.tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
        self.vision_config.use_im_start_end = mm_use_im_start_end
        if mm_use_im_start_end:
            self.vision_config.im_start_token, self.vision_config.im_end_token = self.tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
        self.image_token_len = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
    
    def evaluate(self, prompt, filepath):
        
        qs = prompt
        qs = qs.replace('<image>', '').strip()
        cur_prompt = qs

        # prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"

        image = Image.open(filepath)
        image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        print(image_tensor.shape)
        images = image_tensor.unsqueeze(0).half().to(device=self.device)
        if getattr(self.model.config, 'mm_use_im_start_end', False):
            qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * self.image_token_len + DEFAULT_IM_END_TOKEN
        else:
            qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * self.image_token_len
        cur_prompt = cur_prompt + '\n' + '<image>'
        
        conv_mode = "simple"
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        prompt = conv.get_prompt()
        # print(prompt)
        
        inputs = self.tokenizer([prompt])
        input_ids = torch.as_tensor(inputs.input_ids).to(device=self.device)
        
        keywords = ['###']
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
        
        start_time = time.time()
        with torch.inference_mode():
            output = self.model.generate(
                input_ids,
                images=images,
                max_new_tokens=128,
                do_sample=False,
                stopping_criteria=[stopping_criteria],
                use_cache=True)
        end_time = time.time()
        elapsed_time = end_time - start_time
        token_num = output_ids.shape[-1]
        tokens_per_second = token_num / elapsed_time
        # print('tokens_per_second', tokens_per_second)
        
        output_ids = output['sequences']
        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        
        return outputs
    
    def evaluate_of_caption(self, prompt, filepath, caption):
        
        # ## blank prompt
        # prompt = f"<image>\nASSISTANT:"
        
        prompt = f"USER: The given image depicts the following scene: {caption}\n \
Please directly answer the following question from the image description, without guessing or reasoning. Question: \
{prompt}\nASSISTANT:"
        # prompt = f"USER:The given image depicts the following scene: {caption} {prompt}\nASSISTANT:"
        qs = prompt
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
        conv_mode = "llava_v1"
        
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        inputs = self.processor(text=prompt, return_tensors='pt').to(self.device)
        # generate_ids = self.model.generate(**inputs, max_length=128)
        generate_ids = self.model.generate(**inputs, max_new_tokens=32)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_of_caption_img(self, prompt, filepath, caption):
        image = Image.open(filepath)
        
        # ## blank prompt
        # prompt = f"<image>\nASSISTANT:"
        
        # prompt = f"USER: The given image depicts the following scene: {caption} {prompt}\nASSISTANT:"
        prompt = f"<image>\nUSER: The given image describe that {caption} {prompt}\nASSISTANT:"

        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        # generate_ids = self.model.generate(**inputs, max_length=128)
        generate_ids = self.model.generate(**inputs, max_new_tokens=32)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output

    
    def evaluate_with_intervention(self, prompt, filepath, interventions, intervention_fn):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output

        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            intervene = partial(intervention_fn, start_edit_location='lt')
        layers_to_intervene = list(interventions.keys())
        
        qs = prompt
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
        
        conv_mode = "vicuna_v1"
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

        image = Image.open(filepath).convert("RGB")
        image_tensor = process_images([image], self.image_processor, self.model.config)[0]

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
        
        start_time = time.time()
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(
                input_ids,
                images=image_tensor.unsqueeze(0).half().to(self.device),
                max_new_tokens=128,
                do_sample=False,
                stopping_criteria=[stopping_criteria],
                use_cache=True)
        end_time = time.time()
        elapsed_time = end_time - start_time
        token_num = generate_ids.shape[-1]
        tokens_per_second = token_num / elapsed_time
        
        outputs = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0].strip()

        return outputs
    
    def evaluate_with_intervention_youare(self, prompt, filepath, interventions, intervention_fn):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output

        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            intervene = partial(intervention_fn, start_edit_location='lt')
        layers_to_intervene = list(interventions.keys())
        
        qs = prompt
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
        conv_mode = "llava_v1"
        
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"

        image = load_image(filepath)
        
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
        image_tensor = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(
                input_ids,
                images=image_tensor.unsqueeze(0).half().to(self.device),
                max_new_tokens=128,
                do_sample=False,
                use_cache=False)
            
        output = self.tokenizer.batch_decode(generate_ids[:, :], skip_special_tokens=True)[0]
        output = output.strip()
        if output.endswith(stop_str):
            output = output[:-len(stop_str)]
        output = output.strip()
        # # post process
        # replace_text = 'ASSISTANT: '
        # output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_with_intervention_youare_offset(self, prompt, filepath, interventions, intervention_fn):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output
        
        qs = prompt
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
        conv_mode = "llava_v1"
        
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"

        image = load_image(filepath)
        
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
        image_tensor = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        
        HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(self.model.config.num_hidden_layers)]
        with torch.inference_mode():
            with TraceDict(self.model, HEADS) as ret:
                output = self.model(input_ids, images=image_tensor.unsqueeze(0).half().to(self.device), output_hidden_states = True)
        query_hidden_states = [ret[head].output.squeeze() for head in HEADS]
        # head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
        
        interventions_iter = copy.deepcopy(interventions)
        for name, interv in interventions_iter.items():
            layer = int(name.split('.')[-3])
            query = query_hidden_states[layer][-1].reshape(32, 128)
            for i, (head, direction, _, generator) in enumerate(interv):
                offset = generator(query[head].float())
                offset = offset.half()
                direction = direction + offset.detach().cpu().numpy()
                direction = direction / np.linalg.norm(direction)
                interv[i] = (head, direction, _)
        
        if interventions_iter == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            intervene = partial(intervention_fn, start_edit_location='lt', interventions=interventions_iter)
        layers_to_intervene = list(interventions_iter.keys())
        
        start_time = time.time()
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(
                input_ids,
                images=image_tensor.unsqueeze(0).half().to(self.device),
                max_new_tokens=128,
                do_sample=False,
                use_cache=False)
        end_time = time.time()
        elapsed_time = end_time - start_time
        token_num = generate_ids.shape[-1]
        tokens_per_second = token_num / elapsed_time
        print(tokens_per_second)
        output = self.tokenizer.batch_decode(generate_ids[:, :], skip_special_tokens=True)[0]
        output = output.strip()
        if output.endswith(stop_str):
            output = output[:-len(stop_str)]
        output = output.strip()
        # # post process
        # replace_text = 'ASSISTANT: '
        # output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    

    def evaluate_with_i2t(self, prompt, filepath, intervention):    
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        generate_ids = self.model.generate(**inputs, max_length=128, intervention=intervention)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_with_multiple_intervention(self, prompt, filepath, interventions, intervention_fn=None):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output
     
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        ## end of image
        special_image_token_mask = inputs['input_ids'] == self.model.config.image_token_index
        idx_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - special_image_token_mask.shape[-1]
        
        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            # intervene = partial(intervention_fn, special_tokens_location={'img':idx_special_image_tokens.item(), 'lt':-1})
            intervene = partial(intervention_fn, special_tokens_location={'img':slice(1, idx_special_image_tokens.item()), 'lt':-1})
        layers_to_intervene = list(interventions.keys())
        
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(**inputs, max_length=128, )
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output

    def evaluate_with_multiple_intervention2(self, prompt, filepath, interventions, intervention_fn=None):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output
     
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        ## end of image
        special_image_token_mask = inputs['input_ids'] == self.model.config.image_token_index
        idx_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - special_image_token_mask.shape[-1]
        
        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            # intervene = partial(intervention_fn, special_tokens_location={'img':idx_special_image_tokens.item(), 'lt':-1})
            intervene = partial(intervention_fn, special_tokens_location={'img':slice(1, idx_special_image_tokens.item()), 'lt':-1})
        layers_to_intervene = list(interventions.keys())
        
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(**inputs, max_length=128)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output

    def get_activations(self, prompt, filepath, region=None, pos_type=None, param=None):
        HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(self.model.config.num_hidden_layers)]
        MLPS = [f"model.layers.{i}.mlp" for i in range(self.model.config.num_hidden_layers)]
        WEIGHTS = [f"model.layers.{i}.self_attn.attn_weight" for i in range(self.model.config.num_hidden_layers)]
        # prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        
        def modify_fn(input, layer_name, bboxes_indices, img_indices, ratio=1):
            output = input
            if 'weight' in layer_name:
                assert bboxes_indices.shape[0] == input.shape[-1]
                attn_max = output.max(dim=-1)[0]
                img_attn_mean = output[:, :, :, img_indices].mean(dim=-1)
                offset = torch.zeros_like(output)
                offset_value = ratio * (attn_max.unsqueeze(-1) - img_attn_mean.unsqueeze(-1))
                offset[:, :, :, bboxes_indices] = offset_value
                output = output + offset
            return output
        
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

        if type(filepath) == list:
            image = [Image.open(p).convert("RGB") for p in filepath]
            image_tensor = process_images(image, self.image_processor, self.model.config)
            image_tensor = image_tensor.half().to(self.device)
        else:
            image = Image.open(filepath).convert("RGB")
            
            ### 修改输入图片
            if pos_type == 'addnoise':
                save_path = f'{filepath[:-4]}_{pos_type}_{param}.jpg'
                if os.path.exists(save_path):
                    image = Image.open(save_path).convert("RGB")
                else:    
                    image = add_gaussian_noise_outside_bboxes(image, [r['bbox'] for r in region], sigma=param)
                image.save(save_path)   
                
            image_tensor = process_images([image], self.image_processor, self.model.config)[0]
            image_tensor = image_tensor.unsqueeze(0).half().to(self.device)
        ## end of image
        # special_image_token_mask = input_ids == self.model.config.image_token_index
        # idx_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - special_image_token_mask.shape[-1]
        
        ## 修改内部结构
        if pos_type == 'addnoise':
            intervene = partial(modify_fn, bboxes_indices=None, img_indices=None, ratio=None)
            INTERVEN = HEADS
        elif pos_type == 'attn_qk':
            width, height = image.size
            bboxes = [r['bbox'] for r in region]
            indices = [bbox_to_flattened_indices(b, width, height, 
                int(self.model.model.vision_tower.config.image_size / self.model.model.vision_tower.config.patch_size)) 
                       for b in bboxes]
            stacked_indices = np.stack(indices)
            # 对每个位置进行逻辑或操作（如果任一掩码为1，则结果为1）
            merged_indices = np.any(stacked_indices, axis=0).astype(np.int32)
            image_token_indice = torch.where(input_ids.squeeze() == IMAGE_TOKEN_INDEX)
            assert len(image_token_indice) == 1
            before_img_len, after_img_len = image_token_indice[0].item(), input_ids.shape[-1] - image_token_indice[0].item() - 1
            img_indices = np.concatenate([np.zeros(before_img_len), np.ones(merged_indices.shape[0]), np.zeros(after_img_len)]).astype(bool)
            bboxes_indices = np.concatenate([np.zeros(before_img_len), merged_indices, np.zeros(after_img_len)]).astype(bool)
            intervene = partial(modify_fn, bboxes_indices=bboxes_indices, img_indices=img_indices, ratio=param) 
            
            INTERVEN = HEADS+WEIGHTS
        else:
            intervene = partial(modify_fn, bboxes_indices=None, img_indices=None, ratio=None)
            INTERVEN = HEADS
        with torch.inference_mode():
            with TraceDict(self.model, INTERVEN, edit_output=intervene) as ret:
                output = self.model(input_ids, images=image_tensor, output_hidden_states = True)
            hidden_states = output.hidden_states
            hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
            hidden_states = hidden_states.detach().cpu().numpy()
            head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
            head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
            # mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS]
            # mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim = 0).squeeze().numpy()

        # return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states, [idx_special_image_tokens.item()]
        return hidden_states, head_wise_hidden_states, None, None
    
    
    def get_activations_only_text(self, prompt):
        HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(self.model.config.num_hidden_layers)]
        MLPS = [f"model.layers.{i}.mlp" for i in range(self.model.config.num_hidden_layers)]

        # inputs = self.processor(text=prompt, return_tensors='pt').to(self.device)
                
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

        # image_tensor = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

        with torch.inference_mode():
            with TraceDict(self.model, HEADS+MLPS) as ret:
                output = self.model(input_ids, output_hidden_states = True)
            hidden_states = output.hidden_states
            hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
            hidden_states = hidden_states.detach().cpu().numpy()
            head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
            head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
            mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS]
            mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim = 0).squeeze().numpy()

        return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states
    
    def get_projected_activations(self, prompt, filepath):

        image = Image.open(filepath)
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
                
        with torch.no_grad():
            img_activation, text_activation = self.model.get_projected_activation(inputs['input_ids'], inputs['pixel_values'])

        return img_activation.detach().cpu().numpy(), text_activation.detach().cpu().numpy(), inputs['input_ids']
    
    def batch_get_activation_after_intervention_offset(self, args, data, interventions={}, intervention_fn=None):
        response_list = []
        from tqdm import tqdm
        all_layer_wise_activations = []
        all_head_wise_activations = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = sample.copy()
            
            def id(head_output, layer_name): 
                return head_output
            
            qs = prompt
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
            conv_mode = "llava_v1"
            
            conv = conv_templates[conv_mode].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            # prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"

            image = load_image(image)
            
            input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
            image_tensor = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            
            HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(self.model.config.num_hidden_layers)]
            with torch.inference_mode():
                with TraceDict(self.model, HEADS) as ret:
                    output = self.model(input_ids, images=image_tensor.unsqueeze(0).half().to(self.device), output_hidden_states = True)
            query_hidden_states = [ret[head].output.squeeze() for head in HEADS]
            # head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
            
            interventions_iter = copy.deepcopy(interventions)
            for name, interv in interventions_iter.items():
                layer = int(name.split('.')[-3])
                query = query_hidden_states[layer][-1].reshape(32, 128)
                for i, (head, direction, _, generator) in enumerate(interv):
                    offset = generator(query[head].float())
                    offset = offset.half()
                    direction = direction + offset.detach().cpu().numpy()
                    direction = direction / np.linalg.norm(direction)
                    interv[i] = (head, direction, _)
            
            if interventions_iter == {}: 
                intervene = id
                layers_to_intervene = []
            else: 
                intervene = partial(intervention_fn, start_edit_location='lt', interventions=interventions_iter)
            layers_to_intervene = list(interventions_iter.keys())
            
            start_time = time.time()
            with torch.inference_mode():
                with TraceDict(self.model, HEADS, edit_output=intervene) as ret: 
                    output = self.model(
                        input_ids, images=image_tensor.unsqueeze(0).half().to(self.device), output_hidden_states = True)
                    hidden_states = output.hidden_states
                    hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
                    hidden_states = hidden_states.detach().cpu().numpy()
                    # head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
                    # head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
        
            all_layer_wise_activations.append(hidden_states[:,-1,:].copy())
            # all_head_wise_activations.append(head_wise_hidden_states[:,-1,:].copy())
        # return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states, [idx_special_image_tokens.item()]
        print("Saving layer wise activations")
        np.save(f'features/{args.model}_POPE_sample2_YR_I+Q_layer_wise_after_intervene.npy', all_layer_wise_activations)
        
        # print("Saving head wise activations")
        # np.save(f'features/{args.model_name}_{args.dataset_name}_head_wise_after_intervene.npy', all_head_wise_activations)


def bbox_to_flattened_indices(bbox, w, h, num_blocks):
    """
    计算边界框覆盖的展平块索引（考虑非整除尺寸）
    
    参数:
        bbox: [x_min, y_min, x_max, y_max] (绝对坐标)
        w: 图像宽度 (像素)
        h: 图像高度 (像素)
    
    返回:
        list: 展平序列中的索引列表 (k 值)
    """
    x_min, y_min, x_max, y_max = bbox
    
    # 计算块的实际尺寸
    block_width = w / num_blocks
    block_height = h / num_blocks
    
    # 确定覆盖的行和列范围
    j_min = max(0, int(np.floor(x_min / block_width)))
    j_max = min(num_blocks - 1, int(np.floor(x_max / block_width)))
    
    i_min = max(0, int(np.floor(y_min / block_height)))
    i_max = min(num_blocks - 1, int(np.floor(y_max / block_height)))
    
    # 收集所有覆盖的块索引
    covered_indices = []
    for i in range(i_min, i_max + 1):
        for j in range(j_min, j_max + 1):
            # 计算块的绝对坐标范围
            block_x_min = j * block_width
            block_x_max = (j + 1) * block_width
            block_y_min = i * block_height
            block_y_max = (i + 1) * block_height
            
            # 检查块与边界框是否有重叠
            x_overlap = (x_min < block_x_max) and (x_max > block_x_min)
            y_overlap = (y_min < block_y_max) and (y_max > block_y_min)
            
            if x_overlap and y_overlap:
                k = i * num_blocks + j
                covered_indices.append(k)
    
    # 创建全零数组
    mask = np.zeros(num_blocks * num_blocks, dtype=np.int32)
    
    # 过滤有效索引（0 到 total_length-1）
    valid_indices = [idx for idx in covered_indices if 0 <= idx < num_blocks * num_blocks]
    
    # 将有效索引位置设置为 1
    mask[valid_indices] = 1
    
    return mask
