{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys,random, torch\n",
    "from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor\n",
    "from PIL import Image\n",
    "from collections import defaultdict\n",
    "import os\n",
    "import json\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))\n",
    "sys.path.insert(0, parent_dir)\n",
    "\n",
    "from pruning_llava_utils import batch_generate_llava\n",
    "from chair_metrics import batch_compute_chair_metrics\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "model_id = \"llava-hf/llava-1.5-7b-hf\"\n",
    "jsonl_file = \"baseline_captions.jsonl\"\n",
    "jsonl_file_patched = \"patched_captions.jsonl\"\n",
    "\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=True)\n",
    "processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,use_fast=True)\n",
    "\n",
    "prompt_template = \"<image>\\nPlease describe the image in detail.\"\n",
    "\n",
    "rotary_emb = model.model.language_model.rotary_emb\n",
    "num_attention_heads = model.model.language_model.config.num_attention_heads\n",
    "\n",
    "for layer in model.model.language_model.layers:\n",
    "    attn = layer.self_attn\n",
    "    setattr(attn, \"rotary_emb\", rotary_emb)\n",
    "    setattr(attn, \"num_attention_heads\", num_attention_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_dir = \"path/to/val2014\"\n",
    "ann_file = \"path/to/val2014/annotations/instances_val2014.json\"\n",
    "log_file       = \"experiment_log_patch.txt\"\n",
    "txt_file = \"../../auto_cir/chosen_img.txt\"\n",
    "N              = 1     \n",
    "n_rounds       = 120\n",
    "max_new_tokens = 128\n",
    "guided_layers  = list(range(5, 19))  # patch 5–18\n",
    "aggregation    = \"mean\"\n",
    "alpha          = 0.5\n",
    "\n",
    "def get_image_id(fname):\n",
    "    return int(os.path.splitext(fname)[0].split('_')[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_start_idx = 1\n",
    "\n",
    "vc = model.config.vision_config\n",
    "num_patches = (vc.image_size // vc.patch_size)**2 + 1\n",
    "img_end_idx   = img_start_idx + num_patches\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(txt_file) as f:\n",
    "    all_img_names = [x.strip() for x in f if x.strip()]\n",
    "import json\n",
    "with open(ann_file) as f:\n",
    "    coco = json.load(f)\n",
    "imgid2fn = {i[\"id\"]: i[\"file_name\"] for i in coco[\"images\"]}\n",
    "cat2nm   = {c[\"id\"]: c[\"name\"]      for c in coco[\"categories\"]}\n",
    "from collections import defaultdict\n",
    "fname2labels = defaultdict(list)\n",
    "for ann in coco[\"annotations\"]:\n",
    "    fname2labels[ imgid2fn[ann[\"image_id\"]] ].append(cat2nm[ann[\"category_id\"]])\n",
    "\n",
    "\n",
    "fixed = random.sample(all_img_names, N * n_rounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(log_file, \"w\") as f:\n",
    "    f.write(\"Round,Type,CHAIR-s,CHAIR-i,F1,Len\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for rnd in range(1, n_rounds + 1):\n",
    "    names   = fixed[(rnd-1)*N : rnd*N]\n",
    "    imgs    = [Image.open(os.path.join(img_dir, fn)).convert(\"RGB\") for fn in names]\n",
    "    prompts = [\"<image>\\nPlease describe the image in detail.\"] * N\n",
    "\n",
    "    with torch.no_grad():\n",
    "        preds_base = batch_generate_llava(\n",
    "            model, tokenizer, processor,\n",
    "            imgs, prompts,\n",
    "            device=\"cuda\", max_new_tokens=max_new_tokens\n",
    "        )\n",
    "        \n",
    "    with open(jsonl_file, \"a\", encoding=\"utf-8\") as f_jsonl:\n",
    "        for fn, caption in zip(names, preds_base):\n",
    "            image_id = get_image_id(fn)\n",
    "            entry = {\n",
    "                \"image_id\": image_id,\n",
    "                \"caption\": caption.strip()\n",
    "            }\n",
    "            f_jsonl.write(json.dumps(entry, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "    metrics_base = batch_compute_chair_metrics(\n",
    "        preds_base,\n",
    "        [fname2labels.get(fn, []) for fn in names]\n",
    "    )\n",
    "    line = (f\"{rnd},Baseline,\"\n",
    "            f\"{metrics_base['CHAIR-s']:.4f},\"\n",
    "            f\"{metrics_base['CHAIR-i']:.4f},\"\n",
    "            f\"{metrics_base['F1']:.4f},\"\n",
    "            f\"{metrics_base['Len']:.2f}\\n\")\n",
    "    with open(log_file, \"a\") as f:\n",
    "        f.write(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modify_attention import llama_head_guide\n",
    "\n",
    "llama_head_guide(\n",
    "    model.model.language_model,\n",
    "    guided_layer_range=guided_layers,\n",
    "    aggregation=aggregation,\n",
    "    alpha=alpha,\n",
    "    img_start_idx=img_start_idx,\n",
    "    img_end_idx=img_end_idx\n",
    ")\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "for rnd in range(1, n_rounds + 1):\n",
    "    names   = fixed[(rnd-1)*N : rnd*N]\n",
    "    imgs    = [Image.open(os.path.join(img_dir, fn)).convert(\"RGB\") for fn in names]\n",
    "    prompts = [\"<image>\\nPlease describe the image in detail.\"] * N\n",
    "\n",
    "    with torch.no_grad():\n",
    "        preds_patch = batch_generate_llava(\n",
    "            model, tokenizer, processor,\n",
    "            imgs, prompts,\n",
    "            device=\"cuda\", max_new_tokens=max_new_tokens\n",
    "        )\n",
    "\n",
    "    with open(jsonl_file_patched, \"a\", encoding=\"utf-8\") as f_jsonl:\n",
    "        for fn, caption in zip(names, preds_patch):\n",
    "            image_id = get_image_id(fn)\n",
    "            entry = {\n",
    "                \"image_id\": image_id,\n",
    "                \"caption\": caption.strip()\n",
    "            }\n",
    "            f_jsonl.write(json.dumps(entry, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "    metrics_patch = batch_compute_chair_metrics(\n",
    "        preds_patch,\n",
    "        [fname2labels.get(fn, []) for fn in names]\n",
    "    )\n",
    "    line = (f\"{rnd},Patched,\"\n",
    "            f\"{metrics_patch['CHAIR-s']:.4f},\"\n",
    "            f\"{metrics_patch['CHAIR-i']:.4f},\"\n",
    "            f\"{metrics_patch['F1']:.4f},\"\n",
    "            f\"{metrics_patch['Len']:.2f}\\n\")\n",
    "    with open(log_file, \"a\") as f:\n",
    "        f.write(line)\n",
    "\n",
    "print(\"All done.\", log_file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cir",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
