{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os , sys\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from PIL import Image,UnidentifiedImageError\n",
    "from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor\n",
    "from conversation import conv_templates\n",
    "from attentionSPIN import llama_modify_spin\n",
    "\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))\n",
    "sys.path.insert(0, parent_dir)\n",
    "\n",
    "\n",
    "model_id = \"llava-hf/llava-1.5-7b-hf\"\n",
    "txt_file = \"../../auto_cir/chosen_img.txt\"\n",
    "image_dir = \"path/to/val2014\"\n",
    "output_baseline = \"baseline_captions.jsonl\"\n",
    "output_spin = \"spin_captions.jsonl\"\n",
    "\n",
    "\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)\n",
    "processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)\n",
    "\n",
    "# model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open(txt_file, \"r\") as f:\n",
    "    image_list = [line.strip() for line in f if line.strip()]\n",
    "print(f\"[INFO] Total {len(image_list)} images to process.\")\n",
    "\n",
    "prompt = \"<image>\\nPlease describe the image in detail.\"\n",
    "conv = conv_templates[\"llava_v1\"].copy()\n",
    "conv.append_message(conv.roles[0], prompt)\n",
    "conv.append_message(conv.roles[1], None)\n",
    "text_prompt = conv.get_prompt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "open(output_baseline, \"w\").close()\n",
    "\n",
    "for img_name in tqdm(image_list, desc=\"Running Baseline Inference\"):\n",
    "    img_path = os.path.join(image_dir, img_name)\n",
    "\n",
    "    try:\n",
    "        image = Image.open(img_path).convert(\"RGB\")\n",
    "    except (UnidentifiedImageError, OSError):\n",
    "        print(f\"[Skip] Cannot open image: {img_name}\")\n",
    "        continue\n",
    "\n",
    "    inputs = processor(text=text_prompt, images=image, return_tensors=\"pt\").to(\"cuda\", torch.float16)\n",
    "\n",
    "    with torch.no_grad():\n",
    "         output_ids = model.generate(\n",
    "                **inputs,\n",
    "                do_sample=False,\n",
    "                max_new_tokens=128,\n",
    "                num_beams=1,\n",
    "                use_cache=True,\n",
    "                output_attentions=False,\n",
    "                output_hidden_states=False,\n",
    "                return_dict_in_generate=False,\n",
    "        )\n",
    "    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
    "    output_text = output_text.split(\"ASSISTANT:\")[-1].strip()\n",
    "    image_id = int(img_name.split(\"_\")[-1].split(\".\")[0])\n",
    "    with open(output_baseline, \"a\", encoding=\"utf-8\") as fout:\n",
    "        fout.write(json.dumps({\"image_id\": image_id, \"caption\": output_text}) + \"\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "open(output_spin, \"w\").close()\n",
    "\n",
    "for img_name in tqdm(image_list, desc=\"Running SPIN Inference\"):\n",
    "    img_path = os.path.join(image_dir, img_name) \n",
    "    \n",
    "    try:\n",
    "        image = Image.open(img_path).convert(\"RGB\")\n",
    "    except (UnidentifiedImageError, OSError):\n",
    "        print(f\"[Skip] Cannot open image: {img_name}\")\n",
    "        continue\n",
    "\n",
    "    inputs = processor(text=text_prompt, images=image, return_tensors=\"pt\").to(\"cuda\", torch.float16)\n",
    "\n",
    "    input_ids = inputs[\"input_ids\"][0].tolist()\n",
    "    image_token_idx = 32000\n",
    "    image_token_positions = [i for i, t in enumerate(input_ids) if t == image_token_idx]\n",
    "    if not image_token_positions:\n",
    "        print(f\"[Skip] No <image> token found in {img_name}\")\n",
    "        continue\n",
    "    img_start_idx = image_token_positions[0]\n",
    "    img_end_idx = image_token_positions[-1] + 1\n",
    "\n",
    "    llama_modify_spin(\n",
    "            model=model,\n",
    "            start_layer=0,\n",
    "            end_layer=32,\n",
    "            img_start_idx=img_start_idx,\n",
    "            img_end_idx=img_end_idx,\n",
    "            routed_head=0.95,\n",
    "            use_spin_img=True,\n",
    "            small_num_mask=0.08,\n",
    "    )\n",
    "\n",
    "    with torch.no_grad():\n",
    "        output_ids_spin = model.generate(\n",
    "                **inputs,\n",
    "                do_sample=False,\n",
    "                max_new_tokens=128,\n",
    "                num_beams=1,\n",
    "                use_cache=True,\n",
    "                output_attentions=False,\n",
    "                output_hidden_states=False,\n",
    "                return_dict_in_generate=False,\n",
    "        )\n",
    "    output_text_spin = tokenizer.decode(output_ids_spin[0], skip_special_tokens=True)\n",
    "    output_text_spin = output_text_spin.split(\"ASSISTANT:\")[-1].strip()\n",
    "    image_id = int(img_name.split(\"_\")[-1].split(\".\")[0])\n",
    "    with open(output_spin, \"a\", encoding=\"utf-8\") as fout:\n",
    "        fout.write(json.dumps({\"image_id\": image_id, \"caption\": output_text_spin}) + \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cir",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
