import os
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from collections import defaultdict
from PIL import Image
import random
import json
import numpy as np
from itertools import product
import datetime

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

start_time = datetime.datetime.now()

from pruning_llava_utils import prune_model_llava_dynamic, batch_generate_llava
from chair_metrics import batch_compute_chair_metrics 

model_id = "llava-hf/llava-1.5-13b-hf"

# Load model
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
# Recommended: use AutoProcessor to load the processor
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)

# ============ Parameter Settings ============
img_dir = "path/to/val2014"
txt_file = "./semantic_attribution/chosen_imgs_13b_5_18-hallu.txt"
txt_file_chosen = "holo.txt"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
log_file = "experiment_log_13b_5_18-hallu.txt"
N = 8                    # Number of images per round
n_rounds = 20            # Total number of experiment rounds

target_layers = list(range(5, 19))  
n_heads_per_layer = 40 

faithful_heads_list =[
[[5, 1], [5, 6], [5, 14], [5, 16], [5, 24], [5, 30], [6, 0], [6, 6], [6, 16], [6, 17], [6, 22], [6, 26], [6, 30], [6, 36], [6, 37], [6, 39], [7, 6], [7, 10], [7, 24], [7, 25], [7, 28], [7, 30], [7, 37], [8, 0], [8, 3], [8, 12], [8, 13], [8, 19], [8, 30], [8, 32], [8, 33], [8, 35], [9, 0], [9, 3], [9, 12], [9, 13], [9, 15], [9, 24], [9, 26], [9, 30], [9, 32], [9, 33], [9, 34], [9, 38], [10, 0], [10, 2], [10, 6], [10, 7], [10, 17], [10, 23], [10, 31], [10, 36], [10, 38], [11, 1], [11, 9], [11, 10], [11, 14], [11, 18], [11, 21], [11, 22], [11, 23], [11, 25], [11, 26], [11, 33], [11, 35], [12, 4], [12, 6], [12, 8], [12, 20], [12, 22], [12, 27], [12, 29], [12, 30], [12, 33], [12, 37], [13, 1], [13, 7], [13, 8], [13, 9], [13, 15], [13, 17], [13, 18], [13, 19], [13, 25], [13, 27], [13, 29], [13, 31], [13, 33], [13, 39], [14, 3], [14, 4], [14, 20], [14, 23], [14, 26], [14, 30], [14, 31], [14, 33], [14, 36], [15, 2], [15, 5], [15, 7], [15, 8], [15, 17], [15, 30], [15, 31], [16, 1], [16, 8], [16, 16], [16, 23], [16, 26], [16, 30], [16, 31], [16, 32], [16, 33], [16, 34], [16, 38], [17, 2], [17, 4], [17, 7], [17, 10], [17, 12], [17, 18], [17, 22], [17, 23], [17, 27], [17, 32], [17, 39], [18, 1], [18, 2], [18, 6], [18, 8], [18, 9], [18, 13], [18, 20], [18, 23], [18, 24], [18, 25], [18, 29], [18, 32]], #0.41-0.59 c>=3
[[5, 1], [5, 6], [5, 14], [5, 16], [5, 24], [5, 30], [6, 0], [6, 6], [6, 16], [6, 17], [6, 22], [6, 26], [6, 30], [6, 36], [6, 37], [6, 39], [7, 6], [7, 10], [7, 24], [7, 25], [7, 28], [7, 30], [7, 37], [8, 0], [8, 3], [8, 12], [8, 13], [8, 19], [8, 30], [8, 32], [8, 33], [8, 35], [9, 0], [9, 3], [9, 12], [9, 13], [9, 15], [9, 24], [9, 26], [9, 30], [9, 32], [9, 33], [9, 34], [9, 38], [10, 0], [10, 2], [10, 6], [10, 7], [10, 17], [10, 23], [10, 31], [10, 36], [10, 38], [11, 1], [11, 9], [11, 10], [11, 14], [11, 18], [11, 21], [11, 22], [11, 23], [11, 25], [11, 26], [11, 33], [11, 35], [12, 4], [12, 6], [12, 8], [12, 20], [12, 22], [12, 27], [12, 29], [12, 30], [12, 33], [12, 37], [13, 1], [13, 7], [13, 8], [13, 9], [13, 15], [13, 17], [13, 18], [13, 19], [13, 25], [13, 27], [13, 29], [13, 31], [13, 33], [13, 39], [14, 3], [14, 4], [14, 20], [14, 23], [14, 26], [14, 30], [14, 31], [14, 33], [14, 36], [15, 2], [15, 5], [15, 7], [15, 8], [15, 17], [15, 30], [15, 31], [16, 1], [16, 8], [16, 16], [16, 23], [16, 26], [16, 30], [16, 31], [16, 32], [16, 33], [16, 34], [16, 38], [17, 2], [17, 4], [17, 7], [17, 10], [17, 12], [17, 18], [17, 22], [17, 23], [17, 27], [17, 32], [17, 39], [18, 1], [18, 2], [18, 6], [18, 8], [18, 9], [18, 13], [18, 20], [18, 23], [18, 24], [18, 25], [18, 29], [18, 32]], #0.41-0.59 c>=3
[[5, 1], [5, 6], [5, 14], [5, 16], [5, 24], [5, 30], [6, 0], [6, 6], [6, 16], [6, 17], [6, 22], [6, 26], [6, 30], [6, 36], [6, 37], [6, 39], [7, 6], [7, 10], [7, 24], [7, 25], [7, 28], [7, 30], [7, 37], [8, 0], [8, 3], [8, 12], [8, 13], [8, 19], [8, 30], [8, 32], [8, 33], [8, 35], [9, 0], [9, 3], [9, 12], [9, 13], [9, 15], [9, 24], [9, 26], [9, 30], [9, 32], [9, 33], [9, 34], [9, 38], [10, 0], [10, 2], [10, 6], [10, 7], [10, 17], [10, 23], [10, 31], [10, 36], [10, 38], [11, 1], [11, 9], [11, 10], [11, 14], [11, 18], [11, 21], [11, 22], [11, 23], [11, 25], [11, 26], [11, 33], [11, 35], [12, 4], [12, 6], [12, 8], [12, 20], [12, 22], [12, 27], [12, 29], [12, 30], [12, 33], [12, 37], [13, 1], [13, 7], [13, 8], [13, 9], [13, 15], [13, 17], [13, 18], [13, 19], [13, 25], [13, 27], [13, 29], [13, 31], [13, 33], [13, 39], [14, 3], [14, 4], [14, 20], [14, 23], [14, 26], [14, 30], [14, 31], [14, 33], [14, 36], [15, 2], [15, 5], [15, 7], [15, 8], [15, 17], [15, 30], [15, 31], [16, 1], [16, 8], [16, 16], [16, 23], [16, 26], [16, 30], [16, 31], [16, 32], [16, 33], [16, 34], [16, 38], [17, 2], [17, 4], [17, 7], [17, 10], [17, 12], [17, 18], [17, 22], [17, 23], [17, 27], [17, 32], [17, 39], [18, 1], [18, 2], [18, 6], [18, 8], [18, 9], [18, 13], [18, 20], [18, 23], [18, 24], [18, 25], [18, 29], [18, 32]], #0.41-0.59 c>=3
[[5, 24], [5, 30], [6, 0], [6, 4], [6, 5], [6, 16], [6, 17], [6, 22], [6, 26], [6, 30], [6, 36], [6, 37], [6, 39], [7, 6], [7, 10], [7, 14], [7, 25], [7, 28], [7, 30], [7, 32], [7, 36], [7, 37], [8, 0], [8, 2], [8, 12], [8, 13], [8, 19], [8, 30], [8, 32], [8, 39], [9, 0], [9, 3], [9, 12], [9, 13], [9, 24], [9, 26], [9, 30], [9, 32], [9, 33], [9, 34], [9, 38], [10, 2], [10, 5], [10, 6], [10, 7], [10, 17], [10, 23], [10, 30], [10, 38], [11, 1], [11, 9], [11, 14], [11, 16], [11, 21], [11, 22], [11, 23], [11, 26], [11, 35], [12, 3], [12, 8], [12, 29], [12, 33], [12, 37], [12, 38], [13, 1], [13, 4], [13, 5], [13, 7], [13, 8], [13, 11], [13, 17], [13, 18], [13, 19], [13, 25], [13, 27], [13, 31], [13, 33], [13, 35], [13, 39], [14, 3], [14, 7], [14, 19], [14, 20], [14, 23], [14, 26], [14, 31], [14, 33], [14, 36], [15, 7], [15, 11], [15, 17], [15, 30], [15, 39], [16, 1], [16, 8], [16, 23], [16, 30], [16, 31], [16, 32], [16, 33], [16, 34], [17, 1], [17, 2], [17, 4], [17, 7], [17, 12], [17, 13], [17, 18], [17, 22], [17, 23], [17, 27], [17, 32], [17, 39], [18, 1], [18, 5], [18, 6], [18, 8], [18, 9], [18, 13], [18, 20], [18, 24], [18, 25], [18, 31], [18, 33]],#0.38-0.62 c>=4
[[5, 24], [5, 30], [6, 0], [6, 4], [6, 5], [6, 16], [6, 17], [6, 22], [6, 26], [6, 30], [6, 36], [6, 37], [6, 39], [7, 6], [7, 10], [7, 14], [7, 25], [7, 28], [7, 30], [7, 32], [7, 36], [7, 37], [8, 0], [8, 2], [8, 12], [8, 13], [8, 19], [8, 30], [8, 32], [8, 39], [9, 0], [9, 3], [9, 12], [9, 13], [9, 24], [9, 26], [9, 30], [9, 32], [9, 33], [9, 34], [9, 38], [10, 2], [10, 5], [10, 6], [10, 7], [10, 17], [10, 23], [10, 30], [10, 38], [11, 1], [11, 9], [11, 14], [11, 16], [11, 21], [11, 22], [11, 23], [11, 26], [11, 35], [12, 3], [12, 8], [12, 29], [12, 33], [12, 37], [12, 38], [13, 1], [13, 4], [13, 5], [13, 7], [13, 8], [13, 11], [13, 17], [13, 18], [13, 19], [13, 25], [13, 27], [13, 31], [13, 33], [13, 35], [13, 39], [14, 3], [14, 7], [14, 19], [14, 20], [14, 23], [14, 26], [14, 31], [14, 33], [14, 36], [15, 7], [15, 11], [15, 17], [15, 30], [15, 39], [16, 1], [16, 8], [16, 23], [16, 30], [16, 31], [16, 32], [16, 33], [16, 34], [17, 1], [17, 2], [17, 4], [17, 7], [17, 12], [17, 13], [17, 18], [17, 22], [17, 23], [17, 27], [17, 32], [17, 39], [18, 1], [18, 5], [18, 6], [18, 8], [18, 9], [18, 13], [18, 20], [18, 24], [18, 25], [18, 31], [18, 33]],#0.38-0.62 c>=4
[[5, 24], [5, 30], [6, 0], [6, 4], [6, 5], [6, 16], [6, 17], [6, 22], [6, 26], [6, 30], [6, 36], [6, 37], [6, 39], [7, 6], [7, 10], [7, 14], [7, 25], [7, 28], [7, 30], [7, 32], [7, 36], [7, 37], [8, 0], [8, 2], [8, 12], [8, 13], [8, 19], [8, 30], [8, 32], [8, 39], [9, 0], [9, 3], [9, 12], [9, 13], [9, 24], [9, 26], [9, 30], [9, 32], [9, 33], [9, 34], [9, 38], [10, 2], [10, 5], [10, 6], [10, 7], [10, 17], [10, 23], [10, 30], [10, 38], [11, 1], [11, 9], [11, 14], [11, 16], [11, 21], [11, 22], [11, 23], [11, 26], [11, 35], [12, 3], [12, 8], [12, 29], [12, 33], [12, 37], [12, 38], [13, 1], [13, 4], [13, 5], [13, 7], [13, 8], [13, 11], [13, 17], [13, 18], [13, 19], [13, 25], [13, 27], [13, 31], [13, 33], [13, 35], [13, 39], [14, 3], [14, 7], [14, 19], [14, 20], [14, 23], [14, 26], [14, 31], [14, 33], [14, 36], [15, 7], [15, 11], [15, 17], [15, 30], [15, 39], [16, 1], [16, 8], [16, 23], [16, 30], [16, 31], [16, 32], [16, 33], [16, 34], [17, 1], [17, 2], [17, 4], [17, 7], [17, 12], [17, 13], [17, 18], [17, 22], [17, 23], [17, 27], [17, 32], [17, 39], [18, 1], [18, 5], [18, 6], [18, 8], [18, 9], [18, 13], [18, 20], [18, 24], [18, 25], [18, 31], [18, 33]],#0.38-0.62 c>=4

]

