import os, gc, random, logging, json,sys
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from PIL import Image
import numpy as np
from collections import defaultdict

# ========== Add Parent Directory to Path for Custom Imports ==========
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s | %(levelname)s | %(message)s',
    handlers=[
        logging.FileHandler("prune_llava.log", encoding='utf-8'),
        logging.StreamHandler()  
    ]
)


model_id = "llava-hf/llava-1.5-13b-hf"
img_dir = "path/to/val2014"
txt_file = "../all_img_names.txt"
ann_file = "path/to/val2014/annotations/instances_val2014.json"
NUM_ROUNDS = 10
K = 15
N = 16  
batch_size = 8
target_layers = list(range(5, 19))
csv_path = "all_experiments_eigenscore_13b_5-18.csv"
prompt = "USER: <image>\nPlease describe the image in detail.\nASSISTANT:"

# Load COCO label info
with open(txt_file, "r") as f:
    all_img_names = [line.strip() for line in f.readlines()]
with open(ann_file, "r") as f:
    coco = json.load(f)
imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}
catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
fname2labels = defaultdict(set)
for ann in coco["annotations"]:
    fname = imgid2fname[ann["image_id"]]
    catname = catid2name[ann["category_id"]]
    fname2labels[fname].add(catname)
fname2labels = {k: list(v) for k, v in fname2labels.items()}

# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,use_fast=True)

from llava_hooks import batch_run_with_cache_llava
from experiments import auto_circuit_experiment

used_img_names = set()
all_selected_names = []

# ------------------ Multi-round Automated Experiment Main Loop ------------------
for round_id in range(NUM_ROUNDS):
    torch.cuda.empty_cache()
    logging.info(f"========== Starting round {round_id+1}/{NUM_ROUNDS} of experiments ==========")

    # Available image names = all_img_names - used_img_names
    available_names = list(set(all_img_names) - used_img_names)
    if len(available_names) < N:
        raise ValueError(f"Not enough available images for N={N}. Please check your total image count or the value of N!")

    selected_names = random.sample(available_names, N)
    used_img_names.update(selected_names)  # Record the images used in this round
    all_selected_names.extend(selected_names)

    patching_samples = []
    for fname in selected_names:
        img_path = f"{img_dir}/{fname}"
        img = Image.open(img_path).convert("RGB")
        patching_samples.append({
            "image": img,
            "prompt": prompt,
            "file_name": fname,
            "gt_label": fname2labels.get(fname, [])
        })

    # Preprocess input
    images = [s["image"] for s in patching_samples]
    images = [img.convert("RGB") for img in images]
    prompts = [s["prompt"] for s in patching_samples]

    # Inference and obtain cache
    patching_logits, patching_cache = batch_run_with_cache_llava(
        model, processor, patching_samples, device="cuda", batch_size=batch_size
    )

    # Perform one ablation experiment (automatically writes to CSV and assigns group ID)
    exp_id = auto_circuit_experiment(
        csv_path=csv_path,
        model=model,
        processor=processor,
        val_samples=patching_samples,  # This is a list of dicts, no actual image objects here
        val_cache=patching_cache,
        ablation_scheme="mean",
        device="cuda",
        include_mlps=False,
        num_samples=K,
        layer_hidden_index=None,
        target_layers=target_layers,
        batch_size=batch_size
    )

    logging.info(f"Round {round_id+1} experiment complete, exp_id={exp_id}")

    # ----------- Proactively release all non-model resources from this round to prevent GPU memory leaks -----------
    del patching_samples, images, prompts, patching_logits, patching_cache
    torch.cuda.empty_cache()
    gc.collect()
    logging.info(f"Resources for round {round_id+1} have been cleaned up\n")

with open("chosen_imgs_13b_5_18.txt", "w") as f:
    for fname in all_selected_names:
        f.write(f"{fname}\n")
logging.info("✅ All experiments have been completed.")