import os
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from collections import defaultdict
from PIL import Image
import random
import json
import numpy as np
from itertools import product
import datetime

from pruning_llava_utils import prune_model_llava_dynamic, batch_generate_llava
from chair_metrics import batch_compute_chair_metrics 
start_time = datetime.datetime.now()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


model_id = "llava-hf/llava-1.5-7b-hf"

model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)


# ============ Parameter Settings ============
img_dir = "path/toval2014"
txt_file = "all_img_names.txt"
txt_file_chosen = "holo.txt"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
log_file = "experiment_log_7b_combo.txt"
detail_log_file = "experiment_detail_7b_log.jsonl" 
N = 32                    # Number of images per round
n_rounds = 5              # Total number of experiment rounds
n_heads_per_layer = 32 
enable_detail_log = False # Optional, detailed sentence-level results for each round

# 30x4 =120

target_layers_1 = list(range(5, 19)) 
target_layers_2 = list(range(19, 27))  

faithful_heads_list_1 =[
 [[5, 4], [5, 10], [5, 12], [5, 17], [5, 19], [5, 20], [5, 24], [6, 1], [6, 6], [6, 7], [6, 10], [6, 13], [6, 22], [6, 26], [7, 0], [7, 1], [7, 2], [7, 6], [7, 7], [7, 12], [7, 14], [7, 28], [8, 1], [8, 2], [8, 12], [8, 19], [8, 20], [8, 21], [8, 23], [8, 30], [8, 31], [9, 1], [9, 5], [9, 6], [9, 10], [9, 15], [9, 20], [9, 21], [9, 24], [9, 25], [9, 29], [9, 31], [10, 6], [10, 7], [10, 8], [10, 19], [10, 21], [10, 28], [10, 29], [11, 6], [11, 14], [11, 21], [11, 24], [11, 28], [12, 2], [12, 4], [12, 7], [12, 8], [12, 15], [12, 19], [12, 25], [12, 28], [13, 14], [13, 18], [13, 24], [13, 28], [13, 30], [14, 1], [14, 5], [14, 7], [14, 14], [14, 21], [15, 4], [15, 7], [15, 15], [15, 19], [15, 29], [15, 30], [16, 4], [16, 7], [16, 11], [16, 12], [16, 13], [16, 16], [16, 17], [16, 29], [17, 0], [17, 12], [17, 22], [17, 26], [17, 31], [18, 2], [18, 7], [18, 20], [18, 21], [18, 23]],#0.38-0.42 c>=4
]

hallucination_heads_list_1=[
   [[5, 2], [5, 3], [5, 5], [5, 7], [5, 8], [5, 15], [5, 16], [6, 0], [6, 3], [6, 12], [6, 15], [6, 17], [6, 18], [6, 22], [6, 25], [6, 27], [6, 29], [6, 30], [7, 10], [7, 13], [7, 14], [7, 15], [7, 17], [7, 19], [7, 21], [7, 23], [7, 26], [7, 28], [8, 3], [8, 4], [8, 5], [8, 8], [8, 10], [8, 11], [8, 13], [8, 16], [8, 20], [8, 23], [8, 24], [8, 25], [8, 26], [8, 29], [8, 30], [9, 4], [9, 7], [9, 18], [9, 23], [9, 28], [9, 30], [10, 1], [10, 2], [10, 4], [10, 5], [10, 7], [10, 16], [10, 22], [10, 23], [10, 26], [10, 27], [10, 30], [10, 31], [11, 0], [11, 1], [11, 2], [11, 7], [11, 9], [11, 11], [11, 12], [11, 13], [11, 26], [12, 5], [12, 6], [12, 8], [12, 10], [12, 13], [12, 14], [12, 18], [12, 21], [12, 23], [12, 24], [12, 26], [12, 29], [12, 31], [13, 1], [13, 2], [13, 3], [13, 7], [13, 8], [13, 12], [13, 17], [13, 25], [13, 27], [14, 2], [14, 3], [14, 4], [14, 8], [14, 9], [14, 11], [14, 13], [14, 15], [14, 16], [14, 24], [14, 27], [14, 28], [14, 31], [15, 1], [15, 2], [15, 3], [15, 5], [15, 6], [15, 8], [15, 10], [15, 12], [15, 17], [15, 18], [15, 20], [15, 23], [15, 24], [15, 26], [15, 28], [16, 0], [16, 1], [16, 2], [16, 6], [16, 9], [16, 10], [16, 15], [16, 20], [16, 21], [16, 22], [16, 23], [16, 27], [16, 28], [16, 29], [16, 31], [17, 4], [17, 5], [17, 11], [17, 13], [17, 18], [17, 19], [17, 21], [17, 24], [17, 25], [17, 27], [17, 28], [17, 29], [17, 30], [18, 1], [18, 5], [18, 6], [18, 8], [18, 10], [18, 12], [18, 19], [18, 22], [18, 27], [18, 28], [18, 29], [18, 30], [18, 31]],#5,5,5
]


