import os, json, torch ,sys
from PIL import Image
import pandas as pd
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from datetime import datetime
from collections import defaultdict
import time
import tempfile

from skopt import gp_minimize
from skopt.space import Real

model_id = "llava-hf/llava-1.5-7b-hf"
img_dir = "path/to/val2014"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
csv_path = "auto_eigenscore_7b_5-18_hallu.csv"
content_log_file = "content_log_7b_5_18_hallu_retry-autojudge-bias.jsonl"
used_imgs_file = "valid_imgs.txt"

prompt = "<image>\nPlease describe the image in detail."

orig_mius, orig_lamda = 0.65,0.95
max_new_tokens = 128
step = 0.01  # Faith step size
summary_csv = "search_summary_bias-73-i.csv"
max_retries = 3 

# Define Bayesian optimization search space for head selection and pruning parameters
search_space = [
    Real(0.1, 0.3, name="low"),
    Real(0.7, 0.9, name="high"),
    Real(0.2, 0.75, name="mius"),
    Real(0.8, 0.99, name="lamda"),
    ]

# ========== Add Parent Directory to Path for Custom Imports ==========
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)

from pruning_llava_utils import prune_model_llava_dynamic, safe_batch_generate
from chair_metrics import batch_compute_chair_metrics



# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    trust_remote_code=True,
)
model = torch.compile(model)

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)


# Load COCO label info
with open(ann_file, "r") as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname = imgid2fname[ann["image_id"]]
    catname = catid2name[ann["category_id"]]
    fname2labels[fname].add(catname)
fname2labels = {k: list(v) for k, v in fname2labels.items()}


with open(used_imgs_file, "r") as f:
    used_img_names = [line.strip() for line in f if line.strip()]


df = pd.read_csv(csv_path)

def init_summary(df_cols):
    if os.path.exists(summary_csv):
        df = pd.read_csv(summary_csv)
        if set(df.columns) != set(df_cols):
            df = pd.DataFrame(columns=df_cols)
    else:
        df = pd.DataFrame(columns=df_cols)
    return df.set_index("round", drop=False)

def update_summary(df_sum, rec):
    rnd = rec["round"]
    if rnd in df_sum.index:
        for k, v in rec.items():
            df_sum.at[rnd, k] = v
    else:
        df_sum.loc[rnd] = rec
    df_sum = df_sum.sort_index()

    with tempfile.NamedTemporaryFile("w", delete=False, suffix=".csv") as tmp:
        df_sum.to_csv(tmp.name, index=False)
        tmp_path = tmp.name
    os.replace(tmp_path, summary_csv)

    return df_sum

def update_jsonl_entry(jsonl_path, new_entry):
    entries = []
    if os.path.exists(jsonl_path):
        with open(jsonl_path, "r", encoding="utf-8") as fr:
            for line in fr:
                try:
                    entries.append(json.loads(line))
                except:
                    continue

    fname = new_entry["file_name"]
    typ   = new_entry["type"]

    for idx, e in enumerate(entries):
        if e.get("file_name") == fname and e.get("type") == typ:
            entries[idx] = new_entry
            break
    else:
        insert_idx = len(entries)
        if typ == "baseline":
            for i, e in enumerate(entries):
                if e.get("file_name") == fname and e.get("type") == "pruned":
                    insert_idx = i
                    break
        else:  # typ == "pruned"
            for i, e in enumerate(entries):
                if e.get("file_name") == fname and e.get("type") == "baseline":
                    insert_idx = i + 1
                    break
        entries.insert(insert_idx, new_entry)

    with open(jsonl_path, "w", encoding="utf-8") as fw:
        for e in entries:
            fw.write(json.dumps(e, ensure_ascii=False) + "\n")

