import os
import sys
sys.path.insert(0, "TruthfulQA")

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import pandas as pd
import warnings
from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import Trace, TraceDict
import sklearn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.linear_model import LogisticRegression
import pickle
from functools import partial
import random
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from disease_knowledge_dataset.format_reasoning import parse_report_to_chain

import jsonlines
import json
from tqdm import tqdm
import copy
from PIL import Image

ENGINE_MAP = {
    'llama_7B': 'baffo32/decapoda-research-llama-7B-hf', 
    'alpaca_7B': 'circulus/alpaca-7b', 
    'vicuna_7B': 'AlekseyKorshuk/vicuna-7b', 
    'llama2_chat_7B': 'meta-llama/Llama-2-7b-chat-hf', 
    'llama2_chat_13B': 'meta-llama/Llama-2-13b-chat-hf', 
    'llama2_chat_70B': 'meta-llama/Llama-2-70b-chat-hf', 
    'llava_v1.5_7B': '/root/project/huggingface/llava-v1.5-7b-hf', 
    'mplug_owl2_7B': '/root/project/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/root/project/huggingface/sharegpt4v-7b',
}

# # global settings
# dimensions = ['privacy', 'bias', 'toxicity', 'hallucination', 'noise-injection', 'position-swapping', 'legality']


# response的格式
RESPONSE_DICT = {
    'prompt': '',
    'img_url': '',
    'response': '',
    'lan': ''
}

# 输入的格式
INPUT_DICT = {
    # 'index': 0,
    # 'img_url': '',
    # 'prompt': '',
    # 'lan': '',
    # 'type': 'free-text' # or choice
}

def process_data_mllmguard(data_path):
    """
    输入是 csv + img
    """
    data_list = []
    # img_base = os.path.join(data_path, 'imgs')
    img_base = data_path
    data = pd.read_csv(os.path.join(data_path, 'prompt.csv'))
    n = data.shape[0]
    cnt = 0
    if 'sequential' in data_path:
        for i in range(n):
            for j in range(2):
                new_result = INPUT_DICT.copy()
                new_result['index'] = cnt
                new_result['img_url'] = os.path.join(img_base, data.iat[i, j])
                new_result['prompt'] = data.iat[i, 2]
                new_result['lan'] = data.iat[i, 5]
                new_result['type'] = 'choice'
                data_list.append(new_result)
                cnt += 1
                
    elif 'non-existent' in data_path:
        for i in range(n):
            for j in range(2):
                new_result = INPUT_DICT.copy()
                new_result['index'] = cnt
                new_result['img_url'] = os.path.join(img_base, data.iat[i, 0])
                if j == 0:
                    new_result['prompt'] = data.iat[i, 1]
                    new_result['lan'] = data.iat[i, 5]
                else:
                    new_result['prompt'] = f'{data.iat[i,1]}(Please answer me with options) {data.iat[i, 2]}'
                    new_result['type'] = f'choice'
                    new_result['lan'] = data.iat[i, 5]
                data_list.append(new_result)
                cnt += 1
                
    elif 'noise-consistency' in data_path:
        for i in range(n):
            for j in range(2):
                new_result = INPUT_DICT.copy()
                new_result['index'] = cnt
                new_result['prompt'] = data.iat[i, 1]
                if j == 0:
                    new_result['img_url'] = os.path.join(img_base, data.iat[i, 0])
                else:
                    base_name = os.path.basename(data.iat[i, 0])
                    name, ext = os.path.splitext(base_name)
                    new_result['img_url'] = os.path.join(img_base, f'{name}_noise{ext}')
                    new_result['type'] = 'add_noise'
                new_result['lan'] = data.iat[i, 5]
                cnt += 1
                data_list.append(new_result)   
                                                     
    else:
        for i in range(n):
            new_result = INPUT_DICT.copy()
            new_result['index'] = i
            new_result['img_url'] = os.path.join(img_base, data.iat[i, 0])
            new_result['prompt'] = data.iat[i, 1]
            new_result['lan'] = data.iat[i, 4]
            data_list.append(new_result)
             
    return data_list

def process_data_mmharmfulbench(data_path):
    data_list = []
    img_base = os.path.join(data_path, 'images')
    # img_base = data_path
    file_path = os.path.join(data_path, 'questions.jsonl')
    with jsonlines.open(file_path, 'r') as reader:
        for i, line in enumerate(reader):
            new_result = INPUT_DICT.copy()
            new_result['index'] = i
            new_result['img_url'] = os.path.join(img_base, f'{i}.jpg')
            new_result['prompt'] = line['question']
            data_list.append(new_result)
    return data_list

def process_data_mmsafetybench(data_path, category):
    data_list = []
    img_base = os.path.join(data_path, 'imgs', category, 'SD')
    # img_base = data_path
    file_path = os.path.join(data_path, 'processed_questions', f'{category}.json')
    reader = json.load(open(file_path, 'r'))
    for key, item in reader.items():
        new_result = INPUT_DICT.copy()
        new_result['index'] = int(key)
        new_result['img_url'] = os.path.join(img_base, f'{key}.jpg')
        new_result['prompt'] = item['Question']
        new_result['lan'] = category.split('-')[-1]
        data_list.append(new_result)
    return data_list

def process_data_safebench(data_path):
    data_list = []
    img_base = os.path.join(data_path, 'images', 'SafeBench')
    # img_base = data_path
    file_path = os.path.join(data_path, 'question', 'safebench.csv')
    data = pd.read_csv(file_path)
    for i in range(data.shape[0]):
        new_result = INPUT_DICT.copy()
        new_result['index'] = i
        new_result['img_url'] = os.path.join(img_base, f'query_ForbidQI_{data.iat[i, 1]}_{data.iat[i, 2]}_6.png')
        new_result['prompt'] = data.iat[i, 4]
        new_result['lan'] = data.iat[i, 3]
        data_list.append(new_result)
    return data_list
'A car that seems to be parked illegally behind a legally parked car'

def process_data_pope(data_path, dataset, category):
    data_list = []
    # val_caps = json.load(open(os.path.join(data_path, 'captions_val2014.json')))['annotations']
    # val_captions = {c['image_id']: c['caption'] for c in val_caps} 
    # val_caps = json.load(open(os.path.join(data_path, 'captions_detail_revised_500.json')))
    # val_captions = {c['image']: c['answer'] for c in val_caps}
    img_base = os.path.join(data_path, 'images')
    
    ### load from parquet
    # img_base = data_path
    # file_path = os.path.join(data_path, 'questions', f'{category}-00000-of-00001.parquet')
    # dataset = load_dataset('parquet', data_files=file_path)['train']
    # for i in range(len(dataset)): 
    #     new_result = INPUT_DICT.copy()
    #     new_result['index'] = dataset[i]['question_id']
    #     img_name = dataset[i]['image_source']
    #     new_result['img_url'] = os.path.join(img_base, f'{img_name}.jpg')
    #     new_result['prompt'] = dataset[i]['question']
    #     new_result['lan'] = dataset[i]['category']
    #     new_result['ground_truth'] = dataset[i]['answer'] 
    #     data_list.append(new_result)
    
    
    anno_path = os.path.join(data_path, 'Annotations')
    annotations = {}
    file_path = os.path.join(data_path, 'questions', f'{dataset}_pope_{category}.json')
    with jsonlines.open(file_path, 'r') as reader:
        for line in reader:
            new_result = INPUT_DICT.copy()
            new_result['index'] = line['question_id']
            img_name = line['image']
            # new_result['caption'] = val_captions[img_name]
            
            # ### get instance annotation
            # xml_path = os.path.join(anno_path, img_name.replace('.jpg', '.xml'))
            # objects_string = parse_xml_to_dict(xml_path)
            # annotations[img_name] = objects_string
            
            # img_id = int(img_name.split('_')[-1][:-4])
            # caps = [c['caption'] for c in val_caps if img_id == c['image_id']]
            # new_result['caption'] = caps[-1]
            # import shutil
            # shutil.copy(os.path.join(img_base, img_name), os.path.join(data_path, 'images_500', img_name))
            new_result['img_url'] = os.path.join(img_base, img_name)
            new_result['prompt'] = line['text']
            new_result['lan'] = category
            new_result['ground_truth'] = line['label'] 
            data_list.append(new_result)
    
    # with open('/root/wtb/multimodal_alignment/mm_iti/data/POPE/annotations_500.json', 'w') as json_file:
    #     json.dump(annotations, json_file)
    return data_list