faithful_heads_list_2 =[
    [[19, 9], [19, 22], [19, 25], [19, 29], [20, 3], [20, 11], [20, 13], [20, 14], [20, 16], [20, 23], [20, 30], [21, 0], [21, 1], [21, 5], [21, 10], [21, 17], [22, 1], [22, 6], [22, 9], [22, 12], [22, 14], [22, 18], [22, 24], [22, 29], [22, 30], [23, 4], [23, 15], [23, 17], [23, 24], [23, 25], [23, 26], [23, 31], [24, 0], [24, 3], [24, 9], [24, 13], [24, 18], [24, 26], [25, 1], [25, 6], [25, 8], [25, 12], [25, 13], [25, 15], [25, 24], [25, 27], [25, 29], [25, 30], [26, 1], [26, 6], [26, 12], [26, 14], [26, 16]],#0.38-0.62,c>=4
]
hallucination_heads_list_2=[
    [[19, 0], [19, 2], [19, 4], [19, 6], [19, 13], [19, 15], [19, 17], [19, 19], [19, 20], [19, 21], [19, 24], [19, 26], [19, 28], [19, 31], [20, 0], [20, 1], [20, 2], [20, 5], [20, 7], [20, 8], [20, 9], [20, 12], [20, 13], [20, 18], [20, 26], [21, 4], [21, 6], [21, 7], [21, 9], [21, 12], [21, 13], [21, 16], [21, 18], [21, 26], [21, 27], [21, 31], [22, 3], [22, 5], [22, 6], [22, 14], [22, 16], [22, 17], [22, 20], [22, 21], [22, 25], [22, 26], [22, 28], [23, 0], [23, 1], [23, 3], [23, 6], [23, 8], [23, 10], [23, 18], [23, 22], [23, 30], [24, 4], [24, 5], [24, 9], [24, 10], [24, 11], [24, 12], [24, 17], [24, 19], [24, 20], [24, 22], [24, 28], [24, 30], [25, 4], [25, 9], [25, 14], [25, 19], [25, 21], [25, 22], [25, 29], [25, 30], [26, 21], [26, 26], [26, 27], [26, 30], [26, 31]],#4,4,6
]

alpha_beta_list_1 = [
    [0.65, 0.9],
    [0.65, 0.95],
    [0.6, 0.95],
    [0.6, 0.9],
    [0.35,0.97],
]

alpha_beta_list_2 = [
    [0.6, 0.99],
    [0.65,0.99],
    [0.65,0.97],
    [0.3,0.97],
]

ablation_scheme="mean"
circuit_mlps = []
include_mlps=False


if enable_detail_log:
    open(detail_log_file, "w", encoding="utf-8").close()

def write_log(message):
    print(message)  # Print to the screen
    with open(log_file, 'a') as f:  # Append to the log file
        f.write(message + '\n')


with open(txt_file) as f:
    all_img_names = [l.strip() for l in f]
with open(ann_file) as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname2labels[imgid2fname[ann["image_id"]]].add(catid2name[ann["category_id"]])
fname2labels = {k:list(v) for k,v in fname2labels.items()}

prompt = "<image>\nPlease describe the image in detail."

with open(txt_file) as f:
    all_img_names = [x.strip() for x in f if x.strip()]
with open(txt_file_chosen) as f:
    chosen_names = set(x.strip() for x in f if x.strip())