hallucination_heads_list=[
    [[5, 0], [5, 2], [5, 4], [5, 5], [5, 7], [5, 8], [5, 13], [5, 16], [5, 18], [5, 19], [5, 20], [5, 23], [5, 28], [5, 29], [5, 31], [5, 37], [5, 39], [6, 2], [6, 3], [6, 8], [6, 9], [6, 10], [6, 14], [6, 15], [6, 17], [6, 18], [6, 24], [6, 27], [6, 30], [6, 31], [6, 32], [6, 34], [6, 35], [7, 2], [7, 6], [7, 9], [7, 10], [7, 17], [7, 20], [7, 21], [7, 23], [7, 24], [7, 26], [7, 27], [7, 38], [7, 39], [8, 1], [8, 3], [8, 4], [8, 6], [8, 7], [8, 8], [8, 9], [8, 10], [8, 14], [8, 15], [8, 18], [8, 19], [8, 24], [8, 26], [8, 27], [8, 29], [8, 37], [8, 38], [9, 6], [9, 8], [9, 10], [9, 11], [9, 12], [9, 13], [9, 14], [9, 15], [9, 16], [9, 17], [9, 18], [9, 19], [9, 20], [9, 26], [9, 27], [9, 37], [10, 9], [10, 11], [10, 12], [10, 14], [10, 15], [10, 18], [10, 20], [10, 21], [10, 28], [10, 32], [10, 33], [10, 37], [11, 10], [11, 13], [11, 17], [11, 23], [11, 24], [11, 28], [11, 29], [11, 31], [11, 32], [11, 34], [12, 0], [12, 1], [12, 5], [12, 7], [12, 9], [12, 10], [12, 13], [12, 16], [12, 18], [12, 19], [12, 21], [12, 23], [12, 24], [12, 28], [12, 34], [12, 35], [13, 0], [13, 2], [13, 6], [13, 10], [13, 14], [13, 18], [13, 20], [13, 22], [13, 24], [13, 32], [13, 36], [13, 39], [14, 3], [14, 8], [14, 9], [14, 10], [14, 14], [14, 16], [14, 22], [14, 23], [14, 28], [14, 29], [14, 30], [14, 31], [14, 35], [14, 38], [15, 0], [15, 2], [15, 5], [15, 8], [15, 13], [15, 15], [15, 16], [15, 17], [15, 19], [15, 21], [15, 23], [15, 29], [15, 32], [15, 34], [15, 36], [15, 37], [15, 38], [15, 39], [16, 0], [16, 2], [16, 5], [16, 9], [16, 13], [16, 14], [16, 16], [16, 17], [16, 18], [16, 20], [16, 22], [16, 23], [16, 24], [16, 25], [16, 29], [16, 33], [16, 35], [16, 36], [16, 37], [16, 39], [17, 2], [17, 6], [17, 9], [17, 10], [17, 14], [17, 15], [17, 24], [17, 25], [17, 29], [17, 34], [17, 35], [17, 37], [17, 38], [18, 1], [18, 12], [18, 14], [18, 16], [18, 17], [18, 18], [18, 23], [18, 30], [18, 32], [18, 34], [18, 35]],#3,5,7
    [[5, 0], [5, 2], [5, 4], [5, 6], [5, 7], [5, 8], [5, 11], [5, 12], [5, 16], [5, 18], [5, 19], [5, 20], [5, 21], [5, 22], [5, 23], [5, 28], [5, 29], [5, 31], [5, 33], [5, 35], [5, 37], [5, 39], [6, 2], [6, 6], [6, 8], [6, 9], [6, 15], [6, 17], [6, 18], [6, 21], [6, 24], [6, 30], [6, 31], [6, 32], [6, 34], [6, 35], [7, 2], [7, 6], [7, 15], [7, 16], [7, 17], [7, 18], [7, 19], [7, 22], [7, 24], [7, 26], [7, 27], [7, 39], [8, 1], [8, 4], [8, 5], [8, 6], [8, 9], [8, 10], [8, 14], [8, 15], [8, 17], [8, 18], [8, 24], [8, 26], [8, 27], [8, 31], [8, 36], [8, 37], [8, 38], [9, 1], [9, 3], [9, 6], [9, 8], [9, 10], [9, 12], [9, 13], [9, 14], [9, 15], [9, 16], [9, 17], [9, 19], [9, 20], [9, 22], [9, 25], [9, 29], [9, 37], [10, 9], [10, 10], [10, 11], [10, 12], [10, 14], [10, 15], [10, 20], [10, 21], [10, 24], [10, 26], [10, 28], [10, 29], [10, 32], [10, 33], [10, 34], [11, 0], [11, 2], [11, 5], [11, 6], [11, 8], [11, 10], [11, 13], [11, 14], [11, 17], [11, 24], [11, 27], [11, 28], [11, 29], [11, 30], [11, 31], [11, 32], [11, 33], [11, 34], [11, 35], [12, 0], [12, 4], [12, 5], [12, 7], [12, 10], [12, 16], [12, 18], [12, 23], [12, 25], [12, 28], [12, 34], [12, 36], [12, 39], [13, 1], [13, 2], [13, 9], [13, 10], [13, 14], [13, 16], [13, 17], [13, 18], [13, 20], [13, 21], [13, 22], [13, 23], [13, 24], [13, 29], [13, 39], [14, 0], [14, 4], [14, 8], [14, 9], [14, 12], [14, 14], [14, 16], [14, 22], [14, 24], [14, 29], [14, 30], [14, 35], [14, 37], [14, 38], [15, 0], [15, 2], [15, 4], [15, 8], [15, 10], [15, 13], [15, 15], [15, 16], [15, 17], [15, 19], [15, 20], [15, 21], [15, 23], [15, 29], [15, 31], [15, 32], [15, 37], [15, 38], [16, 0], [16, 2], [16, 4], [16, 5], [16, 9], [16, 12], [16, 13], [16, 15], [16, 17], [16, 18], [16, 20], [16, 21], [16, 22], [16, 23], [16, 24], [16, 25], [16, 29], [16, 32], [16, 35], [16, 36], [17, 0], [17, 5], [17, 6], [17, 9], [17, 15], [17, 17], [17, 24], [17, 25], [17, 29], [17, 34], [17, 37], [18, 2], [18, 14], [18, 15], [18, 16], [18, 21], [18, 23], [18, 30], [18, 32], [18, 34], [18, 39]],#4,5,5
    [[5, 0], [5, 7], [5, 8], [5, 11], [5, 12], [5, 16], [5, 18], [5, 19], [5, 23], [5, 27], [5, 28], [5, 29], [5, 31], [5, 37], [5, 39], [6, 2], [6, 8], [6, 9], [6, 15], [6, 17], [6, 18], [6, 24], [6, 31], [6, 32], [6, 34], [7, 2], [7, 6], [7, 18], [7, 20], [7, 24], [7, 26], [7, 39], [8, 1], [8, 5], [8, 6], [8, 10], [8, 14], [8, 15], [8, 18], [8, 19], [8, 20], [8, 24], [8, 26], [8, 27], [8, 29], [8, 37], [8, 38], [9, 3], [9, 8], [9, 12], [9, 13], [9, 14], [9, 15], [9, 16], [9, 17], [9, 18], [9, 19], [9, 25], [9, 29], [9, 37], [10, 4], [10, 9], [10, 11], [10, 12], [10, 14], [10, 15], [10, 18], [10, 20], [10, 21], [10, 28], [10, 32], [10, 33], [11, 5], [11, 6], [11, 10], [11, 13], [11, 17], [11, 27], [11, 28], [11, 29], [11, 30], [11, 31], [11, 32], [11, 34], [12, 0], [12, 10], [12, 13], [12, 16], [12, 18], [12, 19], [12, 21], [12, 23], [12, 24], [12, 28], [12, 36], [13, 1], [13, 6], [13, 10], [13, 14], [13, 20], [13, 22], [13, 29], [13, 32], [13, 39], [14, 8], [14, 9], [14, 12], [14, 14], [14, 16], [14, 22], [14, 29], [14, 30], [14, 35], [14, 38], [15, 2], [15, 4], [15, 8], [15, 13], [15, 15], [15, 16], [15, 17], [15, 21], [15, 23], [15, 26], [15, 29], [15, 31], [15, 32], [15, 37], [15, 38], [16, 0], [16, 2], [16, 4], [16, 5], [16, 9], [16, 13], [16, 17], [16, 18], [16, 20], [16, 21], [16, 22], [16, 24], [16, 25], [16, 29], [16, 36], [16, 37], [17, 5], [17, 9], [17, 15], [17, 25], [17, 29], [17, 34], [17, 37], [18, 14], [18, 15], [18, 16], [18, 21], [18, 23], [18, 30], [18, 32], [18, 34], [18, 39]],#4,4,6
    [[5, 0], [5, 2], [5, 4], [5, 5], [5, 7], [5, 8], [5, 13], [5, 16], [5, 18], [5, 19], [5, 20], [5, 23], [5, 28], [5, 29], [5, 31], [5, 37], [5, 39], [6, 2], [6, 3], [6, 8], [6, 9], [6, 10], [6, 14], [6, 15], [6, 17], [6, 18], [6, 24], [6, 27], [6, 30], [6, 31], [6, 32], [6, 34], [6, 35], [7, 2], [7, 6], [7, 9], [7, 10], [7, 17], [7, 20], [7, 21], [7, 23], [7, 24], [7, 26], [7, 27], [7, 38], [7, 39], [8, 1], [8, 3], [8, 4], [8, 6], [8, 7], [8, 8], [8, 9], [8, 10], [8, 14], [8, 15], [8, 18], [8, 19], [8, 24], [8, 26], [8, 27], [8, 29], [8, 37], [8, 38], [9, 6], [9, 8], [9, 10], [9, 11], [9, 12], [9, 13], [9, 14], [9, 15], [9, 16], [9, 17], [9, 18], [9, 19], [9, 20], [9, 26], [9, 27], [9, 37], [10, 9], [10, 11], [10, 12], [10, 14], [10, 15], [10, 18], [10, 20], [10, 21], [10, 28], [10, 32], [10, 33], [10, 37], [11, 10], [11, 13], [11, 17], [11, 23], [11, 24], [11, 28], [11, 29], [11, 31], [11, 32], [11, 34], [12, 0], [12, 1], [12, 5], [12, 7], [12, 9], [12, 10], [12, 13], [12, 16], [12, 18], [12, 19], [12, 21], [12, 23], [12, 24], [12, 28], [12, 34], [12, 35], [13, 0], [13, 2], [13, 6], [13, 10], [13, 14], [13, 18], [13, 20], [13, 22], [13, 24], [13, 32], [13, 36], [13, 39], [14, 3], [14, 8], [14, 9], [14, 10], [14, 14], [14, 16], [14, 22], [14, 23], [14, 28], [14, 29], [14, 30], [14, 31], [14, 35], [14, 38], [15, 0], [15, 2], [15, 5], [15, 8], [15, 13], [15, 15], [15, 16], [15, 17], [15, 19], [15, 21], [15, 23], [15, 29], [15, 32], [15, 34], [15, 36], [15, 37], [15, 38], [15, 39], [16, 0], [16, 2], [16, 5], [16, 9], [16, 13], [16, 14], [16, 16], [16, 17], [16, 18], [16, 20], [16, 22], [16, 23], [16, 24], [16, 25], [16, 29], [16, 33], [16, 35], [16, 36], [16, 37], [16, 39], [17, 2], [17, 6], [17, 9], [17, 10], [17, 14], [17, 15], [17, 24], [17, 25], [17, 29], [17, 34], [17, 35], [17, 37], [17, 38], [18, 1], [18, 12], [18, 14], [18, 16], [18, 17], [18, 18], [18, 23], [18, 30], [18, 32], [18, 34], [18, 35]],#3,5,7
    [[5, 0], [5, 2], [5, 4], [5, 6], [5, 7], [5, 8], [5, 11], [5, 12], [5, 16], [5, 18], [5, 19], [5, 20], [5, 21], [5, 22], [5, 23], [5, 28], [5, 29], [5, 31], [5, 33], [5, 35], [5, 37], [5, 39], [6, 2], [6, 6], [6, 8], [6, 9], [6, 15], [6, 17], [6, 18], [6, 21], [6, 24], [6, 30], [6, 31], [6, 32], [6, 34], [6, 35], [7, 2], [7, 6], [7, 15], [7, 16], [7, 17], [7, 18], [7, 19], [7, 22], [7, 24], [7, 26], [7, 27], [7, 39], [8, 1], [8, 4], [8, 5], [8, 6], [8, 9], [8, 10], [8, 14], [8, 15], [8, 17], [8, 18], [8, 24], [8, 26], [8, 27], [8, 31], [8, 36], [8, 37], [8, 38], [9, 1], [9, 3], [9, 6], [9, 8], [9, 10], [9, 12], [9, 13], [9, 14], [9, 15], [9, 16], [9, 17], [9, 19], [9, 20], [9, 22], [9, 25], [9, 29], [9, 37], [10, 9], [10, 10], [10, 11], [10, 12], [10, 14], [10, 15], [10, 20], [10, 21], [10, 24], [10, 26], [10, 28], [10, 29], [10, 32], [10, 33], [10, 34], [11, 0], [11, 2], [11, 5], [11, 6], [11, 8], [11, 10], [11, 13], [11, 14], [11, 17], [11, 24], [11, 27], [11, 28], [11, 29], [11, 30], [11, 31], [11, 32], [11, 33], [11, 34], [11, 35], [12, 0], [12, 4], [12, 5], [12, 7], [12, 10], [12, 16], [12, 18], [12, 23], [12, 25], [12, 28], [12, 34], [12, 36], [12, 39], [13, 1], [13, 2], [13, 9], [13, 10], [13, 14], [13, 16], [13, 17], [13, 18], [13, 20], [13, 21], [13, 22], [13, 23], [13, 24], [13, 29], [13, 39], [14, 0], [14, 4], [14, 8], [14, 9], [14, 12], [14, 14], [14, 16], [14, 22], [14, 24], [14, 29], [14, 30], [14, 35], [14, 37], [14, 38], [15, 0], [15, 2], [15, 4], [15, 8], [15, 10], [15, 13], [15, 15], [15, 16], [15, 17], [15, 19], [15, 20], [15, 21], [15, 23], [15, 29], [15, 31], [15, 32], [15, 37], [15, 38], [16, 0], [16, 2], [16, 4], [16, 5], [16, 9], [16, 12], [16, 13], [16, 15], [16, 17], [16, 18], [16, 20], [16, 21], [16, 22], [16, 23], [16, 24], [16, 25], [16, 29], [16, 32], [16, 35], [16, 36], [17, 0], [17, 5], [17, 6], [17, 9], [17, 15], [17, 17], [17, 24], [17, 25], [17, 29], [17, 34], [17, 37], [18, 2], [18, 14], [18, 15], [18, 16], [18, 21], [18, 23], [18, 30], [18, 32], [18, 34], [18, 39]],#4,5,5
    [[5, 0], [5, 7], [5, 8], [5, 11], [5, 12], [5, 16], [5, 18], [5, 19], [5, 23], [5, 27], [5, 28], [5, 29], [5, 31], [5, 37], [5, 39], [6, 2], [6, 8], [6, 9], [6, 15], [6, 17], [6, 18], [6, 24], [6, 31], [6, 32], [6, 34], [7, 2], [7, 6], [7, 18], [7, 20], [7, 24], [7, 26], [7, 39], [8, 1], [8, 5], [8, 6], [8, 10], [8, 14], [8, 15], [8, 18], [8, 19], [8, 20], [8, 24], [8, 26], [8, 27], [8, 29], [8, 37], [8, 38], [9, 3], [9, 8], [9, 12], [9, 13], [9, 14], [9, 15], [9, 16], [9, 17], [9, 18], [9, 19], [9, 25], [9, 29], [9, 37], [10, 4], [10, 9], [10, 11], [10, 12], [10, 14], [10, 15], [10, 18], [10, 20], [10, 21], [10, 28], [10, 32], [10, 33], [11, 5], [11, 6], [11, 10], [11, 13], [11, 17], [11, 27], [11, 28], [11, 29], [11, 30], [11, 31], [11, 32], [11, 34], [12, 0], [12, 10], [12, 13], [12, 16], [12, 18], [12, 19], [12, 21], [12, 23], [12, 24], [12, 28], [12, 36], [13, 1], [13, 6], [13, 10], [13, 14], [13, 20], [13, 22], [13, 29], [13, 32], [13, 39], [14, 8], [14, 9], [14, 12], [14, 14], [14, 16], [14, 22], [14, 29], [14, 30], [14, 35], [14, 38], [15, 2], [15, 4], [15, 8], [15, 13], [15, 15], [15, 16], [15, 17], [15, 21], [15, 23], [15, 26], [15, 29], [15, 31], [15, 32], [15, 37], [15, 38], [16, 0], [16, 2], [16, 4], [16, 5], [16, 9], [16, 13], [16, 17], [16, 18], [16, 20], [16, 21], [16, 22], [16, 24], [16, 25], [16, 29], [16, 36], [16, 37], [17, 5], [17, 9], [17, 15], [17, 25], [17, 29], [17, 34], [17, 37], [18, 14], [18, 15], [18, 16], [18, 21], [18, 23], [18, 30], [18, 32], [18, 34], [18, 39]],#4,4,6
]