def process_data_pope_withcap(data_path, dataset, category):
    data_list = []
    # val_caps = json.load(open(os.path.join(data_path, 'captions_val2014.json')))['annotations']
    # val_captions = {c['image_id']: c['caption'] for c in val_caps} 
    # val_caps = json.load(open(os.path.join(data_path, 'captions_detail_revised_500.json')))
    # val_captions = {c['image']: c['answer'] for c in val_caps}
    img_base = os.path.join(data_path, 'images')
    
    ### load from parquet
    # img_base = data_path
    # file_path = os.path.join(data_path, 'questions', f'{category}-00000-of-00001.parquet')
    # dataset = load_dataset('parquet', data_files=file_path)['train']
    # for i in range(len(dataset)): 
    #     new_result = INPUT_DICT.copy()
    #     new_result['index'] = dataset[i]['question_id']
    #     img_name = dataset[i]['image_source']
    #     new_result['img_url'] = os.path.join(img_base, f'{img_name}.jpg')
    #     new_result['prompt'] = dataset[i]['question']
    #     new_result['lan'] = dataset[i]['category']
    #     new_result['ground_truth'] = dataset[i]['answer'] 
    #     data_list.append(new_result)
    

    ## 500 val
    cap_path = '/root/wtb/multimodal_alignment/mm_iti/data/POPE/captions_detail_train_500_llava_v1.5_7B_lht.json'
    question_path = '/root/wtb/multimodal_alignment/mm_iti/data/POPE/questions/coco_pope_adversarial.json'
    id_to_cap = {}
    with jsonlines.open(cap_path, 'r') as reader:
        for data in reader:
            for line in data:
                id_to_cap[line["image"]] = line["answer"]
    with jsonlines.open(question_path, 'r') as reader:
        for line in reader:
            new_result = INPUT_DICT.copy()
            new_result['index'] = line['question_id']
            img_name = line['image']
            new_result['caption'] = id_to_cap[img_name]
            
            new_result['img_url'] = os.path.join(img_base, img_name)
            new_result['prompt'] = line['text']
            new_result['lan'] = category
            new_result['ground_truth'] = line['label'] 
            data_list.append(new_result)


    # ## 3000train
    # anno_path = os.path.join(data_path, 'Annotations')
    # annotations = {}
    # # file_path = os.path.join(data_path, 'questions', f'{dataset}_pope_{category}.json')
    # file_path = '/root/wtb/multimodal_alignment/mm_iti/data/POPE/train_3k_complete_qs_oneline.json'
    # with jsonlines.open(file_path, 'r') as reader:
    #     for line in reader:
    #         new_result = INPUT_DICT.copy()
    #         new_result['index'] = line['question_id']
    #         img_name = line['image']
    #         new_result['caption'] = line['caption']['error_cap']
            
    #         new_result['img_url'] = os.path.join(img_base, img_name)
    #         new_result['prompt'] = line['text']
    #         new_result['lan'] = category
    #         new_result['ground_truth'] = line['label'] 
    #         data_list.append(new_result)
    
    # with open('/data/wtb/multimodal_alignment/mm_iti/data/POPE/annotations_500.json', 'w') as json_file:
    #     json.dump(annotations, json_file)
    return data_list
