{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "import random\n",
    "import json\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "model_id = \"llava-hf/llava-1.5-7b-hf\"\n",
    "img_dir = \"path/to/val2014\"\n",
    "ann_file = \"path/to/val2014/annotations/instances_val2014.json\"\n",
    "txt_file = \"../all_img_names.txt\"\n",
    "N = 1         \n",
    "n_rounds = 600   \n",
    "max_new_tokens = 128\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
    "sys.path.insert(0, parent_dir)\n",
    "\n",
    "device = \"cuda\"\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id, torch_dtype=torch.float16, device_map=\"auto\", trust_remote_code=True\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)\n",
    "processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruning_llava_utils import batch_generate_llava ,batch_generate_llava\n",
    "from chair_metrics import batch_compute_chair_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open(txt_file, \"r\") as f:\n",
    "    all_img_names = [line.strip() for line in f]\n",
    "\n",
    "with open(ann_file, \"r\") as f:\n",
    "    coco = json.load(f)\n",
    "imgid2fname = {img[\"id\"]: img[\"file_name\"] for img in coco[\"images\"]}\n",
    "catid2name = {cat[\"id\"]: cat[\"name\"] for cat in coco[\"categories\"]}\n",
    "fname2labels = defaultdict(set)\n",
    "for ann in coco[\"annotations\"]:\n",
    "    fname = imgid2fname[ann[\"image_id\"]]\n",
    "    fname2labels[fname].add(catid2name[ann[\"category_id\"]])\n",
    "fname2labels = {k: list(v) for k, v in fname2labels.items()}\n",
    "\n",
    "prompt = \"<image>\\nPlease describe the image in detail.\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "hallucination_samples = []\n",
    "faithful_samples = []\n",
    "\n",
    "rng = list(range(len(all_img_names)))\n",
    "used = set()\n",
    "for rnd in range(n_rounds):\n",
    "    avail = list(set(rng) - used)\n",
    "    if len(avail) < N:\n",
    "        break\n",
    "    sel = random.sample(avail, N)\n",
    "    used.update(sel)\n",
    "    samples = []\n",
    "    for idx in sel:\n",
    "        fname = all_img_names[idx]\n",
    "        img = Image.open(os.path.join(img_dir, fname)).convert(\"RGB\")\n",
    "        samples.append({\"image\": img, \"prompt\": prompt, \"file_name\": fname, \"gt_label\": fname2labels.get(fname, [])})\n",
    "\n",
    "    preds = batch_generate_llava(model, tokenizer, processor,\n",
    "                                [s[\"image\"] for s in samples],\n",
    "                                [s[\"prompt\"] for s in samples],\n",
    "                                device=device, max_new_tokens=max_new_tokens)\n",
    "\n",
    "    metrics = batch_compute_chair_metrics(preds, [s[\"gt_label\"] for s in samples])\n",
    "    for det, samp in zip(metrics[\"sentence_details\"], samples):\n",
    "        if det[\"metrics\"][\"CHAIRs\"] == 1:\n",
    "            hallucination_samples.append(samp)\n",
    "        else:\n",
    "            faithful_samples.append(samp)\n",
    "\n",
    "print(f\"Collected {len(hallucination_samples)} hallucination samples, {len(faithful_samples)} faithful samples.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open('../hallu_img_samples/hallu-i.txt', 'w', encoding='utf-8') as f:\n",
    "    for sample in hallucination_samples:\n",
    "        f.write(sample['file_name'] + '\\n')\n",
    "\n",
    "\n",
    "with open('../hallu_img_samples/faith-i.txt', 'w', encoding='utf-8') as f:\n",
    "    for sample in faithful_samples:\n",
    "        f.write(sample['file_name'] + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "\n",
    "\n",
    "hall_meta = [\n",
    "    {\"file_name\": s[\"file_name\"], \"prompt\": s[\"prompt\"], \"gt_label\": s[\"gt_label\"]}\n",
    "    for s in hallucination_samples\n",
    "]\n",
    "faith_meta = [\n",
    "    {\"file_name\": s[\"file_name\"], \"prompt\": s[\"prompt\"], \"gt_label\": s[\"gt_label\"]}\n",
    "    for s in faithful_samples\n",
    "]\n",
    "\n",
    "with open(\"../hallu_img_samples/hallucination_samples-i.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(hall_meta, f, ensure_ascii=False, indent=2)\n",
    "with open(\"../hallu_img_samples/faithful_samples.json-i\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(faith_meta, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"Saved {len(hall_meta)} hallucination samples, {len(faith_meta)} faithful samples.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cir",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