alpha_list = [0.3,0.35,0.5,0.6,0.65]
beta_list  = [0.8,0.85,0.9,0.99,0.97]

ablation_scheme = "mean"
circuit_mlps = []
include_mlps = False

def write_log(message):
    print(message)  # Print to the screen
    with open(log_file, 'a') as f:  # Append to the log file
        f.write(message + '\n')

# Read COCO GT data (unchanged)
with open(txt_file) as f:
    all_img_names = [l.strip() for l in f]
with open(ann_file) as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname2labels[imgid2fname[ann["image_id"]]].add(catid2name[ann["category_id"]])
fname2labels = {k: list(v) for k, v in fname2labels.items()}

prompt = "USER: <image>\nPlease describe the image in detail.\nASSISTANT:"

with open(txt_file) as f:
    all_img_names = [x.strip() for x in f if x.strip()]
with open(txt_file_chosen) as f:
    chosen_names = set(x.strip() for x in f if x.strip())

valid_img_names = [x for x in all_img_names if x not in chosen_names]
required_imgs = n_rounds * N
assert len(valid_img_names) >= required_imgs, f"Insufficient images: need {required_imgs}, but got {len(valid_img_names)}"

fixed_sets = random.sample(valid_img_names, required_imgs)  # Fixed sampling

def write_log(text):
    print(text, end="")
    with open(log_file, "a") as f:
        f.write(text)