import xml.etree.ElementTree as ET
def parse_xml_to_dict(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    filename = root.find('filename').text
    width = int(root.find('./size/width').text)
    height = int(root.find('./size/height').text)

    objects = []

    # Iterate through all objects in the xml
    for obj in root.findall('object'):
        name = obj.find('name').text
        xmin = int(obj.find('bndbox/xmin').text)
        ymin = int(obj.find('bndbox/ymin').text)
        xmax = int(obj.find('bndbox/xmax').text)
        ymax = int(obj.find('bndbox/ymax').text)

        # Convert to relative coordinates
        xmin_rel = xmin / width
        ymin_rel = ymin / height
        xmax_rel = xmax / width
        ymax_rel = ymax / height

        objects.append(f'{name}: [{xmin_rel:.3f}, {ymin_rel:.3f}, {xmax_rel:.3f}, {ymax_rel:.3f}]')

    # Join the objects list into a string
    objects_string = ", ".join(objects)

    # Return the dictionary for this XML file
    return objects_string

def process_data_chair(data_path):
    data_list = []
    img_base = os.path.join(data_path, 'images')
    file_path = os.path.join(data_path, 'selected.txt')
    lines = open(file_path, 'r').readlines()
    for i, line in enumerate(lines):
        new_result = INPUT_DICT.copy()
        new_result['index'] = i + 1
        new_result['img_url'] = os.path.join(img_base, line.strip())
        new_result['prompt'] = 'Please describe this image in detail.'
        new_result['lan'] = None
        data_list.append(new_result)
    
    return data_list

def process_data_medheval(question_file, image_folder, type):
    data_list = []
    file_path = os.path.expanduser(question_file)
    if file_path.endswith('.jsonl'):
        # Handle JSONL files
        with open(file_path, 'r') as file:
            questions = [json.loads(line) for line in file]
    else:
        # Handle JSON files
        with open(file_path, 'r') as file:
            questions = json.load(file)
            
    for line in questions:
        if type == 'close':
            question = line["question"]
            type_key = 'question_type' if 'question_type' in line.keys() else 'ground_truth_type'
            if line[type_key] == "multi-choice":
                if len(line['choices']) >= 4:
                    question += f" Please select from the following choices: {line['choices']}"
                else:
                    continue
            else:
                question += " Please answer Yes or No."
            new_result = INPUT_DICT.copy()
            new_result['question_id'] = line['qid']
            new_result['img_url'] = os.path.join(image_folder, line['img_name'])
            new_result['prompt'] = question
            new_result[type_key] = line[type_key]
            new_result['ground_truth'] = line["answer"]
            if 'hallucination_type' in line.keys():
                new_result['hallucination_type'] = line['hallucination_type']
            
        elif type == 'open':
            question = line["question"]
            stru_ans = ""
            if line.__contains__("structured_answer"):
                stru_ans = line["structured_answer"]
            new_result = INPUT_DICT.copy()
            new_result['question_id'] = line['qid']
            new_result['img_url'] = os.path.join(image_folder, line['img_name'])
            new_result['prompt'] = question
            new_result['question_type'] = line['question_type']
            new_result['ground_truth'] = line["answer"]
            new_result['structured_answer'] = stru_ans
            
        data_list.append(new_result)
    return data_list

def process_data_harvard_pmc(question_file, image_folder, type):
    data_list = []
    file_path = os.path.expanduser(question_file)
    if not os.path.exists(file_path):
        file_path = file_path.replace(".jsonl", ".json")
        assert os.path.exists(file_path)
    if file_path.endswith('.jsonl'):
        # Handle JSONL files
        with open(file_path, 'r') as file:
            questions = [json.loads(line) for line in file]
    else:
        # Handle JSON files
        with open(file_path, 'r') as file:
            questions = json.load(file)
            
    for line in questions:
        if type == 'close':
            question = line["question"]
            type_key = 'ground_truth_type'
            question += " Please answer Yes or No."
            new_result = INPUT_DICT.copy()
            new_result['question_id'] = line['qid']
            new_result['img_url'] = os.path.join(image_folder, line['image'])
            new_result['prompt'] = question
            new_result[type_key] = "binary"
            new_result['ground_truth'] = line["answer"]
            
        elif type == 'open':
            question = "Generate a medical report summarizing the key findings in the given image."
            new_result = INPUT_DICT.copy()
            new_result['question_id'] = line['id']
            new_result['img_url'] = os.path.join(image_folder, line['image'])
            new_result['prompt'] = question
            new_result['ground_truth'] = line["answer"]
            
        data_list.append(new_result)
    return data_list

def process_data_gemex_activation(dataset_path, model, mode):
    all_prompts = []
    all_regions = []
    all_paths = []
    
    dataset = []
    with open(dataset_path, 'r') as f:
        for i, line in enumerate(f):
            if i >= 2000:
                break
            dataset.append(json.loads(line))
    for data in dataset:
        image_path = '/root/project/datasets/mimic_cxr_jpg/files/' + data['image']
        question = data['question']
        answer = data['answer']
        regions = data['regions']        
        
        if 'I+Q+A' in mode:
            if 'llava' in model:
                prompt = format_question_answer(question, answer)
                
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                prompt = f'{prefix}{prompt}'
                
            all_prompts.append(prompt)
            all_regions.append(regions)
            all_paths.append(image_path)
        elif 'I+R+Q' in mode:
            if 'llava' in model:
                img_prompt = format_question_woa(question)
                region_prompt = 'When answer the question, please pay attention to the following bounding box regions: '
                for region in regions:
                    re_p = f"{region['region_name']}: {region['bbox']}; "
                    region_prompt = region_prompt + re_p
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{region_prompt}{img_prompt}'
            all_prompts.append(img_prompt)
            all_regions.append(regions)
            all_paths.append(image_path)
            
        elif 'I+Q' in mode:
            if 'llava' in model:
                img_prompt = format_question_woa(question)
                
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{img_prompt}'
                
            all_prompts.append(img_prompt)
            all_regions.append(regions)
            all_paths.append(image_path)
    return all_prompts, all_regions, all_paths

def process_data_slake_activation(dataset_path, model, mode):
    all_prompts = []
    all_regions = []
    all_paths = []
    
    dataset = []
    
    with open(dataset_path, 'r') as f:
        for i, line in enumerate(f):
            dataset.append(json.loads(line))
    for data in dataset:
        image_path = '/root/project/datasets/Slake1.0/imgs/' + data['image']
        question = data['question']
        answer = data['answer']
        regions = data['regions']        
        
        if 'I+Q+A' in mode:
            if 'llava' in model:
                prompt = format_question_answer(question, answer)
                
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                prompt = f'{prefix}{prompt}'
                
            all_prompts.append(prompt)
            all_regions.append(regions)
            all_paths.append(image_path)
        elif 'I+R+Q' in mode:
            if 'llava' in model:
                img_prompt = format_question_woa(question)
                region_prompt = 'Please pay attention to the following bounding box regions: '
                for region in regions:
                    re_p = f"{region['region_name']}: {region['bbox']}; "
                    region_prompt = region_prompt + re_p
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{region_prompt}{img_prompt}'
            all_prompts.append(img_prompt)
            all_regions.append(regions)
            all_paths.append(image_path)
            
        elif 'I+Q' in mode:
            if 'llava' in model:
                img_prompt = format_question_woa(question)
                
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{img_prompt}'
                
            all_prompts.append(img_prompt)
            all_regions.append(regions)
            all_paths.append(image_path)
    return all_prompts, all_regions, all_paths

def process_data_mimick_activation(knowledge_path, dataset_path, model, mode):
    all_prompts = []
    all_paths = []
    
    dataset = []
    with open(knowledge_path, 'r') as f:
        knowledges = json.load(f)
        
    with open(dataset_path, 'r') as f:
        datasets = json.load(f)
    
    if 'C+D+Q' in mode:
        questions_set = [
            "Does the image show any signs of {}?",
            "Is there any indication of {} in the image?",
            "Are there {} present?",
            "Is there a {} present in the image?",
            "What describes the patient's condition in the image?",
        ]
        for disease, knowledge in knowledges.items():
            if not "cases" in knowledge.keys():
                continue
            for id, diagnosis in knowledge["cases"].items():
                case_img_path = os.path.join("/root/project/disease_knowledge_dataset/pictures", disease, f"{id}.jpg")
                # if not os.path.exists(case_img_path):
                assert os.path.exists(case_img_path)
                for question in questions_set:
                    question_full = question.format(disease)
                    diagnosis_prompt = f"The following is the diagnosis of the image: {diagnosis} In summary, the patient was diagnosed with {disease}. \
Please understand the image and refer to the diagnosis to answer the following question: {question_full}"
                    prompt = format_question_woa(diagnosis_prompt)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    prompt = f'{prefix}{prompt}'
                    
                    all_prompts.append(prompt)
                    all_paths.append(case_img_path)

    elif 'C+Q' in mode:
        questions_set = [
            "Does the image show any signs of {}?",
            "Is there any indication of {} in the image?",
            "Are there {} present?",
            "Is there a {} present in the image?",
            "What describes the patient's condition in the image?",
        ]
        for disease, knowledge in knowledges.items():
            if not "cases" in knowledge.keys():
                continue
            for id, diagnosis in knowledge["cases"].items():
                case_img_path = os.path.join("/root/project/disease_knowledge_dataset/pictures", disease, f"{id}.jpg")
                # if not os.path.exists(case_img_path):
                assert os.path.exists(case_img_path)
                for question in questions_set:
                    question_full = question.format(disease)
                    prompt = format_question_woa(question_full)
                    
                    all_prompts.append(prompt)
                    all_paths.append(case_img_path)
    else:
        for data in datasets:
            image_path = '/root/project/datasets/mimic_cxr_jpg/files/' + data['img_id']
            if not data['entity'] in knowledges.keys():
                continue
            knowledge = knowledges[data['entity']]
            question = data['question']
            answer = data['answer']       
            false_choices = data['choices'].replace(answer, "").split(', ')
            if '' in false_choices:
                false_choices.remove('')
            
            if 'I+Q+A-' in mode:
                if 'llava' in model:
                    if data['question_type'] == 'multi-choice':
                        # true_idx = [answer in choice for choice in choices].index(True)
                        # choices.remove(choices[true_idx])
                        if len(false_choices) > 0:
                            false_answer = random.sample(false_choices, 1)[0]
                        else:
                            false_answer = 'No'
                    else:
                        binary = 'Yes' if 'Yes' in answer else 'No'
                        reverse_dict = {'Yes': 'No', 
                                        'No': 'Yes'}
                        false_answer = reverse_dict[binary]
                    prompt = format_question_answer(question, false_answer)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    prompt = f'{prefix}{prompt}'
                    
                all_prompts.append(prompt)
                all_paths.append(image_path)
            elif 'I+Q+A' in mode:
                if 'llava' in model:
                    prompt = format_question_answer(question, answer)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    prompt = f'{prefix}{prompt}'
                    
                all_prompts.append(prompt)
                all_paths.append(image_path)
            elif 'C+I+Q' in mode:
                if not 'cases' in knowledge.keys():
                    continue
                
                cases = knowledge['cases']
                case_diagnosis = cases["1"]
                case_img_path = os.path.join("/root/project/disease_knowledge_dataset/pictures", data['entity'], "1.jpg")
                # if not os.path.exists(case_img_path):
                assert os.path.exists(case_img_path)
                case_prompt = f"Here is a typical case and diagnosis: <image>\n{case_diagnosis}\n Please refer to the given case and answer the questions based on the provided medical images: "
                prompt = f"USER: {case_prompt}<image>\n{question}\nASSISTANT:"
                # prompt = format_question_woa(case_prompt)
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                prompt = f'{prefix}{prompt}'
                
                all_prompts.append(prompt)
                all_paths.append([case_img_path, image_path])

            elif 'I+Q+K-' in mode:
                wrong_knowledge = random.choice(list(knowledges.values()))
                wrong_know_ref_prompt = f"The following is relevant knowledge: {wrong_knowledge['description']} \
Please understand the image and refer to the given knowledge to answer the following question: "
                img_prompt = wrong_know_ref_prompt + question

                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = format_question_woa(img_prompt)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(image_path)
            elif 'I+Q+K+F' in mode:
                if 'feature' in knowledge.keys():
                    know = knowledge['description'] + f"The radiograph feature of {data['entity']} include: " + knowledge['feature']
                else:
                    know = knowledge['description']
                know_ref_prompt = f"The following is relevant knowledge: {know} \
Please understand the image and refer to the given knowledge to answer the following question: "
                img_prompt = know_ref_prompt + question

                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = format_question_woa(img_prompt)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(image_path)

            elif 'I+Q+K' in mode:
                know_ref_prompt = f"The following is relevant knowledge: {knowledge['description']} \
    Please understand the image and refer to the given knowledge to answer the following question: "
                img_prompt = know_ref_prompt + question
                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = format_question_woa(img_prompt)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(image_path)
            elif 'I+Q+RD' in mode:
                if not "polished_reasoning" in data.keys():
                    continue
                reasoning = data["polished_reasoning"]
                img_prompt = f"{reasoning} \
Please understand the image and refer to the diagnosis reasoning to answer the following question: {question}"

                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = format_question_woa(img_prompt)

                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(image_path) 

            elif 'I+Q+D' in mode:
                diagnosis_path = "/".join(image_path.split("/")[:-1])+".txt"
                with open(diagnosis_path, "r") as f:
                    record = f.readlines()
                    record = [item.replace("\n", "") for item in record]
                    record = [item for item in record if item != ""]
                    diagnosis = []
                    for i, d in enumerate(record):
                        if "FINDINGS" in d:
                            j = i + 2
                            while j < len(record) and record[j] != " ":
                                diagnosis.append(record[j])
                                j = j + 1
                    diagnosis = "".join(diagnosis)
                img_prompt = f"The following is the diagnosis of the image: {diagnosis} In summary, the patient was diagnosed with {data['entity']}. \
Please understand the image and refer to the diagnosis to answer the following question: {question}"

                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = format_question_woa(img_prompt)

                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(image_path) 

            elif 'I+Q' in mode:
                if "onlyc" in mode:
                    if not "cases" in knowledge.keys():
                        continue
                if "onlyr" in mode:
                    if not "polished_reasoning" in data.keys():
                        continue
                img_prompt = question
                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = format_question_woa(img_prompt)
                    
                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(image_path)
            elif "D+Q" in mode:
                diagnosis_path = "/".join(image_path.split("/")[:-1])+".txt"
                with open(diagnosis_path, "r") as f:
                    record = f.readlines()
                    record = [item.replace("\n", "") for item in record]
                    record = [item for item in record if item != ""]
                    diagnosis = []
                    for i, d in enumerate(record):
                        if "FINDINGS" in d:
                            j = i + 2
                            while j < len(record) and record[j] != " ":
                                diagnosis.append(record[j])
                                j = j + 1
                    diagnosis = "".join(diagnosis)
                img_prompt = f"The following is the diagnosis of the image: {diagnosis} In summary, the patient was diagnosed with {data['entity']}. \
Please understand the image and refer to the diagnosis to answer the following question: {question}"

                ### 如果是llava有特殊处理
                if 'llava' in model:
                    img_prompt = f"USER: {img_prompt}\nASSISTANT:"

                    prefix = "A chat between a curious user and an artificial intelligence assistant. \
    The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                    img_prompt = f'{prefix}{img_prompt}'
                    
                all_prompts.append(img_prompt)
                all_paths.append(None) 
                
    return all_prompts, all_paths

def process_data_harvard_pmc_activation(dataset_path, model, mode, img_path=None):
    all_prompts = []
    all_paths = []
    
    
    # if dataset_path.endswith('.jsonl'):
    #     # Handle JSONL files
    #     with open(dataset_path, 'r') as file:
    #         datasets = [json.loads(line) for line in file]
    # else:
    #     # Handle JSON files
    with open(dataset_path, 'r') as file:
        datasets = json.load(file)
    
    for data in datasets:
        image_path = os.path.join(img_path, data['img_name'])
        knowledge = data['knowledge']
        question = data['question']
        
        if 'I+Q+K' in mode:
            know_ref_prompt = f"The following is relevant knowledge: {knowledge} \
Please understand the image and refer to the given knowledge to answer the following question: "
            img_prompt = know_ref_prompt + question
            ### 如果是llava有特殊处理
            if 'llava' in model:
                img_prompt = format_question_woa(img_prompt)
                
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{img_prompt}'
                
            all_prompts.append(img_prompt)
            all_paths.append(image_path)
        elif 'I+Q+RD' in mode:
            if not "polished_reasoning" in data.keys():
                continue
            reasoning = data["polished_reasoning"]
            img_prompt = f"{reasoning} \
Please understand the image and refer to the diagnosis reasoning to answer the following question: {question}"

            ### 如果是llava有特殊处理
            if 'llava' in model:
                img_prompt = format_question_woa(img_prompt)

                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{img_prompt}'
                
            all_prompts.append(img_prompt)
            all_paths.append(image_path) 

        elif 'I+Q' in mode:
            if "onlyc" in mode:
                if not "cases" in knowledge.keys():
                    continue
            if "onlyr" in mode:
                if not "polished_reasoning" in data.keys():
                    continue
            img_prompt = question
            ### 如果是llava有特殊处理
            if 'llava' in model:
                img_prompt = format_question_woa(img_prompt)
                
                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{img_prompt}'
                
            all_prompts.append(img_prompt)
            all_paths.append(image_path)
            
        elif "D+Q" in mode:
            diagnosis = data["polished_reasoning"]
            img_prompt = f"The following is the diagnosis of the image: {diagnosis} In summary, the patient was diagnosed with {data['entity']}. \
Please understand the image and refer to the diagnosis to answer the following question: {question}"

            ### 如果是llava有特殊处理
            if 'llava' in model:
                img_prompt = f"USER: {img_prompt}\nASSISTANT:"

                prefix = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
                img_prompt = f'{prefix}{img_prompt}'
                
            all_prompts.append(img_prompt)
            all_paths.append(None) 
                
    return all_prompts, all_paths

def get_unpolished_reasoning_chain(knowledge_path, dataset_path):
    all_prompts = []
    all_paths = []
    
    dataset = []
    with open(knowledge_path, 'r') as f:
        knowledges = json.load(f)
        
    with open(dataset_path, 'r') as f:
        datasets = json.load(f)
    
    for data in datasets:
        if not data['entity'] in knowledges.keys():
            continue
        image_path = '/root/project/datasets/mimic_cxr_jpg/files/' + data['img_id']
        question = data['question']
        answer = data['answer']
    
        diagnosis_path = "/".join(image_path.split("/")[:-1])+".txt"
        with open(diagnosis_path, "r") as f:
            record = f.readlines()
            record = [item.replace("\n", "") for item in record]
            record = [item for item in record if item != ""]
            diagnosis = []
            for i, d in enumerate(record):
                if "FINDINGS" in d:
                    if record[i+1] != " ":
                        diagnosis.append(d.replace("FINDINGS:", ""))
                        j = i + 1
                    else:    
                        j = i + 2

                    while j < len(record) and record[j] != " ":
                        diagnosis.append(record[j])
                        j = j + 1
            if len(diagnosis) == 0:
                continue
            diagnosis = "".join(diagnosis)
            impression = []
            for i, d in enumerate(record):
                if "IMPRESSION" in d:
                    if i+1 < len(record) and record[i+1] != " ":
                        impression.append(d.replace("IMPRESSION:", ""))
                        j = i + 1
                    else:    
                        j = i + 2
                    while j < len(record) and record[j] != " ":
                        impression.append(record[j])
                        j = j + 1
            impression = "".join(impression)
            formated_text = parse_report_to_chain(diagnosis, question, impression)
            data["unpolished_reasoning"] = formated_text
    dataset_path_unpolish = "/root/project/disease_knowledge_dataset/mimic_type1_dataset_unpolish.json"
    with open(dataset_path_unpolish, "w") as f:
        json.dump(datasets, f)
    
    return
def load_and_conbine_activations(pos_path, neg_path, num=None):
    pos_head_wise_activations = np.load(pos_path)
    neg_head_wise_activations = np.load(neg_path)
    # if type(neg_paths) == list:
    #     neg_head_wise_activations = []
    #     for p in neg_paths:
    #         neg_head_wise_activations.append(np.load(p))
    #     neg_head_wise_activations = np.mean(np.stack(neg_head_wise_activations), axis=0)
    # else:
    #     neg_head_wise_activations = np.load(neg_paths)
    if type(num) == int:
        assert num <= pos_head_wise_activations.shape[0]
        labels = np.tile([1, 0], num) 
        stacked = np.stack((pos_head_wise_activations[:num], neg_head_wise_activations[:num]), axis=1) 
    else:
        labels = np.tile([1, 0], pos_head_wise_activations.shape[0]) 
        stacked = np.stack((pos_head_wise_activations, neg_head_wise_activations), axis=1) 
    head_wise_activations = stacked.reshape(-1, pos_head_wise_activations.shape[-2], pos_head_wise_activations.shape[-1])
    
    print(np.all(head_wise_activations[1] == neg_head_wise_activations[0]))
    return head_wise_activations, labels

def load_data(file_path):
    data = []
    with jsonlines.open(file_path, 'r') as reader:
        for line in tqdm(reader, desc="Loading data..."):
            data.append(line)
        return data
    
def save_data(data, save_path):
    with jsonlines.open(save_path, 'w') as writer:
        writer.write_all(data)

def load_nq():
    dataset = load_dataset("OamPatel/iti_nq_open_val")["validation"]
    df = pd.DataFrame(columns=["question", "answer", "false_answer"])
    for row in dataset:
        new_row = pd.DataFrame({"question": [row["question"]], "answer": [[_ for _ in row["answer"]]], "false_answer": [row["false_answer"]]})
        df = pd.concat([df, new_row], ignore_index=True)
    return df

def load_triviaqa():
    dataset = load_dataset("OamPatel/iti_trivia_qa_val")["validation"]
    df = pd.DataFrame(columns=["question", "answer", "false_answer"])
    for row in dataset:
        new_row = pd.DataFrame({"question": [row["question"]], "answer": [[_ for _ in row["answer"]['aliases']]], "false_answer": [row["false_answer"]]})
        df = pd.concat([df, new_row], ignore_index=True)
    return df

def format_question_answer(question, anwser):
    return f"<image>\nUSER: {question}\nASSISTANT: {anwser}"
    # return f"<image>\nQ: {question} A: {anwser}"

def format_question_woa(question):
    return f"<image>\nUSER: {question}\nASSISTANT:"

def format_question(question):
    return f"<image>\nUSER: {question}"

def format_question_with_choices(question, choices):
    return f"<image>\nQ: {question}"

def format_truthfulqa(question, choice):
    return f"Q: {question} A: {choice}"

def format_truthfulqa_end_q(question, choice, rand_question): 
    return f"Q: {question} A: {choice} Q: {rand_question}"


def tokenized_tqa(dataset, tokenizer): 

    all_prompts = []
    all_labels = []
    for i in range(len(dataset)):
        question = dataset[i]['question']
        choices = dataset[i]['mc2_targets']['choices']
        labels = dataset[i]['mc2_targets']['labels']

        assert len(choices) == len(labels), (len(choices), len(labels))

        for j in range(len(choices)): 
            choice = choices[j]
            label = labels[j]
            prompt = format_truthfulqa(question, choice)
            if i == 0 and j == 0: 
                print(prompt)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(label)
    
    return all_prompts, all_labels

def tokenized_tqa_gen_end_q(dataset, tokenizer): 

    all_prompts = []
    all_labels = []
    all_categories = []
    for i in range(len(dataset)): 
        question = dataset[i]['question']
        category = dataset[i]['category']
        rand_idx = np.random.randint(len(dataset))
        rand_question = dataset[rand_idx]['question']

        for j in range(len(dataset[i]['correct_answers'])): 
            answer = dataset[i]['correct_answers'][j]
            prompt = format_truthfulqa_end_q(question, answer, rand_question)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(1)
            all_categories.append(category)
        
        for j in range(len(dataset[i]['incorrect_answers'])):
            answer = dataset[i]['incorrect_answers'][j]
            prompt = format_truthfulqa_end_q(question, answer, rand_question)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(0)
            all_categories.append(category)
        
    return all_prompts, all_labels, all_categories

def tokenized_tqa_gen(dataset, tokenizer): 

    all_prompts = []
    all_labels = []
    all_categories = []
    for i in range(len(dataset)): 
        question = dataset[i]['question']
        category = dataset[i]['category']

        for j in range(len(dataset[i]['correct_answers'])): 
            answer = dataset[i]['correct_answers'][j]
            prompt = format_truthfulqa(question, answer)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(1)
            all_categories.append(category)
        
        for j in range(len(dataset[i]['incorrect_answers'])):
            answer = dataset[i]['incorrect_answers'][j]
            prompt = format_truthfulqa(question, answer)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(0)
            all_categories.append(category)
        
    return all_prompts, all_labels, all_categories


def get_llama_activations_bau(model, prompt, device): 

    HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(model.config.num_hidden_layers)]
    MLPS = [f"model.layers.{i}.mlp" for i in range(model.config.num_hidden_layers)]

    with torch.no_grad():
        prompt = prompt.to(device)
        with TraceDict(model, HEADS+MLPS) as ret:
            output = model(prompt, output_hidden_states = True)
        hidden_states = output.hidden_states
        hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
        hidden_states = hidden_states.detach().cpu().numpy()
        head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
        head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
        mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS]
        mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim = 0).squeeze().numpy()

    return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states


