import os
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from collections import defaultdict
from PIL import Image
import random
import json
import numpy as np
from itertools import product
import datetime
from pruning_llava_utils import prune_model_llava_dynamic, batch_generate_llava
from chair_metrics import batch_compute_chair_metrics 

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

model_id = "llava-hf/llava-1.5-13b-hf"

start_time = datetime.datetime.now()

# Load model
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
# Recommended: load processor using AutoProcessor
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)

# ============ Parameter Settings ============
img_dir = "path/to/val2014"
txt_file = "hallu_13b.txt"
txt_file_chosen = "holo.txt"
# txt_file_chosen = "semantic_attribution/chosen_imgs_13b_complex.txt"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
log_file = "experiment_log_13b_combo.txt"
detail_log_file = "experiment_detail_13b_log.jsonl" # Optional, detailed per-round sentence-level results

N = 10                   # Number of images per round
n_rounds = 25            # Total number of experiment rounds, adjust as needed
n_heads_per_layer = 40 

target_layers_1 = list(range(5, 19)) 
target_layers_2 = list(range(19, 27))  

faithful_heads_list_1 =[
    [[5, 1], [5, 6], [5, 7], [5, 14], [5, 22], [5, 30], [5, 36], [5, 39], [6, 4], [6, 6], [6, 9], [6, 16], [6, 20], [6, 21], [6, 23], [6, 31], [6, 38], [7, 9], [7, 14], [7, 17], [8, 2], [8, 17], [8, 18], [8, 19], [8, 21], [8, 23], [8, 36], [8, 38], [8, 39], [9, 0], [9, 1], [9, 2], [9, 6], [9, 9], [9, 11], [9, 14], [9, 15], [9, 16], [9, 20], [9, 24], [9, 32], [9, 35], [9, 37], [10, 2], [10, 4], [10, 5], [10, 13], [10, 19], [10, 21], [10, 24], [10, 31], [10, 35], [11, 7], [11, 9], [11, 13], [11, 14], [11, 15], [11, 20], [11, 24], [11, 36], [12, 1], [12, 2], [12, 16], [12, 18], [12, 22], [12, 27], [12, 32], [13, 1], [13, 5], [13, 8], [13, 10], [13, 14], [13, 16], [13, 18], [13, 26], [13, 28], [13, 31], [14, 4], [14, 7], [14, 10], [14, 14], [14, 24], [14, 26], [14, 29], [14, 32], [14, 33], [15, 6], [15, 11], [15, 14], [15, 15], [15, 20], [15, 22], [15, 30], [15, 35], [15, 37], [15, 38], [16, 0], [16, 13], [16, 15], [16, 17], [16, 25], [16, 28], [16, 32], [17, 8], [17, 9], [17, 14], [17, 15], [17, 20], [17, 22], [17, 24], [17, 34], [18, 2], [18, 6], [18, 8], [18, 11], [18, 14], [18, 18], [18, 29], [18, 32], [18, 34], [18, 35], [18, 37], [18, 39]],
]
hallucination_heads_list_1=[
    
    [[5, 0], [5, 3], [5, 10], [5, 12], [5, 13], [5, 16], [5, 19], [5, 21], [5, 23], [5, 27], [5, 28], [5, 33], [5, 37], [6, 1], [6, 5], [6, 8], [6, 11], [6, 15], [6, 19], [6, 21], [6, 22], [6, 25], [6, 27], [6, 29], [6, 31], [6, 36], [7, 6], [7, 7], [7, 10], [7, 12], [7, 13], [7, 22], [7, 25], [7, 39], [8, 0], [8, 3], [8, 5], [8, 7], [8, 8], [8, 10], [8, 13], [8, 17], [8, 19], [8, 20], [8, 27], [8, 28], [8, 31], [8, 32], [8, 33], [8, 36], [9, 3], [9, 4], [9, 7], [9, 14], [9, 16], [9, 30], [9, 33], [9, 34], [9, 35], [9, 36], [10, 0], [10, 9], [10, 11], [10, 17], [10, 27], [10, 29], [10, 30], [10, 31], [11, 0], [11, 1], [11, 3], [11, 4], [11, 11], [11, 12], [11, 15], [11, 16], [11, 18], [11, 19], [11, 21], [11, 22], [11, 23], [11, 24], [11, 25], [11, 26], [11, 28], [11, 30], [11, 31], [11, 35], [11, 38], [11, 39], [12, 1], [12, 4], [12, 7], [12, 11], [12, 14], [12, 17], [12, 18], [12, 26], [12, 29], [12, 32], [12, 34], [13, 2], [13, 4], [13, 8], [13, 9], [13, 11], [13, 12], [13, 13], [13, 15], [13, 22], [13, 23], [13, 25], [13, 26], [13, 27], [13, 33], [13, 38], [13, 39], [14, 4], [14, 9], [14, 13], [14, 15], [14, 16], [14, 22], [14, 23], [14, 24], [14, 25], [14, 27], [14, 28], [14, 30], [14, 32], [14, 37], [15, 0], [15, 1], [15, 2], [15, 4], [15, 8], [15, 9], [15, 10], [15, 11], [15, 12], [15, 13], [15, 14], [15, 20], [15, 21], [15, 28], [15, 30], [15, 31], [15, 38], [16, 2], [16, 3], [16, 5], [16, 6], [16, 10], [16, 11], [16, 12], [16, 16], [16, 19], [16, 20], [16, 22], [16, 27], [16, 36], [17, 4], [17, 7], [17, 10], [17, 12], [17, 13], [17, 18], [17, 19], [17, 20], [17, 25], [17, 27], [17, 28], [17, 30], [17, 31], [17, 33], [17, 36], [18, 3], [18, 12], [18, 13], [18, 15], [18, 16], [18, 17], [18, 19], [18, 21], [18, 22], [18, 24], [18, 25], [18, 27], [18, 30], [18, 31], [18, 38]],
]