# ============ Write log file header ============
with open(log_file, "w") as f:
    f.write("Alpha,Beta,Round,Type,CHAIR-s,CHAIR-i,F1,Len")

# ============ Global Baseline Precomputation ============
print("=== Computing baseline metrics for all rounds ===")
with open(log_file, "w") as f:
    f.write("Round,CHAIR-s,CHAIR-i,F1,Len\n")

baseline_results = []
for round_idx in range(1, n_rounds + 1):
    round_img_names = fixed_sets[(round_idx - 1) * N: round_idx * N]
    imgs = [Image.open(f"{img_dir}/{fn}").convert("RGB") for fn in round_img_names]
    prompts = [prompt] * len(round_img_names)
    with torch.no_grad():
        preds_base = batch_generate_llava(
            model, tokenizer, processor,
            imgs, prompts,
            device="cuda", max_new_tokens=128
        )
    labels = [fname2labels.get(fn, []) for fn in round_img_names]
    mb = batch_compute_chair_metrics(preds_base, labels)
    print(f"Round{round_idx} Baseline: CHAIR-s={mb['CHAIR-s']:.4f}, CHAIR-i={mb['CHAIR-i']:.4f}, F1={mb['F1']:.4f}, Len={mb['Len']:.2f}")
    
    write_log(f"{round_idx},{mb['CHAIR-s']:.4f},{mb['CHAIR-i']:.4f},"
              f"{mb['F1']:.4f},{mb['Len']:.2f}\n")
    baseline_results.append((mb['CHAIR-s'], mb['CHAIR-i'], mb['F1'], mb['Len']))