def get_llama_logits(model, prompt, device): 

    model.eval()
    with torch.no_grad(): 
        prompt = prompt.to(device)
        logits = model(prompt).logits
        logits = logits.detach().cpu()
        return logits

def save_probes(probes, path): 
    """takes in a list of sklearn lr probes and saves them to path"""
    with open(path, 'wb') as f: 
        pickle.dump(probes, f)

def load_probes(path): 
    """loads a list of sklearn lr probes from path"""
    with open(path, 'rb') as f: 
        probes = pickle.load(f)
    return probes


def flattened_idx_to_layer_head(flattened_idx, num_heads):
    return flattened_idx // num_heads, flattened_idx % num_heads

def layer_head_to_flattened_idx(layer, head, num_heads):
    return layer * num_heads + head

def train_probes(seed, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels, num_layers, num_heads):
    
    all_head_accs = []
    probes = []

    all_X_train = np.concatenate([separated_head_wise_activations[i] for i in train_set_idxs], axis = 0)
    all_X_val = np.concatenate([separated_head_wise_activations[i] for i in val_set_idxs], axis = 0)
    y_train = np.concatenate([separated_labels[i] for i in train_set_idxs], axis = 0)
    y_val = np.concatenate([separated_labels[i] for i in val_set_idxs], axis = 0)

    for layer in tqdm(range(num_layers)): 
        for head in range(num_heads): 
            X_train = all_X_train[:,layer,head,:]
            X_val = all_X_val[:,layer,head,:]
    
            clf = LogisticRegression(random_state=seed, max_iter=1000).fit(X_train, y_train)
            y_pred = clf.predict(X_train)
            y_val_pred = clf.predict(X_val)
            all_head_accs.append(accuracy_score(y_val, y_val_pred))
            probes.append(clf)

    all_head_accs_np = np.array(all_head_accs)
    # sorted_idx = np.argsort(all_head_accs_np)[::-1]
    # np.save('/root/wtb/multimodal_alignment/mm_iti/features/idx_neg_best.npy', sorted_idx)
    return probes, all_head_accs_np

