import os
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from collections import defaultdict
from PIL import Image
import random
import json
import numpy as np
from itertools import product
import datetime
from pruning_llava_utils import prune_model_llava_dynamic, batch_generate_llava
from chair_metrics import batch_compute_chair_metrics 

start_time = datetime.datetime.now()

model_id = "llava-hf/llava-1.5-7b-hf"

# Load model
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
# Recommended: use AutoProcessor to load processor
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)


# ============ Parameter settings ============
img_dir = "path/to/val2014"
txt_file = "./semantic_attribution/chosen_imgs_7b_5_18_hallu_160.txt"
txt_file_chosen = "holo.txt"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
log_file = "experiment_log_7b_5_18-hallu.txt"

# detail_log_file = "experiment_detail_7b_log.jsonl" # Optional, detailed sentence-level results for each round
N = 32                    # Number of images per round
n_rounds = 5              # Total number of experiment rounds

target_layers = list(range(5, 19))  # llava:(5,19) shikra:(3,14)

n_heads_per_layer = 32    # llava-7b:32 shikra:32 llava-13b:40 

# List of hallucination head groups (each sublist for a group)
hallucination_heads_list = [
  [[5, 0], [5, 2], [5, 5], [5, 8], [5, 11], [5, 13], [5, 15], [5, 16], [5, 17], [6, 0], [6, 1], [6, 3], [6, 4], [6, 9], [6, 11], [6, 15], [6, 21], [6, 22], [6, 23], [6, 25], [6, 26], [6, 27], [6, 28], [6, 29], [6, 30], [7, 6], [7, 12], [7, 14], [7, 17], [7, 19], [7, 20], [7, 21], [7, 23], [7, 28], [8, 0], [8, 4], [8, 6], [8, 8], [8, 11], [8, 20], [8, 23], [8, 24], [8, 26], [8, 28], [8, 29], [8, 31], [9, 3], [9, 4], [9, 5], [9, 6], [9, 8], [9, 12], [9, 18], [9, 28], [9, 30], [10, 0], [10, 16], [10, 22], [10, 25], [10, 30], [11, 2], [11, 3], [11, 7], [11, 9], [11, 11], [11, 26], [12, 5], [12, 6], [12, 10], [12, 13], [12, 14], [12, 17], [12, 18], [12, 21], [12, 23], [12, 24], [12, 30], [12, 31], [13, 1], [13, 2], [13, 3], [13, 5], [13, 7], [13, 8], [13, 12], [13, 17], [13, 23], [13, 27], [14, 1], [14, 2], [14, 3], [14, 4], [14, 6], [14, 9], [14, 11], [14, 16], [14, 18], [14, 24], [14, 25], [14, 28], [14, 31], [15, 1], [15, 5], [15, 8], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 31], [16, 0], [16, 1], [16, 2], [16, 6], [16, 9], [16, 10], [16, 15], [16, 19], [16, 20], [16, 23], [16, 25], [16, 27], [16, 29], [16, 31], [17, 1], [17, 3], [17, 4], [17, 5], [17, 11], [17, 13], [17, 14], [17, 17], [17, 18], [17, 19], [17, 21], [17, 25], [17, 28], [17, 29], [18, 0], [18, 1], [18, 5], [18, 6], [18, 9], [18, 12], [18, 13], [18, 14], [18, 16], [18, 18], [18, 19], [18, 22]],#3,5,7
  [[5, 2], [5, 3], [5, 5], [5, 7], [5, 8], [5, 15], [5, 16], [6, 0], [6, 3], [6, 12], [6, 15], [6, 17], [6, 18], [6, 22], [6, 25], [6, 27], [6, 29], [6, 30], [7, 10], [7, 13], [7, 14], [7, 15], [7, 17], [7, 19], [7, 21], [7, 23], [7, 26], [7, 28], [8, 3], [8, 4], [8, 5], [8, 8], [8, 10], [8, 11], [8, 13], [8, 16], [8, 20], [8, 23], [8, 24], [8, 25], [8, 26], [8, 29], [8, 30], [9, 4], [9, 7], [9, 18], [9, 23], [9, 28], [9, 30], [10, 1], [10, 2], [10, 4], [10, 5], [10, 7], [10, 16], [10, 22], [10, 23], [10, 26], [10, 27], [10, 30], [10, 31], [11, 0], [11, 1], [11, 2], [11, 7], [11, 9], [11, 11], [11, 12], [11, 13], [11, 26], [12, 5], [12, 6], [12, 8], [12, 10], [12, 13], [12, 14], [12, 18], [12, 21], [12, 23], [12, 24], [12, 26], [12, 29], [12, 31], [13, 1], [13, 2], [13, 3], [13, 7], [13, 8], [13, 12], [13, 17], [13, 25], [13, 27], [14, 2], [14, 3], [14, 4], [14, 8], [14, 9], [14, 11], [14, 13], [14, 15], [14, 16], [14, 24], [14, 27], [14, 28], [14, 31], [15, 1], [15, 2], [15, 3], [15, 5], [15, 6], [15, 8], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 28], [16, 0], [16, 1], [16, 2], [16, 6], [16, 9], [16, 10], [16, 15], [16, 20], [16, 21], [16, 22], [16, 23], [16, 27], [16, 28], [16, 29], [16, 31], [17, 4], [17, 5], [17, 11], [17, 13], [17, 18], [17, 19], [17, 21], [17, 24], [17, 25], [17, 27], [17, 28], [17, 29], [17, 30], [18, 1], [18, 5], [18, 6], [18, 8], [18, 10], [18, 12], [18, 19], [18, 22], [18, 27], [18, 28], [18, 29], [18, 30], [18, 31]],#5,5,5
  [[5, 3], [5, 5], [5, 15], [6, 1], [6, 3], [6, 9], [6, 11], [6, 12], [6, 15], [6, 18], [6, 22], [6, 27], [6, 28], [6, 29], [6, 30], [7, 12], [7, 13], [7, 15], [7, 17], [7, 19], [7, 21], [7, 23], [7, 28], [8, 0], [8, 4], [8, 6], [8, 8], [8, 11], [8, 13], [8, 20], [8, 23], [8, 24], [8, 26], [8, 28], [8, 29], [8, 30], [8, 31], [9, 4], [9, 18], [9, 28], [9, 30], [10, 0], [10, 1], [10, 4], [10, 5], [10, 7], [10, 16], [10, 23], [10, 25], [10, 27], [10, 30], [11, 0], [11, 2], [11, 7], [11, 9], [11, 11], [11, 26], [12, 6], [12, 10], [12, 13], [12, 14], [12, 18], [12, 21], [12, 23], [12, 24], [12, 31], [13, 1], [13, 2], [13, 3], [13, 7], [13, 8], [13, 17], [13, 23], [14, 2], [14, 3], [14, 4], [14, 6], [14, 11], [14, 13], [14, 15], [14, 16], [14, 17], [14, 18], [14, 24], [14, 28], [14, 31], [15, 1], [15, 3], [15, 5], [15, 6], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 28], [15, 31], [16, 0], [16, 1], [16, 2], [16, 9], [16, 10], [16, 15], [16, 19], [16, 20], [16, 22], [16, 27], [16, 29], [16, 31], [17, 4], [17, 11], [17, 13], [17, 14], [17, 18], [17, 19], [17, 21], [17, 24], [17, 25], [17, 27], [17, 28], [17, 29], [18, 1], [18, 5], [18, 6], [18, 12], [18, 18], [18, 19], [18, 22], [18, 27], [18, 28]],#4,4,6
  [[5, 0], [5, 2], [5, 5], [5, 8], [5, 11], [5, 13], [5, 15], [5, 16], [5, 17], [6, 0], [6, 1], [6, 3], [6, 4], [6, 9], [6, 11], [6, 15], [6, 21], [6, 22], [6, 23], [6, 25], [6, 26], [6, 27], [6, 28], [6, 29], [6, 30], [7, 6], [7, 12], [7, 14], [7, 17], [7, 19], [7, 20], [7, 21], [7, 23], [7, 28], [8, 0], [8, 4], [8, 6], [8, 8], [8, 11], [8, 20], [8, 23], [8, 24], [8, 26], [8, 28], [8, 29], [8, 31], [9, 3], [9, 4], [9, 5], [9, 6], [9, 8], [9, 12], [9, 18], [9, 28], [9, 30], [10, 0], [10, 16], [10, 22], [10, 25], [10, 30], [11, 2], [11, 3], [11, 7], [11, 9], [11, 11], [11, 26], [12, 5], [12, 6], [12, 10], [12, 13], [12, 14], [12, 17], [12, 18], [12, 21], [12, 23], [12, 24], [12, 30], [12, 31], [13, 1], [13, 2], [13, 3], [13, 5], [13, 7], [13, 8], [13, 12], [13, 17], [13, 23], [13, 27], [14, 1], [14, 2], [14, 3], [14, 4], [14, 6], [14, 9], [14, 11], [14, 16], [14, 18], [14, 24], [14, 25], [14, 28], [14, 31], [15, 1], [15, 5], [15, 8], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 31], [16, 0], [16, 1], [16, 2], [16, 6], [16, 9], [16, 10], [16, 15], [16, 19], [16, 20], [16, 23], [16, 25], [16, 27], [16, 29], [16, 31], [17, 1], [17, 3], [17, 4], [17, 5], [17, 11], [17, 13], [17, 14], [17, 17], [17, 18], [17, 19], [17, 21], [17, 25], [17, 28], [17, 29], [18, 0], [18, 1], [18, 5], [18, 6], [18, 9], [18, 12], [18, 13], [18, 14], [18, 16], [18, 18], [18, 19], [18, 22]],#3,5,7
  [[5, 2], [5, 3], [5, 5], [5, 7], [5, 8], [5, 15], [5, 16], [6, 0], [6, 3], [6, 12], [6, 15], [6, 17], [6, 18], [6, 22], [6, 25], [6, 27], [6, 29], [6, 30], [7, 10], [7, 13], [7, 14], [7, 15], [7, 17], [7, 19], [7, 21], [7, 23], [7, 26], [7, 28], [8, 3], [8, 4], [8, 5], [8, 8], [8, 10], [8, 11], [8, 13], [8, 16], [8, 20], [8, 23], [8, 24], [8, 25], [8, 26], [8, 29], [8, 30], [9, 4], [9, 7], [9, 18], [9, 23], [9, 28], [9, 30], [10, 1], [10, 2], [10, 4], [10, 5], [10, 7], [10, 16], [10, 22], [10, 23], [10, 26], [10, 27], [10, 30], [10, 31], [11, 0], [11, 1], [11, 2], [11, 7], [11, 9], [11, 11], [11, 12], [11, 13], [11, 26], [12, 5], [12, 6], [12, 8], [12, 10], [12, 13], [12, 14], [12, 18], [12, 21], [12, 23], [12, 24], [12, 26], [12, 29], [12, 31], [13, 1], [13, 2], [13, 3], [13, 7], [13, 8], [13, 12], [13, 17], [13, 25], [13, 27], [14, 2], [14, 3], [14, 4], [14, 8], [14, 9], [14, 11], [14, 13], [14, 15], [14, 16], [14, 24], [14, 27], [14, 28], [14, 31], [15, 1], [15, 2], [15, 3], [15, 5], [15, 6], [15, 8], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 28], [16, 0], [16, 1], [16, 2], [16, 6], [16, 9], [16, 10], [16, 15], [16, 20], [16, 21], [16, 22], [16, 23], [16, 27], [16, 28], [16, 29], [16, 31], [17, 4], [17, 5], [17, 11], [17, 13], [17, 18], [17, 19], [17, 21], [17, 24], [17, 25], [17, 27], [17, 28], [17, 29], [17, 30], [18, 1], [18, 5], [18, 6], [18, 8], [18, 10], [18, 12], [18, 19], [18, 22], [18, 27], [18, 28], [18, 29], [18, 30], [18, 31]],#5,5,5
  [[5, 3], [5, 5], [5, 15], [6, 1], [6, 3], [6, 9], [6, 11], [6, 12], [6, 15], [6, 18], [6, 22], [6, 27], [6, 28], [6, 29], [6, 30], [7, 12], [7, 13], [7, 15], [7, 17], [7, 19], [7, 21], [7, 23], [7, 28], [8, 0], [8, 4], [8, 6], [8, 8], [8, 11], [8, 13], [8, 20], [8, 23], [8, 24], [8, 26], [8, 28], [8, 29], [8, 30], [8, 31], [9, 4], [9, 18], [9, 28], [9, 30], [10, 0], [10, 1], [10, 4], [10, 5], [10, 7], [10, 16], [10, 23], [10, 25], [10, 27], [10, 30], [11, 0], [11, 2], [11, 7], [11, 9], [11, 11], [11, 26], [12, 6], [12, 10], [12, 13], [12, 14], [12, 18], [12, 21], [12, 23], [12, 24], [12, 31], [13, 1], [13, 2], [13, 3], [13, 7], [13, 8], [13, 17], [13, 23], [14, 2], [14, 3], [14, 4], [14, 6], [14, 11], [14, 13], [14, 15], [14, 16], [14, 17], [14, 18], [14, 24], [14, 28], [14, 31], [15, 1], [15, 3], [15, 5], [15, 6], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 28], [15, 31], [16, 0], [16, 1], [16, 2], [16, 9], [16, 10], [16, 15], [16, 19], [16, 20], [16, 22], [16, 27], [16, 29], [16, 31], [17, 4], [17, 11], [17, 13], [17, 14], [17, 18], [17, 19], [17, 21], [17, 24], [17, 25], [17, 27], [17, 28], [17, 29], [18, 1], [18, 5], [18, 6], [18, 12], [18, 18], [18, 19], [18, 22], [18, 27], [18, 28]],#4,4,6
]
faithful_heads_list = [
[[5, 4], [5, 10], [5, 12], [5, 17], [5, 19], [5, 20], [5, 24], [6, 1], [6, 6], [6, 7], [6, 10], [6, 13], [6, 22], [6, 26], [7, 0], [7, 1], [7, 2], [7, 6], [7, 7], [7, 12], [7, 14], [7, 28], [8, 1], [8, 2], [8, 12], [8, 19], [8, 20], [8, 21], [8, 23], [8, 30], [8, 31], [9, 1], [9, 5], [9, 6], [9, 10], [9, 15], [9, 20], [9, 21], [9, 24], [9, 25], [9, 29], [9, 31], [10, 6], [10, 7], [10, 8], [10, 19], [10, 21], [10, 28], [10, 29], [11, 6], [11, 14], [11, 21], [11, 24], [11, 28], [12, 2], [12, 4], [12, 7], [12, 8], [12, 15], [12, 19], [12, 25], [12, 28], [13, 14], [13, 18], [13, 24], [13, 28], [13, 30], [14, 1], [14, 5], [14, 7], [14, 14], [14, 21], [15, 4], [15, 7], [15, 15], [15, 19], [15, 29], [15, 30], [16, 4], [16, 7], [16, 11], [16, 12], [16, 13], [16, 16], [16, 17], [16, 29], [17, 0], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 20], [18, 21], [18, 23]],#0.38-0.42 c>=4
[[5, 4], [5, 10], [5, 12], [5, 17], [5, 19], [5, 20], [5, 24], [6, 1], [6, 6], [6, 7], [6, 10], [6, 13], [6, 22], [6, 26], [7, 0], [7, 1], [7, 2], [7, 6], [7, 7], [7, 12], [7, 14], [7, 28], [8, 1], [8, 2], [8, 12], [8, 19], [8, 20], [8, 21], [8, 23], [8, 30], [8, 31], [9, 1], [9, 5], [9, 6], [9, 10], [9, 15], [9, 20], [9, 21], [9, 24], [9, 25], [9, 29], [9, 31], [10, 6], [10, 7], [10, 8], [10, 19], [10, 21], [10, 28], [10, 29], [11, 6], [11, 14], [11, 21], [11, 24], [11, 28], [12, 2], [12, 4], [12, 7], [12, 8], [12, 15], [12, 19], [12, 25], [12, 28], [13, 14], [13, 18], [13, 24], [13, 28], [13, 30], [14, 1], [14, 5], [14, 7], [14, 14], [14, 21], [15, 4], [15, 7], [15, 15], [15, 19], [15, 29], [15, 30], [16, 4], [16, 7], [16, 11], [16, 12], [16, 13], [16, 16], [16, 17], [16, 29], [17, 0], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 20], [18, 21], [18, 23]],#0.38-0.42 c>=4
[[5, 4], [5, 10], [5, 12], [5, 17], [5, 19], [5, 20], [5, 24], [6, 1], [6, 6], [6, 7], [6, 10], [6, 13], [6, 22], [6, 26], [7, 0], [7, 1], [7, 2], [7, 6], [7, 7], [7, 12], [7, 14], [7, 28], [8, 1], [8, 2], [8, 12], [8, 19], [8, 20], [8, 21], [8, 23], [8, 30], [8, 31], [9, 1], [9, 5], [9, 6], [9, 10], [9, 15], [9, 20], [9, 21], [9, 24], [9, 25], [9, 29], [9, 31], [10, 6], [10, 7], [10, 8], [10, 19], [10, 21], [10, 28], [10, 29], [11, 6], [11, 14], [11, 21], [11, 24], [11, 28], [12, 2], [12, 4], [12, 7], [12, 8], [12, 15], [12, 19], [12, 25], [12, 28], [13, 14], [13, 18], [13, 24], [13, 28], [13, 30], [14, 1], [14, 5], [14, 7], [14, 14], [14, 21], [15, 4], [15, 7], [15, 15], [15, 19], [15, 29], [15, 30], [16, 4], [16, 7], [16, 11], [16, 12], [16, 13], [16, 16], [16, 17], [16, 29], [17, 0], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 20], [18, 21], [18, 23]],#0.38-0.42 c>=4
[[5, 0], [5, 9], [5, 10], [5, 19], [5, 24], [5, 30], [6, 1], [6, 6], [6, 10], [6, 16], [7, 1], [7, 2], [7, 3], [7, 7], [7, 10], [7, 28], [8, 2], [8, 5], [8, 20], [8, 21], [8, 22], [8, 31], [9, 1], [9, 5], [9, 10], [9, 14], [9, 15], [9, 18], [9, 21], [9, 24], [9, 25], [9, 29], [10, 6], [10, 11], [10, 13], [10, 17], [10, 28], [11, 4], [11, 6], [11, 14], [11, 28], [12, 7], [12, 8], [12, 11], [12, 12], [12, 15], [12, 19], [12, 25], [12, 27], [12, 28], [13, 14], [13, 18], [13, 19], [13, 24], [13, 30], [14, 0], [14, 7], [14, 11], [14, 14], [14, 21], [15, 0], [15, 7], [15, 19], [15, 21], [15, 24], [15, 29], [15, 30], [16, 4], [16, 5], [16, 7], [16, 11], [16, 12], [16, 16], [16, 19], [16, 24], [16, 29], [17, 0], [17, 4], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 8], [18, 13], [18, 20], [18, 21], [18, 23], [18, 24]],#0.42-0.58,c>=3
[[5, 0], [5, 9], [5, 10], [5, 19], [5, 24], [5, 30], [6, 1], [6, 6], [6, 10], [6, 16], [7, 1], [7, 2], [7, 3], [7, 7], [7, 10], [7, 28], [8, 2], [8, 5], [8, 20], [8, 21], [8, 22], [8, 31], [9, 1], [9, 5], [9, 10], [9, 14], [9, 15], [9, 18], [9, 21], [9, 24], [9, 25], [9, 29], [10, 6], [10, 11], [10, 13], [10, 17], [10, 28], [11, 4], [11, 6], [11, 14], [11, 28], [12, 7], [12, 8], [12, 11], [12, 12], [12, 15], [12, 19], [12, 25], [12, 27], [12, 28], [13, 14], [13, 18], [13, 19], [13, 24], [13, 30], [14, 0], [14, 7], [14, 11], [14, 14], [14, 21], [15, 0], [15, 7], [15, 19], [15, 21], [15, 24], [15, 29], [15, 30], [16, 4], [16, 5], [16, 7], [16, 11], [16, 12], [16, 16], [16, 19], [16, 24], [16, 29], [17, 0], [17, 4], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 8], [18, 13], [18, 20], [18, 21], [18, 23], [18, 24]],#0.42-0.58,c>=3
[[5, 0], [5, 9], [5, 10], [5, 19], [5, 24], [5, 30], [6, 1], [6, 6], [6, 10], [6, 16], [7, 1], [7, 2], [7, 3], [7, 7], [7, 10], [7, 28], [8, 2], [8, 5], [8, 20], [8, 21], [8, 22], [8, 31], [9, 1], [9, 5], [9, 10], [9, 14], [9, 15], [9, 18], [9, 21], [9, 24], [9, 25], [9, 29], [10, 6], [10, 11], [10, 13], [10, 17], [10, 28], [11, 4], [11, 6], [11, 14], [11, 28], [12, 7], [12, 8], [12, 11], [12, 12], [12, 15], [12, 19], [12, 25], [12, 27], [12, 28], [13, 14], [13, 18], [13, 19], [13, 24], [13, 30], [14, 0], [14, 7], [14, 11], [14, 14], [14, 21], [15, 0], [15, 7], [15, 19], [15, 21], [15, 24], [15, 29], [15, 30], [16, 4], [16, 5], [16, 7], [16, 11], [16, 12], [16, 16], [16, 19], [16, 24], [16, 29], [17, 0], [17, 4], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 8], [18, 13], [18, 20], [18, 21], [18, 23], [18, 24]],#0.42-0.58,c>=3
]