valid_img_names = [x for x in all_img_names if x not in chosen_names]
required_imgs = n_rounds * N
assert len(valid_img_names) >= required_imgs, f"Not enough available images, required: {required_imgs}, actual: {len(valid_img_names)}"

fixed_sets = random.sample(valid_img_names, required_imgs)  # Fixed sampling



def write_log(text):
    print(text)
    with open(log_file, "a") as f:
        f.write(text)


# ==================== Baseline Pre-computation ====================
print("=== Computing baseline metrics ===")
with open(log_file, "w") as f:
    f.write("Round,CHAIR-s,CHAIR-i,F1,Len\n")
rb_list = []
for round_idx in range(1, n_rounds+1):
    round_fnames = fixed_sets[(round_idx-1)*N : round_idx*N]
    imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in round_fnames]
    prompts = [prompt] * len(round_fnames)
    with torch.no_grad():
        preds = batch_generate_llava(model, tokenizer, processor, imgs, prompts,
                                     device="cuda", max_new_tokens=100)
    metrics = batch_compute_chair_metrics(preds, [fname2labels.get(fn, []) for fn in round_fnames])
    write_log(f"{round_idx},{metrics['CHAIR-s']:.4f},{metrics['CHAIR-i']:.4f},"
              f"{metrics['F1']:.4f},{metrics['Len']:.2f}\n")
    rb_list.append([metrics['CHAIR-s'], metrics['CHAIR-i'], metrics['F1'], metrics['Len']])
rb = np.array(rb_list)
mb_mean, mb_std = rb.mean(axis=0), rb.std(axis=0)

# ============ Write log file header ============
with open(log_file, "a") as f:
    f.write("Group1_Alpha,Group1_Beta,Group2_Alpha,Group2_Beta,Round,Type,CHAIR-s,CHAIR-i,F1,Len\n")