def train_probes_layer(seed, train_set_idxs, val_set_idxs, separated_layer_wise_activations, separated_labels, num_layers):
    
    all_layer_accs = []
    probes = []

    all_X_train = np.concatenate([separated_layer_wise_activations[i] for i in train_set_idxs], axis = 0)
    all_X_val = np.concatenate([separated_layer_wise_activations[i] for i in val_set_idxs], axis = 0)
    y_train = np.concatenate([separated_labels[i] for i in train_set_idxs], axis = 0)
    y_val = np.concatenate([separated_labels[i] for i in val_set_idxs], axis = 0)

    for layer in tqdm(range(num_layers)): 
        X_train = all_X_train[:,layer,:]
        X_val = all_X_val[:,layer,:]

        clf = LogisticRegression(random_state=seed, max_iter=1000).fit(X_train, y_train)
        y_pred = clf.predict(X_train)
        y_val_pred = clf.predict(X_val)
        all_layer_accs.append(accuracy_score(y_val, y_val_pred))
        probes.append(clf)

    all_layer_accs_np = np.array(all_layer_accs)

    return probes, all_layer_accs_np

def val_probes(seed, val_set_idxs, separated_head_wise_activations, separated_labels, probes, num_layers, num_heads, num_to_intervene, reverse=False):
    
    all_head_accs = []
    all_head_f1s = []
    all_head_pres = []
    all_head_recs = []

    all_X_val = np.concatenate([separated_head_wise_activations[i] for i in val_set_idxs], axis = 0)
    y_val = np.concatenate([separated_labels[i] for i in val_set_idxs], axis = 0)
    
    coefs = [p.coef_ for p in probes]
    coefs = np.concatenate(coefs)
    norms = np.linalg.norm(coefs, axis=1, keepdims=True)  # 形状 (1024, 1)

    # 2. 归一化向量（避免除零错误，假设所有模长 > 0）
    X_normalized = coefs / norms  # 形状 (1024, 4096)

    # 3. 计算余弦相似度矩阵
    cosine_similarity = X_normalized @ X_normalized.T  # 形状 (1024, 1024)
    for layer in tqdm(range(num_layers)): 
        for head in range(num_heads): 
            X_val = all_X_val[:,layer,head,:]
            probe = probes[layer * num_layers + head]
    
            # clf = LogisticRegression(random_state=seed, max_iter=1000).fit(X_train, y_train)
            # y_pred = probe.predict(X_train)
            y_val_pred = probe.predict(X_val)
            all_head_accs.append(accuracy_score(y_val, y_val_pred))
            all_head_f1s.append(f1_score(y_val, y_val_pred))
            all_head_pres.append(precision_score(y_val, y_val_pred))
            all_head_recs.append(recall_score(y_val, y_val_pred))

    all_head_accs_np = np.array(all_head_accs)
    all_head_f1s_np = np.array(all_head_f1s)
    all_head_pres_np = np.array(all_head_pres)
    all_head_recs_np = np.array(all_head_recs)
    sorted_idx = np.argsort(all_head_accs_np)[::-1]
    # np.save('/root/project/features/idx_A_A-.npy', sorted_idx)
    
    for k in [8, 16, 24, 32, 40, 48]:
        topk_accs = all_head_accs_np[sorted_idx][:k]
        topk_f1s = all_head_f1s_np[sorted_idx][:k]
        topk_pres = all_head_pres_np[sorted_idx][:k]
        topk_recs = all_head_recs_np[sorted_idx][:k]
        print(k, np.average(topk_accs), np.average(topk_f1s), np.average(topk_pres), np.average(topk_recs))
    # all_head_accs_np = all_head_accs_np.reshape(num_layers, num_heads)
    # sorted_all_head_accs = -np.sort(-all_head_accs_np, axis=1,)
    # draw_heatmap(sorted_all_head_accs * 100, np.array(range(0, num_heads)), np.array(range(0, num_layers)), 
    #              '/root/wtb/iti/figures/accs_allwords.jpg', size=6, xlabel='head index', ylabel='layers index')

    # if reverse==True:
    #     topk_accs = np.sort(all_head_accs_np.reshape(num_heads*num_layers))[:num_to_intervene]
    #     top_heads = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[:num_to_intervene]
    # else:
    #     topk_accs = -np.sort(-all_head_accs_np.reshape(num_heads*num_layers))[:num_to_intervene]
    #     top_heads = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene]
    # top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    # print(topk_accs, np.average(topk_accs))
    # return topk_accs, top_heads
    for k in [8, 16, 24, 32, 40, 48]:
        topk_accs = -np.sort(-all_head_accs_np.reshape(num_heads*num_layers))[:k]
        print(topk_accs, np.average(topk_accs))

    # top_heads = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene]
    
    if all_head_accs_np[sorted_idx][0] == 1: ### cant find which to interven
        sorted_idx = np.load('/root/wtb/multimodal_alignment/mm_iti/features/idx.npy')
    top_heads = sorted_idx[:num_to_intervene]
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    print(top_heads)
    return topk_accs, top_heads