def eval_faith(low, high):
    ql, qh = scores.quantile([low, high]).tolist()
    faiths = df_e[(scores >= ql)&(scores <= qh)][["layer","head"]].values.tolist()
    hooks = prune_model_llava_dynamic(model, faiths, halluc_fixed, alpha=mius, beta=lamda)
    with torch.no_grad():
        preds = safe_batch_generate(
            model, tokenizer, processor, images, prompts,
            device='cuda', max_new_tokens=max_new_tokens,
            min_len=10, max_retries=max_retries
        )
    for h in hooks: h.remove()
    if preds is None:
        return 0.0, low, high
    m = batch_compute_chair_metrics(preds, [fname2labels[fname]])
    return base_i - m['CHAIR-i'], low, high



summary_columns = [
    "round", "h_low", "h_high", "f_low", "f_high",
    "mius", "lamda",
    "ΔCHAIR-s (%)", "ΔCHAIR-i (%)",
    "pruned_CHAIR-s", "pruned_CHAIR-i", "pruned_F1",
    "baseline_CHAIR-s", "baseline_CHAIR-i", "baseline_F1",
    "time"
]
df_summary = init_summary(summary_columns)

types_by_file = defaultdict(set)
if os.path.exists(content_log_file):
    with open(content_log_file, "r") as fr:
        for line in fr:
            try:
                d = json.loads(line)
                types_by_file[d["file_name"]].add(d.get("type"))
            except Exception:
                continue

only_baseline = {
    fname for fname, tset in types_by_file.items()
    if "baseline" in tset and "pruned" not in tset
}
completed = {
    fname for fname, tset in types_by_file.items()
    if "baseline" in tset and "pruned" in tset
}

