import os, json ,sys
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from PIL import Image
import numpy as np
import pandas as pd
from datetime import datetime
from skopt import gp_minimize
from skopt.space import Real
from collections import defaultdict



model_id = "llava-hf/llava-1.5-7b-hf"
img_dir = "path/to/val2014"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
time_log_file = "time_log_7b_5_18_hallu.txt"
content_log_file = "content_log_7b_5_18_hallu.jsonl"
used_imgs = "holo.txt"
chosen_imgs = "../hallu_img_samples/hallu-7b.txt"

max_new_tokens=128
batch_size = 1 
K = 15   # K is the number of attribution samples. see SHPM.
target_layers = list(range(5, 19)) # target layers for pruning, 5-18
step = 0.01 # Faith step size
csv_path = "auto_eigenscore_7b_5-18_hallu.csv"
prompt = "<image>\nPlease describe the image in detail."
log_file = "auto_experiment_7b_5_18_log_hallu.txt"



# Define Bayesian optimization search space for head selection and pruning parameters
search_space = [
    Real(0.1, 0.3, name="low"),      # halluc heads low quantile
    Real(0.7, 0.9, name="high"),     # halluc heads high quantile
    Real(0.2, 0.75, name="mius"),    # μ
    Real(0.8, 0.99, name="lamda"),   # λ
]


# ========== Add Parent Directory to Path for Custom Imports ==========
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)

from llava_hooks import batch_run_with_cache_llava
from experiments import auto_circuit_experiment
from pruning_llava_utils import prune_model_llava_dynamic, safe_batch_generate
from chair_metrics import batch_compute_chair_metrics

# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,use_fast=True)


# Load COCO label info
with open(ann_file, "r") as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname = imgid2fname[ann["image_id"]]
    catname = catid2name[ann["category_id"]]
    fname2labels[fname].add(catname)
fname2labels = {k: list(v) for k, v in fname2labels.items()}

# ---- Logging helpers ---
def log_time(msg):
    ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(time_log_file, "a") as f:
        f.write(f"[{ts}] {msg}\n")

def log_content(rnd, exp_id, type_, file_name, pred, gt_label):
    log_obj = {
        "round": rnd,
        "experiment": exp_id,
        "type": type_,
        "file_name": file_name,
        "pred": pred,
        "gt_label": gt_label
    }
    with open(content_log_file, "a") as f:
        f.write(json.dumps(log_obj, ensure_ascii=False) + "\n")


# ---- Prepare images for the experiment ----
all_selected_names = [] 
used_img_names = set()
if os.path.exists(used_imgs):
    with open(used_imgs, "r") as f:
        for line in f:
            imgname = line.strip()
            if imgname:
                used_img_names.add(imgname)

with open(chosen_imgs, "r") as f:
    chosen_img_names = [line.strip() for line in f if line.strip()]

chosen_img_names = [x for x in chosen_img_names if x not in used_img_names][:120]
print(f"Chosen {len(chosen_img_names)} images for this experiment.")