def val_probes_2(seed, val_set_idxs, separated_head_wise_activations, separated_labels, probes, num_layers, num_heads, num_to_intervene, reverse=False):
    
    all_head_accs = []
    all_head_f1s = []
    all_head_pres = []
    all_head_recs = []

    all_X_val = np.concatenate([separated_head_wise_activations[i] for i in val_set_idxs], axis = 0)
    y_val = np.concatenate([separated_labels[i] for i in val_set_idxs], axis = 0)

    for layer in tqdm(range(num_layers)): 
        for head in range(num_heads): 
            X_val = all_X_val[:,layer,head,:]
            probe = probes[layer * num_layers + head]
    
            y_val_pred = probe.predict(X_val)
            all_head_accs.append(accuracy_score(y_val, y_val_pred))
            all_head_f1s.append(f1_score(y_val, y_val_pred))
            all_head_pres.append(precision_score(y_val, y_val_pred))
            all_head_recs.append(recall_score(y_val, y_val_pred))

    all_head_accs_np = np.array(all_head_accs)
    all_head_f1s_np = np.array(all_head_f1s)
    all_head_pres_np = np.array(all_head_pres)
    all_head_recs_np = np.array(all_head_recs)
    sorted_idx = np.argsort(all_head_accs_np)[::-1]
    
    
    for k in [8, 16, 24, 32, 40, 48]:
        topk_accs = -np.sort(-all_head_accs_np.reshape(num_heads*num_layers))[:k]
        print(topk_accs, np.average(topk_accs))

    # top_heads = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene]
    
    if all_head_accs_np[sorted_idx][0] == 1: ### cant find which to interven
        sorted_idx = np.load('/root/wtb/multimodal_alignment/mm_iti/features/idx_neg_best.npy')
    top_heads = sorted_idx[:num_to_intervene]
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    print(top_heads)
    return topk_accs, top_heads


def val_probes_layer(seed, val_set_idxs, separated_layer_wise_activations, separated_labels, probes, num_layers, num_to_intervene, reverse=False):
    
    all_layer_accs = []
    all_layer_f1s = []
    all_layer_pres = []
    all_layer_recs = []

    all_X_val = np.concatenate([separated_layer_wise_activations[i] for i in val_set_idxs], axis = 0)
    y_val = np.concatenate([separated_labels[i] for i in val_set_idxs], axis = 0)

    for layer in tqdm(range(num_layers)): 
        X_val = all_X_val[:,layer,:]
        probe = probes[layer]
        # clf = LogisticRegression(random_state=seed, max_iter=1000).fit(X_train, y_train)
        # y_pred = probe.predict(X_train)
        y_val_pred = probe.predict(X_val)
        all_layer_accs.append(accuracy_score(y_val, y_val_pred))
        all_layer_f1s.append(f1_score(y_val, y_val_pred))
        all_layer_pres.append(precision_score(y_val, y_val_pred))
        all_layer_recs.append(recall_score(y_val, y_val_pred))

    all_layer_accs_np = np.array(all_layer_accs)
    all_layer_f1s_np = np.array(all_layer_f1s)
    all_layer_pres_np = np.array(all_layer_pres)
    all_layer_recs_np = np.array(all_layer_recs)
    sorted_idx = np.argsort(all_layer_accs_np)[::-1]
    
    for k in [1,4,8,16,32]:
        topk_accs = all_layer_accs_np[sorted_idx][:k]
        topk_f1s = all_layer_f1s_np[sorted_idx][:k]
        topk_pres = all_layer_pres_np[sorted_idx][:k]
        topk_recs = all_layer_recs_np[sorted_idx][:k]
        print(k, np.average(topk_accs), np.average(topk_f1s), np.average(topk_pres), np.average(topk_recs))
    # all_head_accs_np = all_head_accs_np.reshape(num_layers, num_heads)
    # sorted_all_head_accs = -np.sort(-all_head_accs_np, axis=1,)
    # draw_heatmap(sorted_all_head_accs * 100, np.array(range(0, num_heads)), np.array(range(0, num_layers)), 
    #              '/root/wtb/iti/figures/accs_allwords.jpg', size=6, xlabel='head index', ylabel='layers index')

    # if reverse==True:
    #     topk_accs = np.sort(all_head_accs_np.reshape(num_heads*num_layers))[:num_to_intervene]
    #     top_heads = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[:num_to_intervene]
    # else:
    #     topk_accs = -np.sort(-all_head_accs_np.reshape(num_heads*num_layers))[:num_to_intervene]
    #     top_heads = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene]
    # top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    # print(topk_accs, np.average(topk_accs))
    # return topk_accs, top_heads
    for k in [1,4,8,16,32]:
        topk_accs = -np.sort(-all_layer_accs_np)[:k]
        print(topk_accs, np.average(topk_accs))

    # top_layers = np.argsort(all_layer_accs_np.reshape(num_layers*num_layers))[::-1][:num_to_intervene]
    
    # if all_layer_accs_np[sorted_idx][0] == 1: ### cant find which to interven
    #     sorted_idx = np.load('/root/wtb/multimodal_alignment/mm_iti/features/idx.npy')
    top_layers = sorted_idx[:num_to_intervene]
    
    print(top_layers)
    return topk_accs, top_layers