# ---- Main experiment loop: one image per round ----
for rnd, fname in enumerate(used_img_names, 1):
    
    if fname in completed:
        print(f"[INFO] Skip round {rnd}: {fname} already fully done.")
        continue

    if not fname2labels.get(fname, []):
        print(f"[WARN] {fname} has empty label, skip.")
        continue

    # --- Bayesian optimization for hallucination interval and pruning parameters ---
    def objective(params):
        low, high, mius, lamda = params
        if low >= high:  # Penalize invalid intervals
            return 1e6
        # Select faithful heads (middle quantile, fixed during search)
        qf_low, qf_high = scores.quantile([0.4, 0.6]).tolist()
        faithful_fixed = df_e[(scores >= qf_low)&(scores <= qf_high)][["layer","head"]].values.tolist()
        # Select hallucination heads (outside Bayesian-chosen interval)
        ql, qh = scores.quantile([low, high]).tolist()
        halluc_heads = df_e[(scores < ql)|(scores > qh)][["layer","head"]].values.tolist()
        # Prune heads and run inference
        hooks = prune_model_llava_dynamic(model, faithful_fixed, halluc_heads, alpha=mius, beta=lamda)
        with torch.no_grad():
            preds = safe_batch_generate(
                model, tokenizer, processor, images, prompts,
                device='cuda', max_new_tokens=max_new_tokens,
                min_len=10, max_retries=3
            )
        for h in hooks: h.remove()
        if preds is None:
            return 1e6
        m = batch_compute_chair_metrics(preds, [fname2labels[fname]])
        # Return negative improvement in CHAIR-i (Bayesian minimization)
        return -(base_i - m['CHAIR-i'])
    
    start_time = time.time()
    print(f"\n=== Round {rnd}: {fname} ===")

    df_e = df[df["experiment"] == rnd]
    if df_e.empty:
        print(f"[WARN] no data for round {rnd}, skip")
        continue
    scores = df_e["delta_eigenscore"]

    # ---- Run baseline inference before pruning (needed for gain calculation) ----
    img = Image.open(os.path.join(img_dir, fname)).convert("RGB")
    images, prompts = [img], [prompt]
    with torch.no_grad():
        preds_base = safe_batch_generate(
            model, tokenizer, processor,
            images, prompts,
            device='cuda', max_new_tokens=max_new_tokens,
            min_len=10, max_retries=1
        )
    if preds_base is None:
        print(f"[SKIP] Round {rnd} baseline too short")
        continue
    for p in preds_base:
        entry = {
            "round": rnd, "experiment": rnd, "type": "baseline",
            "file_name": fname, "pred": p, "gt_label": fname2labels.get(fname, [])
        }
        update_jsonl_entry(content_log_file, entry)

    base_m  = batch_compute_chair_metrics(preds_base, [fname2labels[fname]])
    base_s  = base_m['CHAIR-s']
    base_i  = base_m['CHAIR-i']
    base_f1 = base_m['F1']

    # ---- Bayesian optimization: jointly search hallucination interval and pruning parameters ----
    opt_result = gp_minimize(
        objective,
        search_space,
        n_calls=20,              
        n_initial_points=8,      
        acq_func="EI",           
        random_state=2025,
        verbose=True
    )

    # -- Use optimal hallucination interval to select heads for pruning --
    best_h_low, best_h_high, best_alpha, best_beta = opt_result.x
    best_h_gain = -opt_result.fun  

    print(f"[INFO] BO Search: best interval=[{best_h_low:.2f},{best_h_high:.2f}], "
          f"mius={best_alpha:.3f}, lamda={best_beta:.3f} → gain={best_h_gain:.4f}")

    if best_h_gain <= 0:
        print("[WARN] no positive gain on any candidate, skip pruning this round")
        elapsed = time.time() - start_time
        mins, secs = divmod(elapsed, 60)
        rec = {
            "round":       rnd,
            "h_low":      "skip",
            "h_high":     "skip",
            "f_low":      "skip",
            "f_high":     "skip",
            "mius":       "skip",
            "lamda":       "skip",
            "ΔCHAIR-s (%)": 0.0,
            "ΔCHAIR-i (%)": 0.0,
            "pruned_CHAIR-s": "skip",
            "pruned_CHAIR-i": "skip",
            "pruned_F1":      "skip",
            "baseline_CHAIR-s": round(base_s, 2),
            "baseline_CHAIR-i": round(base_i, 2),
            "baseline_F1":      round(base_f1, 2),
            "time": f"{mins}m{secs:.2f}s",
        }
        df_summary = update_summary(df_summary, rec)
        continue
    
    mius, lamda = best_alpha, best_beta
    print(f"[INFO] initial pick interval=[{best_h_low:.2f},{best_h_high:.2f}], "
        f"mius={best_alpha},lamda={best_beta} → gain={best_h_gain:.4f}")


    hh_ql, hh_qh = scores.quantile([best_h_low, best_h_high]).tolist()
    halluc_fixed = df_e[(scores < hh_ql) | (scores > hh_qh)][["layer","head"]].values.tolist()

    # ---- Faithful head interval search: contraction/expansion with local refinement ----
    # -- Faith head search: start from [0.40, 0.60], expand/contract for local optimum --
    initial_low, initial_high = 0.40, 0.60
    best_f_gain, best_f_low, best_f_high = eval_faith(initial_low, initial_high)
    print(f"[DEBUG] faith initial [{initial_low:.2f},{initial_high:.2f}] → gain={best_f_gain:.4f}")

    # —— expand ——  
    best_exp_gain = best_f_gain
    best_exp_low, best_exp_high = best_f_low, best_f_high
    fract_low, fract_high = best_f_low, best_f_high

    while True:
        nl, nh = fract_low - step, fract_high + step
        if nl < 0 or nh > 1:
            break
        gain, l, h = eval_faith(nl, nh)
        if gain > best_exp_gain:
            best_exp_gain = gain
            best_exp_low, best_exp_high = l, h
            fract_low, fract_high = nl, nh
        else:
            break

    # —— contract ——  
    best_con_gain = best_f_gain
    best_con_low, best_con_high = best_f_low, best_f_high
    fract_low, fract_high = best_f_low, best_f_high

    while True:
        nl, nh = fract_low + step, fract_high - step
        if nl >= nh:
            break
        gain, l, h = eval_faith(nl, nh)
        if gain > best_con_gain:
            best_con_gain = gain
            best_con_low, best_con_high = l, h
            fract_low, fract_high = nl, nh
        else:
            break

    # Take the best faith interval from expansion, contraction, or initial
    final_f_gain = best_f_gain
    final_f_low, final_f_high = best_f_low, best_f_high

    if best_exp_gain > final_f_gain:
        final_f_gain = best_exp_gain
        final_f_low, final_f_high = best_exp_low, best_exp_high
        print(f"[INFO] faith picked expand: [{final_f_low:.2f},{final_f_high:.2f}], gain={final_f_gain:.4f}")

    if best_con_gain > final_f_gain:
        final_f_gain = best_con_gain
        final_f_low, final_f_high = best_con_low, best_con_high
        print(f"[INFO] faith picked contract: [{final_f_low:.2f},{final_f_high:.2f}], gain={final_f_gain:.4f}")


    best_f_low, best_f_high, best_f_gain = final_f_low, final_f_high, final_f_gain
    print(f"[INFO] Faith final fracts: [{best_f_low:.2f},{best_f_high:.2f}], gain={best_f_gain:.4f}")

    # Select faithful heads according to final optimized interval
    h_th_low, h_th_high = scores.quantile([best_h_low, best_h_high]).tolist()
    f_th_low, f_th_high = scores.quantile([best_f_low, best_f_high]).tolist()


    # -- Run final pruned inference with selected heads/params --
    best_h_heads = df_e[(scores < h_th_low) | (scores > h_th_high)][["layer","head"]].values.tolist()
    best_f_heads = df_e[(scores >= f_th_low) & (scores <= f_th_high)][["layer","head"]].values.tolist()

    hooks = prune_model_llava_dynamic(
        model, faithful_heads=best_f_heads, hallucination_heads=best_h_heads,
        alpha=mius, beta=lamda
    )
    with torch.no_grad():
        preds_final = safe_batch_generate(
            model, tokenizer, processor,
            images, prompts,
            device='cuda', max_new_tokens=max_new_tokens,
            min_len=7, max_retries=max_retries
        )
    for h in hooks: h.remove()

    # -- Save pruned results and compute metrics --
    final_m = None
    if preds_final is not None:
        final_m = batch_compute_chair_metrics(preds_final, [fname2labels[fname]])
        for p in preds_final:
            entry = {
                "round": rnd, "experiment": rnd, "type": "pruned",
                "file_name": fname, "pred": p, "gt_label": fname2labels.get(fname, [])
            }
            update_jsonl_entry(content_log_file, entry)
    else:
        entry = {
            "round": rnd, "experiment": rnd, "type": "pruned",
            "file_name": fname, "pred": "SKIP", "gt_label": fname2labels.get(fname, [])
        }
        update_jsonl_entry(content_log_file, entry)
    
    h_low_fmt = round(best_h_low, 2)
    h_high_fmt = round(best_h_high, 2)
    f_low_fmt = round(best_f_low, 2)
    f_high_fmt = round(best_f_high, 2)

    # save summary_csv
    if final_m:
        pct_delta_s = round((base_s - final_m['CHAIR-s']) * 100, 2)
        pct_delta_i = round((base_i - final_m['CHAIR-i']) * 100, 2)
    else:
        pct_delta_s = pct_delta_i = None

    elapsed = time.time() - start_time
    mins = int(elapsed // 60)
    secs = elapsed % 60
    print(f"[INFO] Round {rnd} time: {mins}m{secs:.2f}s")

    rec = {
        "round": rnd,
        "h_low": h_low_fmt,
        "h_high": h_high_fmt,
        "f_low": f_low_fmt,
        "f_high": f_high_fmt,
        "mius": round(mius, 4),
        "lamda": round(lamda, 4),
        "ΔCHAIR-s (%)": pct_delta_s,
        "ΔCHAIR-i (%)": pct_delta_i,
        "pruned_CHAIR-s": round(final_m['CHAIR-s'], 2) if final_m else None,
        "pruned_CHAIR-i": round(final_m['CHAIR-i'], 2) if final_m else None,
        "pruned_F1": round(final_m['F1'], 2) if final_m else None,
        "baseline_CHAIR-s": round(base_s, 2),
        "baseline_CHAIR-i": round(base_i, 2),
        "baseline_F1": round(base_f1, 2),
        "time": f"{mins}m{secs:.2f}s",
    }
    df_summary = update_summary(df_summary, rec)

    print(f"[DONE] Round {rnd} → summary updated.")