# ---- Main experiment loop: one image per round ----
for rnd, fname in enumerate(chosen_img_names, 1): 
    log_time(f"Round {rnd} start.")
    torch.cuda.empty_cache()
    
    # -- Prepare image sample --
    used_img_names.add(fname)
    all_selected_names.append(fname)
    img_path = f"{img_dir}/{fname}"
    img = Image.open(img_path).convert("RGB")
    patching_samples = [{
        "image": img,
        "prompt": prompt,
        "file_name": fname,
        "gt_label": fname2labels.get(fname, [])
    }]
    log_time(f"Loaded image: {fname}")
    images = [s["image"] for s in patching_samples]
    images = [img.convert("RGB") for img in images]
    prompts = [s["prompt"] for s in patching_samples]


    # -- get cache --
    _, patching_cache = batch_run_with_cache_llava(
        model, processor, patching_samples, device="cuda",batch_size=batch_size
    )
    log_time(f"Cache complete for round {rnd}.")
    
    # -- Run auto_circuit_experiment to calculate per-head attribution (eigenscore) --
    log_time(f"Auto circuit experiment start.")
    exp_id = auto_circuit_experiment(
        csv_path=csv_path,
        model=model,
        processor=processor,
        val_samples=patching_samples,  
        val_cache=patching_cache,
        ablation_scheme="mean",
        device="cuda",
        include_mlps=False,
        num_samples=K,
        layer_hidden_index=None,
        target_layers=target_layers,
        batch_size=batch_size,
    )
    log_time(f"Auto circuit experiment end. exp_id={exp_id}")

    # save eigenscore to csv
    df = pd.read_csv(csv_path)
    df_e = df[df["experiment"] == exp_id]
    scores = df_e["delta_eigenscore"]
    

    # --- Bayesian optimization for hallucination interval and pruning parameters ---
    def objective(params):
        low, high, mius, lamda = params
        if low >= high:  # Penalize invalid intervals
            return 1e6
        # Select faithful heads (middle quantile, fixed during search)
        qf_low, qf_high = scores.quantile([0.4, 0.6]).tolist()
        faithful_fixed = df_e[(scores >= qf_low)&(scores <= qf_high)][["layer","head"]].values.tolist()
        # Select hallucination heads (outside Bayesian-chosen interval)
        ql, qh = scores.quantile([low, high]).tolist()
        halluc_heads = df_e[(scores < ql)|(scores > qh)][["layer","head"]].values.tolist()
        # Prune heads and run inference
        hooks = prune_model_llava_dynamic(model, faithful_fixed, halluc_heads, alpha=mius, beta=lamda)
        with torch.no_grad():
            preds = safe_batch_generate(
                model, tokenizer, processor, images, prompts,
                device='cuda', max_new_tokens=max_new_tokens,
                min_len=10, max_retries=3
            )
        for h in hooks: h.remove()
        if preds is None:
            return 1e6
        m = batch_compute_chair_metrics(preds, [s["gt_label"] for s in patching_samples])
        # Return negative improvement in CHAIR-i (Bayesian minimization)
        return -(metrics_base['CHAIR-i'] - m['CHAIR-i']) 

    # clean hooks
    for layer in model.model.language_model.layers:
        layer.self_attn._forward_hooks.clear()

    # ---- Run baseline inference before pruning (needed for gain calculation) ----
    log_time("Baseline eval start.")
    with torch.no_grad():
        preds_base = safe_batch_generate(
            model, tokenizer, processor,
            [s["image"] for s in patching_samples],
            [s["prompt"] for s in patching_samples],
            device='cuda', max_new_tokens=max_new_tokens,
            min_len=10, max_retries=3
        )
    if preds_base is None:
        log_time(f"[SKIP] Round {rnd} baseline too short after retries")
        print(f"[SKIP] Round {rnd} baseline too short after retries")
        for samp in patching_samples:
            log_content(rnd, exp_id, "baseline", samp["file_name"], "SKIP", samp["gt_label"])
        with open(log_file, "a") as f:
            f.write(f"{rnd},{exp_id},SKIP,SKIP,SKIP,SKIP,SKIP,SKIP,SKIP,SKIP\n")
        continue  
    log_time("Baseline eval end.")

    for samp, pred in zip(patching_samples, preds_base):
        log_content(rnd, exp_id, "baseline", samp["file_name"], pred, samp["gt_label"])
    metrics_base = batch_compute_chair_metrics(preds_base, [s["gt_label"] for s in patching_samples])

    log_time("Pruned eval start.")

    # ---- Bayesian optimization: jointly search hallucination interval and pruning parameters ----
    opt_result = gp_minimize(
        objective,
        search_space,
        n_calls=20,
        n_initial_points=8,
        acq_func="EI",
        random_state=2025,
        verbose=True
    )
    best_h_low, best_h_high, mius, lamda = opt_result.x

    # -- Use optimal hallucination interval to select heads for pruning --
    q_h_low, q_h_high = scores.quantile([best_h_low, best_h_high])
    halluc_heads = df_e[(scores < q_h_low) | (scores > q_h_high)][["layer", "head"]].values.tolist()
    
    # ---- Faithful head interval search: contraction/expansion with local refinement ----
    def eval_faith(low, high):
        ql, qh = scores.quantile([low, high]).tolist()
        faiths = df_e[(scores >= ql) & (scores <= qh)][["layer", "head"]].values.tolist()
        hooks = prune_model_llava_dynamic(model, faiths, halluc_heads, alpha=mius, beta=lamda)
        with torch.no_grad():
            preds = safe_batch_generate(
                model, tokenizer, processor, images, prompts,
                device='cuda', max_new_tokens=max_new_tokens,
                min_len=10, max_retries=3
            )
        for h in hooks: h.remove()
        if preds is None:
            return 0.0, low, high
        m = batch_compute_chair_metrics(preds, [s["gt_label"] for s in patching_samples])
        return metrics_base['CHAIR-i'] - m['CHAIR-i'], low, high

    # -- Faith head search: start from [0.40, 0.60], expand/contract for local optimum --
    initial_low, initial_high = 0.40, 0.60
    best_f_gain, best_f_low, best_f_high = eval_faith(initial_low, initial_high)
    # Expand interval to see if larger coverage helps
    best_exp_gain, best_exp_low, best_exp_high = best_f_gain, best_f_low, best_f_high
    fract_low, fract_high = best_f_low, best_f_high
    while True:
        nl, nh = fract_low - step, fract_high + step
        if nl < 0 or nh > 1: break
        gain, l, h = eval_faith(nl, nh)
        if gain > best_exp_gain:
            best_exp_gain, best_exp_low, best_exp_high = gain, l, h
            fract_low, fract_high = nl, nh
        else:
            break
    # Contract interval for potential improvement
    best_con_gain, best_con_low, best_con_high = best_f_gain, best_f_low, best_f_high
    fract_low, fract_high = best_f_low, best_f_high
    while True:
        nl, nh = fract_low + step, fract_high - step
        if nl >= nh: break
        gain, l, h = eval_faith(nl, nh)
        if gain > best_con_gain:
            best_con_gain, best_con_low, best_con_high = gain, l, h
            fract_low, fract_high = nl, nh
        else:
            break
    # Take the best faith interval from expansion, contraction, or initial
    final_f_gain, final_f_low, final_f_high = best_f_gain, best_f_low, best_f_high
    if best_exp_gain > final_f_gain:
        final_f_gain, final_f_low, final_f_high = best_exp_gain, best_exp_low, best_exp_high
    if best_con_gain > final_f_gain:
        final_f_gain, final_f_low, final_f_high = best_con_gain, best_con_low, best_con_high
    # Select faithful heads according to final optimized interval
    q_f_low, q_f_high = scores.quantile([final_f_low, final_f_high])
    faithful_heads = df_e[(scores >= q_f_low) & (scores <= q_f_high)][["layer", "head"]].values.tolist()

    print(f"Bayesian hallu=[{best_h_low:.2f},{best_h_high:.2f}], faith=[{final_f_low:.2f},{final_f_high:.2f}], mius={mius:.3f}, lamda={lamda:.3f}")
    print(f"Selected {len(halluc_heads)} halluc heads, {len(faithful_heads)} faithful heads.")

    # -- Run final pruned inference with selected heads/params --
    hooks = prune_model_llava_dynamic(
        model,
        faithful_heads=faithful_heads,
        hallucination_heads=halluc_heads,
        target_layers=target_layers,
        alpha=mius,
        beta=lamda
    )
    with torch.no_grad():
        preds_pruned = safe_batch_generate(
            model, tokenizer, processor,
            [s["image"] for s in patching_samples],
            [s["prompt"] for s in patching_samples],
            device='cuda', max_new_tokens=max_new_tokens,
            min_len=10, max_retries=3
        )
    pruned_success = preds_pruned is not None

    # -- Save pruned results and compute metrics --
    if pruned_success:
        for samp, pred in zip(patching_samples, preds_pruned):
            log_content(rnd, exp_id, "pruned", samp["file_name"], pred, samp["gt_label"])
        metrics_pruned = batch_compute_chair_metrics(preds_pruned, [s["gt_label"] for s in patching_samples])

        with open(log_file, "a") as f:
            f.write(
                f"{rnd},{exp_id},{metrics_base['CHAIR-s']:.4f},{metrics_base['CHAIR-i']:.4f},"
                f"{metrics_base['F1']:.4f},{metrics_base['Len']:.2f},"
                f"{metrics_pruned['CHAIR-s']:.4f},{metrics_pruned['CHAIR-i']:.4f},"
                f"{metrics_pruned['F1']:.4f},{metrics_pruned['Len']:.2f}\n"
            )
    else:
        log_time(f"[SKIP] Round {rnd} pruned too short after retries")
        print(f"[SKIP] Round {rnd} pruned too short after retries")
        with open(log_file, "a") as f:
            f.write(f"{rnd},{exp_id},SKIP,SKIP,SKIP,SKIP,SKIP,SKIP,SKIP,SKIP\n")
    
    for h in hooks:
        h.remove()
    log_time("Hooks cleared.")

# Record all selected images
with open("imgs_used.txt", "a") as f:
    for img_name in dict.fromkeys(all_selected_names):
        f.write(f"{img_name}\n")