faithful_heads_list_2 =[
    [[19, 0], [19, 16], [19, 19], [19, 24], [19, 25], [19, 32], [19, 33], [19, 34], [20, 0], [20, 1], [20, 7], [20, 12], [20, 17], [20, 19], [20, 20], [20, 23], [20, 28], [20, 30], [20, 32], [21, 2], [21, 3], [21, 8], [21, 13], [21, 14], [21, 17], [21, 20], [21, 22], [21, 30], [21, 35], [22, 0], [22, 1], [22, 6], [22, 9], [22, 13], [22, 16], [22, 20], [22, 21], [22, 22], [22, 24], [22, 27], [22, 32], [23, 7], [23, 8], [23, 9], [23, 10], [23, 11], [23, 12], [23, 15], [23, 19], [23, 25], [23, 31], [23, 35], [24, 5], [24, 6], [24, 9], [24, 24], [24, 28], [24, 33], [24, 35], [24, 39], [25, 6], [25, 11], [25, 15], [25, 16], [25, 17], [25, 21], [25, 31], [25, 35], [25, 36], [25, 39], [26, 3], [26, 4], [26, 5], [26, 15], [26, 18], [26, 19], [26, 20], [26, 21], [26, 23], [26, 32], [26, 33]]
]
hallucination_heads_list_2=[
    [[19, 2], [19, 3], [19, 4], [19, 6], [19, 10], [19, 12], [19, 13], [19, 21], [19, 29], [19, 30], [19, 33], [19, 37], [19, 39], [20, 2], [20, 4], [20, 6], [20, 7], [20, 9], [20, 10], [20, 11], [20, 13], [20, 15], [20, 16], [20, 18], [20, 19], [20, 24], [20, 25], [20, 27], [20, 33], [20, 38], [20, 39], [21, 4], [21, 6], [21, 7], [21, 9], [21, 11], [21, 23], [21, 24], [21, 26], [21, 27], [21, 29], [21, 31], [21, 33], [21, 39], [22, 6], [22, 8], [22, 15], [22, 28], [22, 29], [22, 31], [22, 32], [22, 34], [22, 37], [22, 38], [23, 1], [23, 2], [23, 3], [23, 5], [23, 7], [23, 18], [23, 20], [23, 22], [23, 23], [23, 29], [23, 31], [23, 36], [23, 38], [23, 39], [24, 1], [24, 2], [24, 3], [24, 7], [24, 8], [24, 10], [24, 13], [24, 15], [24, 18], [24, 19], [24, 20], [24, 22], [24, 23], [24, 26], [24, 29], [24, 30], [24, 32], [24, 37], [24, 38], [25, 0], [25, 3], [25, 5], [25, 8], [25, 12], [25, 13], [25, 19], [25, 21], [25, 27], [25, 29], [26, 1], [26, 2], [26, 4], [26, 6], [26, 7], [26, 9], [26, 11], [26, 14], [26, 16], [26, 17], [26, 18], [26, 21], [26, 25], [26, 31], [26, 35], [26, 37], [26, 38], [26, 39]],
]