ablation_scheme = "mean"
circuit_mlps = []
include_mlps = False

# Load COCO GT data (keep unchanged)
with open(txt_file) as f:
    all_img_names = [l.strip() for l in f]
with open(ann_file) as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname2labels[imgid2fname[ann["image_id"]]].add(catid2name[ann["category_id"]])
fname2labels = {k:list(v) for k,v in fname2labels.items()}

prompt = "<image>\nPlease describe the image in detail."

with open(txt_file) as f:
    all_img_names = [x.strip() for x in f if x.strip()]
with open(txt_file_chosen) as f:
    chosen_names = set(x.strip() for x in f if x.strip())

valid_img_names = [x for x in all_img_names if x not in chosen_names]
required_imgs = n_rounds * N
assert len(valid_img_names) >= required_imgs, f"Not enough available images, required: {required_imgs}, actual: {len(valid_img_names)}"

fixed_sets = random.sample(valid_img_names, required_imgs)  # Fixed sampling

def write_log(message):
    print(message)  # Print to the screen
    with open(log_file, 'a') as f:  # Append to the log file
        f.write(message + '\n')

# ============ Write log file header ============
with open(log_file, "w") as f:
    f.write("Alpha,Beta,Round,Type,CHAIR-s,CHAIR-i,F1,Len\n")