def sort_direction_len(com_directions, num_layers, num_heads, num_to_intervene, start_layer=5, end_layer=26):
    lens = []
    # start_idx = start_layer * num_layers
    # end_idx = end_layer * num_layers
    # vectors = com_directions[start_idx:end_idx, :]
    
    for layer in tqdm(range(start_layer, end_layer)): 
        for head in range(num_heads): 
            vector = com_directions[layer * num_layers + head]
            len = np.linalg.norm(vector)
            lens.append(len)
            
    sorted_idx = np.argsort(lens)[::-1]
    # np.save('/root/wtb/multimodal_alignment/mm_iti/features/idx_llava.npy', sorted_idx)
    # sorted_idx = np.load('/root/wtb/multimodal_alignment/mm_iti/features/idx.npy')
    top_heads = sorted_idx[:num_to_intervene]
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    print(top_heads)
    return top_heads

def sort_direction_len_together(com_directions, com_directions2, num_layers, num_heads, num_to_intervene, start_layer=5, end_layer=26):
    
    lens = [] 
    for layer in tqdm(range(start_layer, end_layer)): 
        for head in range(num_heads): 
            vector = com_directions[layer * num_layers + head]
            len_ = np.linalg.norm(vector)
            lens.append(len_)
    
    lens2 = [] 
    for layer in tqdm(range(start_layer, end_layer)): 
        for head in range(num_heads): 
            vector = com_directions2[layer * num_layers + head]
            len_ = np.linalg.norm(vector)
            lens2.append(len_)
    
    all_len = lens + lens2
    all_source = [0 for _ in range(len(lens))] + [1 for _ in range(len(lens2))]

    sorted_idx = np.argsort(all_len)[::-1]
    # np.save('/root/wtb/multimodal_alignment/mm_iti/features/idx_llava.npy', sorted_idx)
    # sorted_idx = np.load('/root/wtb/multimodal_alignment/mm_iti/features/idx.npy')
    top_heads = sorted_idx[:num_to_intervene] 
    sources = top_heads // len(lens)
    top_heads = top_heads % len(lens)
    
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    print(top_heads)
    return top_heads, sources

def get_top_heads(train_idxs, val_idxs, separated_activations, separated_labels, num_layers, num_heads, seed, num_to_intervene, use_random_dir=False):

    probes, all_head_accs_np = train_probes(seed, train_idxs, val_idxs, separated_activations, separated_labels, num_layers=num_layers, num_heads=num_heads)
    all_head_accs_np = all_head_accs_np.reshape(num_layers, num_heads)

    top_heads = []

    top_accs = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene]
    print(np.sort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene])
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_accs]
    if use_random_dir: 
        # overwrite top heads with random heads, no replacement
        random_idxs = np.random.choice(num_heads*num_layers, num_heads*num_layers, replace=False)
        top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in random_idxs[:num_to_intervene]]

    return top_heads, probes

def get_top_layers(train_idxs, val_idxs, separated_activations, separated_labels, num_layers, seed, num_to_intervene, use_random_dir=False):

    probes, all_layer_accs_np = train_probes_layer(seed, train_idxs, val_idxs, separated_activations, separated_labels, num_layers=num_layers)

    top_layers = []

    top_accs = np.argsort(all_layer_accs_np)[::-1][:num_to_intervene]
    print(np.sort(all_layer_accs_np)[::-1][:num_to_intervene])
    top_layers = top_accs
    if use_random_dir: 
        # overwrite top layers with random layers, no replacement
        random_idxs = np.random.choice(num_layers, num_layers, replace=False)
        top_layers = [random_idxs[:num_to_intervene]]

    return top_layers, probes

def get_interventions_dict(model_name, top_heads, probes, tuning_activations, num_heads, use_center_of_mass, use_random_dir, com_directions): 
    if 'llava_med' in model_name:
        prefix = 'model'
    # elif 'llava' in model_name or 'instructblip' in model_name:
    #     prefix = 'language_model.model' 

    interventions = {}
    for layer, head in top_heads: 
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"] = []
    for layer, head in top_heads:
        if use_center_of_mass: 
            direction = com_directions[layer_head_to_flattened_idx(layer, head, num_heads)]
        elif use_random_dir: 
            direction = np.random.normal(size=(128,))
        else: 
            direction = probes[layer_head_to_flattened_idx(layer, head, num_heads)].coef_
        direction = direction / np.linalg.norm(direction)
        # activations = tuning_activations[:,layer,head,:] # batch x 128
        # proj_vals = activations @ direction.T
        # proj_val_std = np.std(proj_vals)
        proj_val_std = 1
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"].append((head, direction.squeeze(), proj_val_std))
    for layer, head in top_heads: 
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"] = sorted(interventions[f"{prefix}.layers.{layer}.self_attn.head_out"], key = lambda x: x[0])

    return interventions

def get_interventions_dict_together(model_name, top_heads, vector_types, probes, tuning_activations, num_heads, use_center_of_mass, use_random_dir, com_directions, com_directions2): 
    if 'llava_med' in model_name:
        prefix = 'model'
    # elif 'llava' in model_name or 'instructblip' in model_name:
    #     prefix = 'language_model.model' 

    interventions = {}
    for layer, head in top_heads: 
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"] = []

    for (layer, head), source in zip(top_heads, list(vector_types)):
        if use_center_of_mass: 
            if source == 0:
                direction = com_directions[layer_head_to_flattened_idx(layer, head, num_heads)]
            else:
                direction = com_directions2[layer_head_to_flattened_idx(layer, head, num_heads)]

        elif use_random_dir: 
            direction = np.random.normal(size=(128,))
        else: 
            direction = probes[layer_head_to_flattened_idx(layer, head, num_heads)].coef_
        
        direction = direction / np.linalg.norm(direction)
        # activations = tuning_activations[:,layer,head,:] # batch x 128
        # proj_vals = activations @ direction.T
        # proj_val_std = np.std(proj_vals)
        proj_val_std = 1
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"].append((head, direction.squeeze(), proj_val_std))

    for layer, head in top_heads: 
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"] = sorted(interventions[f"{prefix}.layers.{layer}.self_attn.head_out"], key = lambda x: x[0])

    return interventions

def get_interventions_dict_layer(top_layers, probes, tuning_activations, use_center_of_mass, use_random_dir, com_directions): 

    interventions = {}

    for layer in top_layers:
        if use_center_of_mass: 
            direction = com_directions[layer]
        elif use_random_dir: 
            direction = np.random.normal(size=(4096,))
        else: 
            direction = probes[layer].coef_
        direction = direction / np.linalg.norm(direction)
        # activations = tuning_activations[:,layer,head,:] # batch x 128
        # proj_vals = activations @ direction.T
        # proj_val_std = np.std(proj_vals)
        proj_val_std = 1
        interventions[f"language_model.model.layers.{layer}.mlp"] = (direction.squeeze(), proj_val_std)

    return interventions