alpha_beta_list_1 = [[0.6,0.8],[0.6,0.85],[0.3,0.9]]

alpha_beta_list_2 = [
    [0.3, 0.9],
    [0.3, 0.8],
    [0.35, 0.85],
    [0.65,0.8],
    [0.6, 0.8],
]


ablation_scheme = "mean"
circuit_mlps = []
include_mlps = False

def write_log(message):
    print(message)  # Print to the screen
    with open(log_file, 'a') as f:  # Append to the log file
        f.write(message + '\n')

# Read COCO GT data (unchanged)
with open(txt_file) as f:
    all_img_names = [l.strip() for l in f]
with open(ann_file) as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname2labels[imgid2fname[ann["image_id"]]].add(catid2name[ann["category_id"]])
fname2labels = {k:list(v) for k,v in fname2labels.items()}

prompt = "USER: <image>\nPlease describe the image in detail.\nASSISTANT:"

with open(txt_file) as f:
    all_img_names = [x.strip() for x in f if x.strip()]
with open(txt_file_chosen) as f:
    chosen_names = set(x.strip() for x in f if x.strip())

valid_img_names = [x for x in all_img_names if x not in chosen_names]
required_imgs = n_rounds * N
assert len(valid_img_names) >= required_imgs, f"Insufficient available images, required: {required_imgs}, actual: {len(valid_img_names)}"

fixed_sets = random.sample(valid_img_names, required_imgs)  # Fixed sampling

# ==================== Baseline Precomputation ====================
print("=== Computing baseline metrics ===")
with open(log_file, "w") as f:
    f.write("Round,CHAIR-s,CHAIR-i,F1,Len\n")
rb_list = []
for round_idx in range(1, n_rounds+1):
    round_fnames = fixed_sets[(round_idx-1)*N : round_idx*N]
    imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in round_fnames]
    prompts = [prompt] * len(round_fnames)
    with torch.no_grad():
        preds = batch_generate_llava(model, tokenizer, processor, imgs, prompts,
                                     device="cuda", max_new_tokens=128)
    metrics = batch_compute_chair_metrics(preds, [fname2labels.get(fn, []) for fn in round_fnames])
    write_log(f"{round_idx},{metrics['CHAIR-s']:.4f},{metrics['CHAIR-i']:.4f},"
              f"{metrics['F1']:.4f},{metrics['Len']:.2f}\n")
    rb_list.append([metrics['CHAIR-s'], metrics['CHAIR-i'], metrics['F1'], metrics['Len']])
rb = np.array(rb_list)
mb_mean, mb_std = rb.mean(axis=0), rb.std(axis=0)

# ============ Write log file header ============
with open(log_file, "a") as f:
    f.write("Group1_Alpha,Group1_Beta,Group2_Alpha,Group2_Beta,Round,Type,CHAIR-s,CHAIR-i,F1,Len\n")