alpha_list = [0.65, 0.35, 0.3, 0.6]
beta_list  = [0.8, 0.9, 0.95, 0.97, 0.99]

# ============ Global baseline pre-computation ============
print("=== Computing baseline metrics for all rounds ===")
baseline_results = []
for round_idx in range(1, n_rounds + 1):
    names = fixed_sets[(round_idx - 1) * N : round_idx * N]
    imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in names]
    prompts = [prompt] * len(names)
    with torch.no_grad():
        preds = batch_generate_llava(model, tokenizer, processor,
                                     imgs, prompts,
                                     device="cuda", max_new_tokens=128)
    metrics = batch_compute_chair_metrics(preds, [fname2labels.get(fn, []) for fn in names])
    line = f"0.00,0.00,{round_idx},Baseline,{metrics['CHAIR-s']:.4f}," \
           f"{metrics['CHAIR-i']:.4f},{metrics['F1']:.4f},{metrics['Len']:.2f}\n"
    write_log(line)
    baseline_results.append((metrics['CHAIR-s'], metrics['CHAIR-i'], metrics['F1'], metrics['Len']))
baseline_array = np.array(baseline_results)

# ============ Loop over head groups ============
for grp_idx, (raw_hallu, faithful) in enumerate(zip(hallucination_heads_list, faithful_heads_list), start=1):
    # Remove overlap between faithful and hallucination heads
    faith_set = set(map(tuple, faithful))
    hallu = [h for h in raw_hallu if tuple(h) not in faith_set]

    write_log(f"\n===== Head Group {grp_idx} =====\n")
    for alpha, beta in product(alpha_list, beta_list):
        write_log(f"\nα={alpha:.2f}, β={beta:.2f}\n")
        pruned_list = []
        
        # Clear activation hooks
        for layer in model.model.language_model.layers:
            layer.self_attn._forward_hooks.clear()

        for rnd in range(1, n_rounds + 1):
            names = fixed_sets[(rnd - 1) * N : rnd * N]
            imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in names]

            # Register hooks, prune only on target_layers
            hooks = prune_model_llava_dynamic(
                model,
                faithful_heads=[h for h in faithful if h[0] in target_layers],
                hallucination_heads=[h for h in hallu if h[0] in target_layers],
                target_layers=target_layers,
                alpha=alpha, beta=beta
            )

            with torch.no_grad():
                preds_p = batch_generate_llava(model, tokenizer, processor,
                                               imgs, [prompt] * N,
                                               device="cuda", max_new_tokens=128)
            m_p = batch_compute_chair_metrics(preds_p, [fname2labels.get(fn, []) for fn in names])
            write_log(f"{alpha:.2f},{beta:.2f},{rnd},Pruned,"
                      f"{m_p['CHAIR-s']:.4f},{m_p['CHAIR-i']:.4f},"
                      f"{m_p['F1']:.4f},{m_p['Len']:.2f}\n")
            pruned_list.append((m_p['CHAIR-s'], m_p['CHAIR-i'], m_p['F1'], m_p['Len']))

            # Remove all hooks
            for h in hooks: h.remove()

        # Compute average change rate
        rp = np.array(pruned_list)
        cr = ((rp - baseline_array) / (baseline_array + 1e-6)).mean(axis=0)
        write_log(f"α={alpha:.2f},β={beta:.2f} Average change rate: "
                  f"CHAIR-s={cr[0]*100:.2f}%, "
                  f"CHAIR-i={cr[1]*100:.2f}%, "
                  f"F1={cr[2]*100:.2f}%, "
                  f"Len={cr[3]*100:.2f}%\n")

print("Experiment complete. Logs saved to", log_file)

# Record end time
end_time = datetime.datetime.now()
# Calculate time difference
delta = end_time - start_time
total_seconds = delta.total_seconds()
# Convert to hours and minutes
hours, remainder = divmod(total_seconds, 3600)
minutes, _ = divmod(remainder, 60)

# Append the result to log file
with open(log_file, "a", encoding="utf-8") as f:
    f.write(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Total time: {int(hours)} hours {int(minutes)} minutes\n")
    f.write("-" * 40 + "\n")