baseline_array = np.array(baseline_results)

# ============ Loop over head groups ============
for grp_idx, (raw_hallu, faithful) in enumerate(zip(hallucination_heads_list, faithful_heads_list), start=1):
    # Remove overlaps
    faith_set = set(map(tuple, faithful))
    hallu = [h for h in raw_hallu if tuple(h) not in faith_set]

    write_log(f"\n===== Head Group {grp_idx} =====\n")
    for alpha, beta in product(alpha_list, beta_list):
        write_log(f"\nα={alpha:.2f}, β={beta:.2f}\n")
        pruned_list = []
        
        # Clear hooks
        for layer in model.model.language_model.layers:
            layer.self_attn._forward_hooks.clear()

        for rnd in range(1, n_rounds + 1):
            names = fixed_sets[(rnd - 1) * N: rnd * N]
            imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in names]

            # Register hooks, only prune on target_layers
            hooks = prune_model_llava_dynamic(
                model,
                faithful_heads=[h for h in faithful if h[0] in target_layers],
                hallucination_heads=[h for h in hallu if h[0] in target_layers],
                target_layers=target_layers,
                alpha=alpha, beta=beta
            )

            with torch.no_grad():
                preds_p = batch_generate_llava(model, tokenizer, processor,
                                               imgs, [prompt] * N,
                                               device="cuda", max_new_tokens=128)
            m_p = batch_compute_chair_metrics(preds_p, [fname2labels.get(fn, []) for fn in names])
            write_log(f"{alpha:.2f},{beta:.2f},{rnd},Pruned,"
                      f"{m_p['CHAIR-s']:.4f},{m_p['CHAIR-i']:.4f},"
                      f"{m_p['F1']:.4f},{m_p['Len']:.2f}\n")
            pruned_list.append((m_p['CHAIR-s'], m_p['CHAIR-i'], m_p['F1'], m_p['Len']))

            # Remove hooks
            for h in hooks:
                h.remove()

        # Compute average relative change
        rp = np.array(pruned_list)
        cr = ((rp - baseline_array) / (baseline_array + 1e-6)).mean(axis=0)
        write_log(f"α={alpha:.2f},β={beta:.2f} Average change rate: "
                  f"CHAIR-s={cr[0]*100:.2f}%, "
                  f"CHAIR-i={cr[1]*100:.2f}%, "
                  f"F1={cr[2]*100:.2f}%, "
                  f"Len={cr[3]*100:.2f}%\n")

print("Experiment complete. Logs saved to", log_file)

# Record end time
end_time = datetime.datetime.now()
# Calculate time difference
delta = end_time - start_time
total_seconds = delta.total_seconds()
# Convert to hours and minutes
hours, remainder = divmod(total_seconds, 3600)
minutes, _ = divmod(remainder, 60)

# Append results to log file
with open(log_file, "a", encoding="utf-8") as f:
    f.write(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Total duration: {int(hours)} hours {int(minutes)} minutes\n")
    f.write("-" * 40 + "\n")