# 3D ablation group: G1 × G2
for g1_idx, (faithful1, hallu1) in enumerate(zip(faithful_heads_list_1, hallucination_heads_list_1)):
    for g2_idx, (faithful2, hallu2) in enumerate(zip(faithful_heads_list_2, hallucination_heads_list_2)):

        # ** Remove heads from hallucination list that overlap with faithful list **
        hallu1_clean = [h for h in hallu1 if h not in faithful1]
        hallu2_clean = [h for h in hallu2 if h not in faithful2]

        # For each ablation group, run all 4×3=12 combinations of αβ
        for (alpha1, beta1) in alpha_beta_list_1:
            for (alpha2, beta2) in alpha_beta_list_2:

                # Clear all activation hooks
                for layer in model.model.language_model.layers:
                    layer.self_attn._forward_hooks.clear()
                    
                group_tag = (
                    f"G1({g1_idx})_A({alpha1:.2f})_B({beta1:.2f})__"
                    f"G2({g2_idx})_A({alpha2:.2f})_B({beta2:.2f})"
                )
                write_log(f"\n===== {group_tag} =====\n")

                pruned_list = []
                # Run for each round
                for round_idx in range(1, n_rounds+1):
                    round_fnames = fixed_sets[(round_idx-1)*N : round_idx*N]
                    imgs = [Image.open(os.path.join(img_dir, fn)).convert("RGB") for fn in round_fnames]
                    prompts = [prompt] * len(round_fnames)

                    # Register hooks for both target_layer segments
                    hooks = []
                    hooks += prune_model_llava_dynamic(
                        model,
                        faithful_heads=[h for h in faithful1 if h[0] in target_layers_1],
                        hallucination_heads=[h for h in hallu1_clean if h[0] in target_layers_1],
                        target_layers=target_layers_1,
                        alpha=alpha1, beta=beta1,
                    )
                    hooks += prune_model_llava_dynamic(
                        model,
                        faithful_heads=[h for h in faithful2 if h[0] in target_layers_2],
                        hallucination_heads=[h for h in hallu2_clean if h[0] in target_layers_2],
                        target_layers=target_layers_2,
                        alpha=alpha2, beta=beta2,
                    )

                    # Pruned inference
                    with torch.no_grad():
                        preds_p = batch_generate_llava(
                            model, tokenizer, processor,
                            imgs, prompts,
                            device="cuda", max_new_tokens=128
                        )
                    metrics_p = batch_compute_chair_metrics(
                        preds_p, [fname2labels.get(fn, []) for fn in round_fnames]
                    )
                    write_log(
                        f"{alpha1:.2f},{beta1:.2f},{alpha2:.2f},{beta2:.2f},"
                        f"{round_idx},Pruned,"
                        f"{metrics_p['CHAIR-s']:.4f},{metrics_p['CHAIR-i']:.4f},"
                        f"{metrics_p['F1']:.4f},{metrics_p['Len']:.2f}\n"
                    )
                    pruned_list.append([
                        metrics_p['CHAIR-s'],
                        metrics_p['CHAIR-i'],
                        metrics_p['F1'],
                        metrics_p['Len'],
                    ])

                    # Remove all hooks
                    for h in hooks:
                        h.remove()

                # Compute average relative change for this parameter group
                rp = np.array(pruned_list)
                cr = ((rp - rb) / (rb + 1e-6)).mean(axis=0)
                write_log(
                    f"{group_tag} Average change rate: "
                    f"CHAIR-s={cr[0]*100:.2f}%, "
                    f"CHAIR-i={cr[1]*100:.2f}%, "
                    f"F1={cr[2]*100:.2f}%, "
                    f"Len={cr[3]*100:.2f}%\n"
                )

print("Experiment complete. Logs in", log_file)

# Record end time
end_time = datetime.datetime.now()
# Calculate time difference
delta = end_time - start_time
total_seconds = delta.total_seconds()
# Convert to hours and minutes
hours, remainder = divmod(total_seconds, 3600)
minutes, _ = divmod(remainder, 60)

# Append timing info to log file
with open(log_file, "a", encoding="utf-8") as f:
    f.write(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Total duration: {int(hours)} hours {int(minutes)} minutes\n")
    f.write("-" * 40 + "\n")
