import os, json ,sys
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from PIL import Image
from datetime import datetime
from collections import defaultdict


# ========== Experiment Configuration ==========
model_id = "llava-hf/llava-1.5-7b-hf"
img_dir = "path/to/val2014"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
time_log_file = "time_log_7b_5_26_hallu.txt"
content_log_file = "content_log_7b_5_26_hallu.jsonl"
used_imgs = "holo.txt"
chosen_imgs = "used_images.txt" #images you used.

max_new_tokens=128
batch_size = 1 
K = 15 
target_layers = list(range(5, 27)) # target layers for pruning, 5-18
csv_path = "auto_eigenscore_7b_5-26_hallu.csv"
prompt = "<image>\nPlease describe the image in detail."
log_file = "auto_experiment_7b_5_26_log_hallu.txt"



# ========== Add Parent Directory to Path for Custom Imports ==========
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)

from llava_hooks import batch_run_with_cache_llava
from experiments import auto_circuit_experiment

# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,use_fast=True)


# Load COCO label info
with open(ann_file, "r") as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname = imgid2fname[ann["image_id"]]
    catname = catid2name[ann["category_id"]]
    fname2labels[fname].add(catname)
fname2labels = {k: list(v) for k, v in fname2labels.items()}

# ---- Logging helpers ---
def log_time(msg):
    ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(time_log_file, "a") as f:
        f.write(f"[{ts}] {msg}\n")

def log_content(rnd, exp_id, type_, file_name, pred, gt_label):
    log_obj = {
        "round": rnd,
        "experiment": exp_id,
        "type": type_,
        "file_name": file_name,
        "pred": pred,
        "gt_label": gt_label
    }
    with open(content_log_file, "a") as f:
        f.write(json.dumps(log_obj, ensure_ascii=False) + "\n")


# ---- Prepare images for the experiment ----
all_selected_names = [] 
used_img_names = set()
if os.path.exists(used_imgs):
    with open(used_imgs, "r") as f:
        for line in f:
            imgname = line.strip()
            if imgname:
                used_img_names.add(imgname)

with open(chosen_imgs, "r") as f:
    chosen_img_names = [line.strip() for line in f if line.strip()]

chosen_img_names = [x for x in chosen_img_names if x not in used_img_names][:83]
print(f"Chosen {len(chosen_img_names)} images for this experiment.")


# ---- Main experiment loop: one image per round ----
for rnd, fname in enumerate(chosen_img_names, 1): 
    log_time(f"Round {rnd} start.")
    torch.cuda.empty_cache()
    
    # -- Prepare image sample --
    used_img_names.add(fname)
    all_selected_names.append(fname)
    img_path = f"{img_dir}/{fname}"
    img = Image.open(img_path).convert("RGB")
    patching_samples = [{
        "image": img,
        "prompt": prompt,
        "file_name": fname,
        "gt_label": fname2labels.get(fname, [])
    }]
    log_time(f"Loaded image: {fname}")
    images = [s["image"] for s in patching_samples]
    images = [img.convert("RGB") for img in images]
    prompts = [s["prompt"] for s in patching_samples]


    # -- get cache --
    _, patching_cache = batch_run_with_cache_llava(
        model, processor, patching_samples, device="cuda",batch_size=batch_size
    )
    log_time(f"Cache complete for round {rnd}.")
    
    # -- Run auto_circuit_experiment to calculate per-head attribution (eigenscore) --
    log_time(f"Auto circuit experiment start.")
    exp_id = auto_circuit_experiment(
        csv_path=csv_path,
        model=model,
        processor=processor,
        val_samples=patching_samples,  
        val_cache=patching_cache,
        ablation_scheme="mean",
        device="cuda",
        include_mlps=False,
        num_samples=K,
        layer_hidden_index=None,
        target_layers=target_layers,
        batch_size=batch_size,
    )
    log_time(f"Auto circuit experiment end. exp_id={exp_id}")