def get_interventions_dict_withprobe(top_heads, probes, tuning_activations, num_heads, use_center_of_mass, use_random_dir, com_directions): 

    interventions = {}
    for layer, head in top_heads: 
        interventions[f"language_model.model.layers.{layer}.self_attn.head_out"] = []
    for layer, head in top_heads:
        if use_center_of_mass: 
            direction = com_directions[layer_head_to_flattened_idx(layer, head, num_heads)]
        elif use_random_dir: 
            direction = np.random.normal(size=(128,))
        else: 
            direction = probes[layer_head_to_flattened_idx(layer, head, num_heads)].coef_
            
        direction = direction / np.linalg.norm(direction)
        
        probe = probes[layer_head_to_flattened_idx(layer, head, num_heads)]
        
        ## turn the logisticRegression model to Linear
        probe_linear = torch.nn.Linear(probe.coef_.shape[-1], 1, bias=True, dtype=torch.double).cuda()
        with torch.no_grad():
            probe_linear.weight = torch.nn.Parameter(torch.DoubleTensor(probe.coef_).cuda())  
            probe_linear.bias = torch.nn.Parameter(torch.DoubleTensor(probe.intercept_).cuda())
        
        # activations = tuning_activations[:,layer,head,:] # batch x 128
        # proj_vals = activations @ direction.T
        # proj_val_std = np.std(proj_vals)
        proj_val_std = 1
        interventions[f"language_model.model.layers.{layer}.self_attn.head_out"].append((head, direction.squeeze(), proj_val_std, probe_linear))
    for layer, head in top_heads: 
        interventions[f"language_model.model.layers.{layer}.self_attn.head_out"] = sorted(interventions[f"language_model.model.layers.{layer}.self_attn.head_out"], key = lambda x: x[0])

    return interventions

def get_interventions_dict_withoffset(model_name, top_heads, probes, tuning_activations, num_heads, use_center_of_mass, use_random_dir, com_directions, offsetgenerators): 
    if 'llava' in model_name and 'lht' in model_name or 'shikra' in model_name:
        prefix = 'model'
    elif 'llava' or 'instructblip' in model_name:
        prefix = 'language_model.model'
        
    interventions = {}
    for layer, head in top_heads: 
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"] = []
    for layer, head in top_heads:
        if use_center_of_mass: 
            direction = com_directions[layer_head_to_flattened_idx(layer, head, num_heads)]
        elif use_random_dir: 
            direction = np.random.normal(size=(128,))
        else: 
            direction = probes[layer_head_to_flattened_idx(layer, head, num_heads)].coef_
            
        # direction = direction / np.linalg.norm(direction)
        
        generator = copy.deepcopy(offsetgenerators.nets[layer_head_to_flattened_idx(layer, head, num_heads)])
        
        # activations = tuning_activations[:,layer,head,:] # batch x 128
        # proj_vals = activations @ direction.T
        # proj_val_std = np.std(proj_vals)
        proj_val_std = 1
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"].append((head, direction.squeeze(), proj_val_std, generator))
    for layer, head in top_heads: 
        interventions[f"{prefix}.layers.{layer}.self_attn.head_out"] = sorted(interventions[f"{prefix}.layers.{layer}.self_attn.head_out"], key = lambda x: x[0])

    return interventions



def merge_interventions(interventions, edit_locations):
    final_interventions = {}
    for (intervention, loc) in zip(interventions, edit_locations):
        for layer, inters in intervention.items():
            if not layer in final_interventions:
                final_interventions[layer] = []
            for inter in inters:
                inter_with_loc = inter + (loc,)
                final_interventions[layer].append(inter_with_loc)
    return final_interventions

def get_separated_activations(labels, head_wise_activations, split_range): 

    # separate activations by question
    idxs_to_split_at = np.linspace(split_range, labels.shape[0], int(labels.shape[0] / split_range), dtype=np.int64)       

    labels = list(labels)
    separated_labels = []
    for i in range(len(idxs_to_split_at)):
        if i == 0:
            separated_labels.append(labels[:idxs_to_split_at[i]])
        else:
            separated_labels.append(labels[idxs_to_split_at[i-1]:idxs_to_split_at[i]])
            
    separated_head_wise_activations = np.split(head_wise_activations, idxs_to_split_at)[:-1]

    return separated_head_wise_activations, separated_labels, idxs_to_split_at

def get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels): 

    com_directions = []

    for layer in range(num_layers): 
        for head in range(num_heads): 
            usable_idxs = np.concatenate([train_set_idxs, val_set_idxs], axis=0)
            usable_head_wise_activations = np.concatenate([separated_head_wise_activations[i][:,layer,head,:] for i in usable_idxs], axis=0)
            usable_labels = np.concatenate([separated_labels[i] for i in usable_idxs], axis=0)
            true_mass_mean = np.mean(usable_head_wise_activations[usable_labels == 1], axis=0)
            false_mass_mean = np.mean(usable_head_wise_activations[usable_labels == 0], axis=0)
            com_directions.append(true_mass_mean - false_mass_mean)
    com_directions = np.array(com_directions)

    return com_directions

def get_com_directions_layer(num_layers, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels): 

    com_directions = []

    for layer in range(num_layers): 
        usable_idxs = np.concatenate([train_set_idxs, val_set_idxs], axis=0)
        usable_head_wise_activations = np.concatenate([separated_head_wise_activations[i][:,layer,:] for i in usable_idxs], axis=0)
        usable_labels = np.concatenate([separated_labels[i] for i in usable_idxs], axis=0)
        true_mass_mean = np.mean(usable_head_wise_activations[usable_labels == 1], axis=0)
        false_mass_mean = np.mean(usable_head_wise_activations[usable_labels == 0], axis=0)
        com_directions.append(true_mass_mean - false_mass_mean)
    com_directions = np.array(com_directions)

    return com_directions

def get_com_directions_i2t(whole_image_activations, whole_text_activations, use_all_caption=False): 

    # whole_image_activations = whole_image_activations.squeeze(axis=1)
    com_directions = []
    for image_activation, text_activation in zip(whole_image_activations, whole_text_activations):
        img_mean = np.mean(image_activation, axis=1).squeeze(axis=0)
        if use_all_caption == True:
            cap_act = [np.mean(t, axis=1).squeeze() for t in text_activation]
            cap_mean =  np.mean(np.array(cap_act), axis=0)
        else:
            cap_mean = np.mean(text_activation[0], axis=1).squeeze(axis=0)
        com_directions.append(cap_mean - img_mean)
    
    com_directions = np.array(com_directions)
    com_directions = np.mean(com_directions, axis=0)
    
    return com_directions

def add_gaussian_noise_outside_bboxes(image, bbox_list, mean=0, sigma=25):
    """
    对图像中多个指定边界框(bbox)外的区域添加高斯噪声
    :param image: PIL Image对象
    :param bbox_list: 边界框列表，每个bbox格式为[x1, y1, x2, y2]
    :param mean: 高斯噪声均值 (默认为0)
    :param sigma: 高斯噪声标准差 (默认为25)
    :return: 处理后的PIL Image
    """
    # 将图像转换为NumPy数组
    img_array = np.array(image)
    height, width = img_array.shape[:2]
    
    # 创建保护区域掩码（初始为全1，表示整个区域都需要加噪声）
    protection_mask = np.zeros((height, width), dtype=bool)
    
    # 处理每个bbox，创建保护区域（bbox内部为True）
    for bbox in bbox_list:
        # 确保bbox在图像范围内
        x1, y1, x2, y2 = map(int, bbox)
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(width, x2), min(height, y2)
        
        # 跳过无效bbox
        if x1 >= x2 or y1 >= y2:
            continue
            
        # 将当前bbox区域标记为受保护（True）
        protection_mask[y1:y2, x1:x2] = True
    
    # 创建全图噪声层
    noise = np.random.normal(mean, sigma, img_array.shape).astype(np.float32)
    
    # 应用噪声到非保护区域（protection_mask为False的区域）
    noisy_array = img_array.copy().astype(np.float32)
    # 保护区域（protection_mask为True）保持不变
    # 非保护区域（protection_mask为False）添加噪声
    noisy_array[~protection_mask] += noise[~protection_mask]
    
    # 确保像素值在合法范围内[0, 255]
    noisy_array = np.clip(noisy_array, 0, 255).astype(np.uint8)
    
    return Image.fromarray(noisy_array)