{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, gc, random, logging, json,sys\n",
    "import torch\n",
    "from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "# ========== Add Parent Directory to Path for Custom Imports ==========\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
    "sys.path.insert(0, parent_dir)\n",
    "\n",
    "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO, \n",
    "    format='%(asctime)s | %(levelname)s | %(message)s',\n",
    "    handlers=[\n",
    "        logging.FileHandler(\"prune_llava.log\", encoding='utf-8'),\n",
    "        logging.StreamHandler()  \n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "model_id = \"llava-hf/llava-1.5-13b-hf\"\n",
    "img_dir = \"path/to/val2014\"\n",
    "txt_file = \"../all_img_names.txt\"\n",
    "ann_file = \"path/to/val2014/annotations/instances_val2014.json\"\n",
    "NUM_ROUNDS = 10\n",
    "K = 15\n",
    "N = 16  \n",
    "batch_size = 8\n",
    "target_layers = list(range(5, 19))\n",
    "csv_path = \"all_experiments_eigenscore_13b_5-18.csv\"\n",
    "prompt = \"USER: <image>\\nPlease describe the image in detail.\\nASSISTANT:\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load COCO label info\n",
    "with open(txt_file, \"r\") as f:\n",
    "    all_img_names = [line.strip() for line in f.readlines()]\n",
    "with open(ann_file, \"r\") as f:\n",
    "    coco = json.load(f)\n",
    "imgid2fname = {img[\"id\"]: img[\"file_name\"] for img in coco[\"images\"]}\n",
    "catid2name = {cat[\"id\"]: cat[\"name\"] for cat in coco[\"categories\"]}\n",
    "fname2labels = defaultdict(set)\n",
    "for ann in coco[\"annotations\"]:\n",
    "    fname = imgid2fname[ann[\"image_id\"]]\n",
    "    catname = catid2name[ann[\"category_id\"]]\n",
    "    fname2labels[fname].add(catname)\n",
    "fname2labels = {k: list(v) for k, v in fname2labels.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Load model and tokenizer\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cuda\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=True)\n",
    "processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,use_fast=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from llava_hooks import batch_run_with_cache_llava\n",
    "from experiments import auto_circuit_experiment\n",
    "\n",
    "used_img_names = set()\n",
    "all_selected_names = []\n",
    "\n",
    "# ------------------ Multi-round Automated Experiment Main Loop ------------------\n",
    "for round_id in range(NUM_ROUNDS):\n",
    "    torch.cuda.empty_cache()\n",
    "    logging.info(f\"========== Starting round {round_id+1}/{NUM_ROUNDS} of experiments ==========\")\n",
    "\n",
    "    # Available image names = all_img_names - used_img_names\n",
    "    available_names = list(set(all_img_names) - used_img_names)\n",
    "    if len(available_names) < N:\n",
    "        raise ValueError(f\"Not enough available images for N={N}. Please check your total image count or the value of N!\")\n",
    "\n",
    "    selected_names = random.sample(available_names, N)\n",
    "    used_img_names.update(selected_names)  # Record the images used in this round\n",
    "    all_selected_names.extend(selected_names)\n",
    "\n",
    "    patching_samples = []\n",
    "    for fname in selected_names:\n",
    "        img_path = f\"{img_dir}/{fname}\"\n",
    "        img = Image.open(img_path).convert(\"RGB\")\n",
    "        patching_samples.append({\n",
    "            \"image\": img,\n",
    "            \"prompt\": prompt,\n",
    "            \"file_name\": fname,\n",
    "            \"gt_label\": fname2labels.get(fname, [])\n",
    "        })\n",
    "\n",
    "    # Preprocess input\n",
    "    images = [s[\"image\"] for s in patching_samples]\n",
    "    images = [img.convert(\"RGB\") for img in images]\n",
    "    prompts = [s[\"prompt\"] for s in patching_samples]\n",
    "\n",
    "    # Inference and obtain cache\n",
    "    patching_logits, patching_cache = batch_run_with_cache_llava(\n",
    "        model, processor, patching_samples, device=\"cuda\", batch_size=batch_size\n",
    "    )\n",
    "\n",
    "    # Perform one ablation experiment (automatically writes to CSV and assigns group ID)\n",
    "    exp_id = auto_circuit_experiment(\n",
    "        csv_path=csv_path,\n",
    "        model=model,\n",
    "        processor=processor,\n",
    "        val_samples=patching_samples,  # This is a list of dicts, no actual image objects here\n",
    "        val_cache=patching_cache,\n",
    "        ablation_scheme=\"mean\",\n",
    "        device=\"cuda\",\n",
    "        include_mlps=False,\n",
    "        num_samples=K,\n",
    "        layer_hidden_index=None,\n",
    "        target_layers=target_layers,\n",
    "        batch_size=batch_size\n",
    "    )\n",
    "\n",
    "    logging.info(f\"Round {round_id+1} experiment complete, exp_id={exp_id}\")\n",
    "\n",
    "    # ----------- Proactively release all non-model resources from this round to prevent GPU memory leaks -----------\n",
    "    del patching_samples, images, prompts, patching_logits, patching_cache\n",
    "    torch.cuda.empty_cache()\n",
    "    gc.collect()\n",
    "    logging.info(f\"Resources for round {round_id+1} have been cleaned up\\n\")\n",
    "\n",
    "with open(\"chosen_imgs_13b_5_18.txt\", \"w\") as f:\n",
    "    for fname in all_selected_names:\n",
    "        f.write(f\"{fname}\\n\")\n",
    "logging.info(\"✅ All experiments have been completed.\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
