{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, gc, random, logging, json,sys\n",
    "import torch\n",
    "from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
    "\n",
    "sys.path.insert(0, parent_dir)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO, \n",
    "    format='%(asctime)s | %(levelname)s | %(message)s',\n",
    "    handlers=[\n",
    "        logging.FileHandler(\"prune_llava_7b_5-18-hallu.log\", encoding='utf-8'),\n",
    "        logging.StreamHandler()  \n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "model_id = \"llava-hf/llava-1.5-7b-hf\"\n",
    "img_dir = \"/path/to/val2014\"\n",
    "txt_file = '../hallu_img_samples/hallu-7b.txt'\n",
    "ann_file = \"/path/to/val2014/annotations/instances_val2014.json\"\n",
    "NUM_ROUNDS = 10\n",
    "N = 16  \n",
    "K = 15\n",
    "batch_size= 16\n",
    "target_layers = list(range(5, 19))\n",
    "csv_path = \"all_experiments_eigenscore_7b_5-18-hallu.csv\"\n",
    "prompt = \"<image>\\nPlease describe the image in detail.\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(txt_file, \"r\") as f:\n",
    "    all_img_names = [line.strip() for line in f.readlines()]\n",
    "with open(ann_file, \"r\") as f:\n",
    "    coco = json.load(f)\n",
    "imgid2fname = {img[\"id\"]: img[\"file_name\"] for img in coco[\"images\"]}\n",
    "catid2name = {cat[\"id\"]: cat[\"name\"] for cat in coco[\"categories\"]}\n",
    "fname2labels = defaultdict(set)\n",
    "for ann in coco[\"annotations\"]:\n",
    "    fname = imgid2fname[ann[\"image_id\"]]\n",
    "    catname = catid2name[ann[\"category_id\"]]\n",
    "    fname2labels[fname].add(catname)\n",
    "fname2labels = {k: list(v) for k, v in fname2labels.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cuda\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=True)\n",
    "processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,use_fast=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from llava_hooks import batch_run_with_cache_llava\n",
    "from experiments import auto_circuit_experiment\n",
    "\n",
    "used_img_names = set()\n",
    "all_selected_names = []\n",
    "\n",
    "for round_id in range(NUM_ROUNDS):\n",
    "    torch.cuda.empty_cache()\n",
    "    logging.info(f\"========== Starting round {round_id+1}/{NUM_ROUNDS} ==========\")\n",
    "\n",
    "    # Available image names = all_img_names - used_img_names\n",
    "    available_names = list(set(all_img_names) - used_img_names)\n",
    "    if len(available_names) < N:\n",
    "        raise ValueError(f\"Not enough available images for N={N}, please check your total images or N setting!\")\n",
    "\n",
    "    selected_names = random.sample(available_names, N)\n",
    "    used_img_names.update(selected_names)  # Record the images used in this round\n",
    "    all_selected_names.extend(selected_names)\n",
    "\n",
    "    patching_samples = []\n",
    "    for fname in selected_names:\n",
    "        img_path = f\"{img_dir}/{fname}\"\n",
    "        img = Image.open(img_path).convert(\"RGB\")\n",
    "        patching_samples.append({\n",
    "            \"image\": img,\n",
    "            \"prompt\": prompt,\n",
    "            \"file_name\": fname,\n",
    "            \"gt_label\": fname2labels.get(fname, [])\n",
    "        })\n",
    "\n",
    "    # Preprocess inputs\n",
    "    images = [s[\"image\"] for s in patching_samples]\n",
    "    images = [img.convert(\"RGB\") for img in images]\n",
    "    prompts = [s[\"prompt\"] for s in patching_samples]\n",
    "\n",
    "    # Run inference and get cache\n",
    "    patching_logits, patching_cache = batch_run_with_cache_llava(\n",
    "        model, processor, patching_samples, device=\"cuda\", batch_size=batch_size\n",
    "    )\n",
    "\n",
    "    # Run one ablation experiment (will write to CSV and assign an experiment ID)\n",
    "    exp_id = auto_circuit_experiment(\n",
    "        csv_path=csv_path,\n",
    "        model=model,\n",
    "        processor=processor,\n",
    "        val_samples=patching_samples, \n",
    "        val_cache=patching_cache,\n",
    "        ablation_scheme=\"mean\",\n",
    "        device=\"cuda\",\n",
    "        include_mlps=False,\n",
    "        num_samples=K,\n",
    "        layer_hidden_index=None,\n",
    "        target_layers=target_layers,\n",
    "        batch_size=batch_size,\n",
    "    )\n",
    "\n",
    "    logging.info(f\"Round {round_id+1} completed, exp_id={exp_id}\")\n",
    "\n",
    "    # ---------- Proactively release all non-model resources from this round to prevent GPU memory leaks ----------\n",
    "    del patching_samples, images, prompts, patching_logits, patching_cache\n",
    "    torch.cuda.empty_cache()\n",
    "    gc.collect()\n",
    "    logging.info(f\"Resources for round {round_id+1} have been cleaned up\\n\")\n",
    "\n",
    "with open(\"chosen_imgs_7b_5_18_hallu.txt\", \"w\") as f:\n",
    "    for fname in all_selected_names:\n",
    "        f.write(f\"{fname}\\n\")\n",
    "\n",
    "logging.info(\"✅ All experiments completed.\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