# 3D Ablation Group: G1 × G2
for g1_idx, (faithful1, hallu1) in enumerate(zip(faithful_heads_list_1, hallucination_heads_list_1)):
    for g2_idx, (faithful2, hallu2) in enumerate(zip(faithful_heads_list_2, hallucination_heads_list_2)):

        # ** Remove heads in hallucination list that overlap with faithful **
        hallu1_clean = [h for h in hallu1 if h not in faithful1]
        hallu2_clean = [h for h in hallu2 if h not in faithful2]

        # For each ablation group, run 4×3=12 alpha-beta combinations
        for (alpha1, beta1) in alpha_beta_list_1:
            for (alpha2, beta2) in alpha_beta_list_2:
                
                # Clear hooks
                for layer in model.model.language_model.layers:
                    layer.self_attn._forward_hooks.clear()

                # write_log(f"[CHECK] about to prune with α1={alpha1}, β1={beta1}, α2={alpha2}, β2={beta2}")

                group_tag = (
                    f"G1({g1_idx})_A({alpha1:.2f})_B({beta1:.2f})__"
                    f"G2({g2_idx})_A({alpha2:.2f})_B({beta2:.2f})"
                )
                write_log(f"\n===== {group_tag} =====\n")
                
                # write_log(f"[DEBUG heads] faithful1({len(faithful1)}): {faithful1}")
                # write_log(f"[DEBUG heads] hallu1_clean({len(hallu1_clean)}): {hallu1_clean}")
                # write_log(f"[DEBUG heads] faithful2({len(faithful2)}): {faithful2}")
                # write_log(f"[DEBUG heads] hallu2_clean({len(hallu2_clean)}): {hallu2_clean}")

                pruned_list = []
                # Iterate through rounds
                for round_idx in range(1, n_rounds+1):
                    round_fnames = fixed_sets[(round_idx-1)*N : round_idx*N]
                    imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in round_fnames]
                    prompts = [prompt] * len(round_fnames)

                    # # Collect the list of heads to be passed in
                    # heads1_f = [h for h in faithful1   if h[0] in target_layers_1]
                    # heads1_h = [h for h in hallu1_clean if h[0] in target_layers_1]
                    # heads2_f = [h for h in faithful2   if h[0] in target_layers_2]
                    # heads2_h = [h for h in hallu2_clean if h[0] in target_layers_2]

                    # # Then write to debug log
                    # write_log(f"[DEBUG sub-heads] G1 faithful→{heads1_f}, hallu→{heads1_h}")
                    # write_log(f"[DEBUG sub-heads] G2 faithful→{heads2_f}, hallu→{heads2_h}")
                    
                    # Register hooks for the two target_layers segments
                    hooks = []
                    # print(f" G1 pruned faithful₁ heads in layers: {[h for h in faithful1 if h[0] in target_layers_1]}, hallu₁ heads: {[h for h in hallu1_clean if h[0] in target_layers_1]}")
                    hooks += prune_model_llava_dynamic(
                        model,
                        faithful_heads=[h for h in faithful1 if h[0] in target_layers_1],
                        hallucination_heads=[h for h in hallu1_clean if h[0] in target_layers_1],
                        target_layers=target_layers_1,
                        alpha=alpha1, beta=beta1,
                    )
                    # print(f" G2 pruned faithful2 heads in layers: {[h for h in faithful2 if h[0] in target_layers_2]}, hallu2 heads: {[h for h in hallu2_clean if h[0] in target_layers_2]}")
                    hooks += prune_model_llava_dynamic(
                        model,
                        faithful_heads=[h for h in faithful2 if h[0] in target_layers_2],
                        hallucination_heads=[h for h in hallu2_clean if h[0] in target_layers_2],
                        target_layers=target_layers_2,
                        alpha=alpha2, beta=beta2,
                    )
                    # # Debug: log how many hooks registered this time
                    # write_log(f"[DEBUG hooks] registered {len(hooks)} hooks")

                    # Pruned inference
                    with torch.no_grad():
                        preds_p = batch_generate_llava(
                            model, tokenizer, processor,
                            imgs, prompts,
                            device="cuda", max_new_tokens=100
                        )
                    metrics_p = batch_compute_chair_metrics(
                        preds_p, [fname2labels.get(fn, []) for fn in round_fnames]
                    )
                    if enable_detail_log:
                        entry = {
                        "round": round_idx,
                        "names": round_fnames,           # List of image file names
                        "predictions": preds_p,   # List of model outputs
                        "metrics": metrics_p      # CHAIR-s, CHAIR-i, F1, Len, etc.
                        }
                        with open(detail_log_file, "a", encoding="utf-8") as f:
                            f.write(json.dumps(entry, ensure_ascii=False) + "\n")
                    write_log(
                        f"{alpha1:.2f},{beta1:.2f},{alpha2:.2f},{beta2:.2f},"
                        f"{round_idx},Pruned,"
                        f"{metrics_p['CHAIR-s']:.4f},{metrics_p['CHAIR-i']:.4f},"
                        f"{metrics_p['F1']:.4f},{metrics_p['Len']:.2f}\n"
                    )
                    pruned_list.append([
                        metrics_p['CHAIR-s'],
                        metrics_p['CHAIR-i'],
                        metrics_p['F1'],
                        metrics_p['Len'],
                    ])

                    # Remove all hooks
                    for h in hooks:
                        h.remove()
                    # write_log(f"[DEBUG hooks] removed all hooks\n")

                # Compute average change rate for this set of parameters
                rp = np.array(pruned_list)
                cr = ((rp - rb) / (rb + 1e-6)).mean(axis=0)
                write_log(
                    f"{group_tag} Average change rate: "
                    f"CHAIR-s={cr[0]*100:.2f}%, "
                    f"CHAIR-i={cr[1]*100:.2f}%, "
                    f"F1={cr[2]*100:.2f}%, "
                    f"Len={cr[3]*100:.2f}%\n"
                )

print("Experiment complete. Logs in", log_file)

# Record end time
end_time = datetime.datetime.now()
# Calculate time difference
delta = end_time - start_time
total_seconds = delta.total_seconds()
# Convert to hours and minutes
hours, remainder = divmod(total_seconds, 3600)
minutes, _ = divmod(remainder, 60)

# Append the result to log file
with open(log_file, "a", encoding="utf-8") as f:
    f.write(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Total time: {int(hours)} hours {int(minutes)} minutes\n")
    f.write("-" * 40 + "\n")
