{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys, random, torch\n",
    "from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor\n",
    "from PIL import Image\n",
    "from collections import defaultdict\n",
    "import json\n",
    "from vhr import replace_attn\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))\n",
    "sys.path.insert(0, parent_dir)\n",
    "\n",
    "from pruning_llava_utils import batch_generate_llava\n",
    "from chair_metrics import batch_compute_chair_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = \"llava-hf/llava-1.5-7b-hf\"\n",
    "img_dir = \"path/to/val2014\"\n",
    "ann_file = \"path/to/val2014/annotations/instances_val2014.json\"\n",
    "jsonl_file = \"baseline_captions.jsonl\"\n",
    "jsonl_file_vhr = \"vhr_captions.jsonl\"\n",
    "log_file = \"experiment_log_vhr.txt\"\n",
    "txt_file = \"../../auto_cir/chosen_img.txt\"\n",
    "N = 1\n",
    "n_rounds = 120\n",
    "max_new_tokens = 128\n",
    "prompt_template = \"<image>\\nPlease describe the image in detail.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)\n",
    "processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)\n",
    "\n",
    "rotary_emb = model.model.language_model.rotary_emb\n",
    "num_attention_heads = model.model.language_model.config.num_attention_heads\n",
    "for layer in model.model.language_model.layers:\n",
    "    attn = layer.self_attn\n",
    "    setattr(attn, \"rotary_emb\", rotary_emb)\n",
    "    setattr(attn, \"num_attention_heads\", num_attention_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_image_id(fname):\n",
    "    return int(os.path.splitext(fname)[0].split('_')[-1])\n",
    "\n",
    "with open(txt_file) as f:\n",
    "    all_img_names = [x.strip() for x in f if x.strip()]\n",
    "\n",
    "with open(ann_file) as f:\n",
    "    coco = json.load(f)\n",
    "imgid2fn = {i[\"id\"]: i[\"file_name\"] for i in coco[\"images\"]}\n",
    "cat2nm   = {c[\"id\"]: c[\"name\"] for c in coco[\"categories\"]}\n",
    "fname2labels = defaultdict(list)\n",
    "for ann in coco[\"annotations\"]:\n",
    "    fname2labels[imgid2fn[ann[\"image_id\"]]].append(cat2nm[ann[\"category_id\"]])\n",
    "\n",
    "\n",
    "fixed = random.sample(all_img_names, N * n_rounds)\n",
    "\n",
    "\n",
    "with open(log_file, \"w\") as f:\n",
    "    f.write(\"Round,Type,CHAIR-s,CHAIR-i,F1,Len\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============ Part 1: Baseline ============\n",
    "for rnd in range(1, n_rounds + 1):\n",
    "    names = fixed[(rnd-1)*N : rnd*N]\n",
    "    imgs = [Image.open(os.path.join(img_dir, fn)).convert(\"RGB\") for fn in names]\n",
    "    prompts = [prompt_template] * N\n",
    "\n",
    "    with torch.no_grad():\n",
    "        preds_base = batch_generate_llava(\n",
    "            model, tokenizer, processor,\n",
    "            imgs, prompts,\n",
    "            device=\"cuda\", max_new_tokens=max_new_tokens\n",
    "        )\n",
    "    with open(jsonl_file, \"a\", encoding=\"utf-8\") as f_jsonl:\n",
    "        for fn, caption in zip(names, preds_base):\n",
    "            image_id = get_image_id(fn)\n",
    "            entry = {\n",
    "                \"image_id\": image_id,\n",
    "                \"caption\": caption.strip()\n",
    "            }\n",
    "            f_jsonl.write(json.dumps(entry, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "    metrics_base = batch_compute_chair_metrics(\n",
    "        preds_base,\n",
    "        [fname2labels.get(fn, []) for fn in names]\n",
    "    )\n",
    "    line = (f\"{rnd},Baseline,\"\n",
    "            f\"{metrics_base['CHAIR-s']:.4f},\"\n",
    "            f\"{metrics_base['CHAIR-i']:.4f},\"\n",
    "            f\"{metrics_base['F1']:.4f},\"\n",
    "            f\"{metrics_base['Len']:.2f}\\n\")\n",
    "    with open(log_file, \"a\") as f:\n",
    "        f.write(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============ Part 2: VHR ============ \n",
    "\n",
    "\n",
    "vhr_layers = [1] + list(range(19-13, 19))   \n",
    "vhr_aug_ratio = 2.0               \n",
    "vhr_filter = True                \n",
    "\n",
    "\n",
    "replace_attn(model, target_layers=vhr_layers, aug_ratio=vhr_aug_ratio, filter=vhr_filter)\n",
    "\n",
    "for rnd in range(1, n_rounds + 1):\n",
    "    names = fixed[(rnd-1)*N : rnd*N]\n",
    "    imgs = [Image.open(os.path.join(img_dir, fn)).convert(\"RGB\") for fn in names]\n",
    "    prompts = [prompt_template] * N\n",
    "\n",
    "    with torch.no_grad():\n",
    "        preds_vhr = batch_generate_llava(\n",
    "            model, tokenizer, processor,\n",
    "            imgs, prompts,\n",
    "            device=\"cuda\", max_new_tokens=max_new_tokens\n",
    "        )\n",
    "\n",
    "    with open(jsonl_file_vhr, \"a\", encoding=\"utf-8\") as f_jsonl:\n",
    "        for fn, cap in zip(names, preds_vhr):\n",
    "            image_id = get_image_id(fn)\n",
    "            f_jsonl.write(json.dumps({\n",
    "                \"image_id\": image_id,\n",
    "                \"caption\": cap.strip()\n",
    "            }, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "    metrics_vhr = batch_compute_chair_metrics(\n",
    "        preds_vhr,\n",
    "        [fname2labels.get(fn, []) for fn in names]\n",
    "    )\n",
    "    line = (f\"{rnd},VHR,\"\n",
    "            f\"{metrics_vhr['CHAIR-s']:.4f},\"\n",
    "            f\"{metrics_vhr['CHAIR-i']:.4f},\"\n",
    "            f\"{metrics_vhr['F1']:.4f},\"\n",
    "            f\"{metrics_vhr['Len']:.2f}\\n\")\n",
    "    with open(log_file, \"a\") as f:\n",
    "        f.write(line)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cir",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
