{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import argparse\n",
    "import torch\n",
    "\n",
    "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
    "from llava.conversation import conv_templates, SeparatorStyle\n",
    "from llava.model.builder import load_pretrained_model\n",
    "from llava.utils import disable_torch_init\n",
    "from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path\n",
    "\n",
    "import requests\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "import json\n",
    "from transformers import TextStreamer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(i):\n",
    "    visual_news_data = json.load(open(\"../../../datasets/visualnews/origin/data.json\"))\n",
    "    visual_news_data_mapping = {ann[\"id\"]: ann for ann in visual_news_data}\n",
    "\n",
    "    data = json.load(open(\"../../../news_clippings/news_clippings/data/merged_balanced/test.json\"))\n",
    "    annotations = data[\"annotations\"]\n",
    "    ann_true = annotations[i]\n",
    "\n",
    "    caption = visual_news_data_mapping[ann_true[\"id\"]][\"caption\"]\n",
    "    image_path = visual_news_data_mapping[ann_true[\"image_id\"]][\"image_path\"]\n",
    "    image_path = \"../../../datasets/visualnews/origin/\"+image_path[2:]\n",
    "    image = Image.open(image_path)\n",
    "    #print(\"DATA SAMPLE\")\n",
    "    #print(\"Caption: \", caption)\n",
    "    #print(\"Misinformation (Ground Truth): {}\".format(ann_true[\"falsified\"]))\n",
    "    return image, caption, image_path, ann_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"liuhaotian/llava-v1.6-34b\"\n",
    "temperature = 0.2\n",
    "max_new_tokens = 512\n",
    "num_models = 2\n",
    "models = []\n",
    "\n",
    "disable_torch_init()\n",
    "\n",
    "model_name = get_model_name_from_path(MODEL_NAME)\n",
    "tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_NAME, model_base=None, model_name=model_name, load_8bit=False, load_4bit=False, device_map=\"auto\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"llama-2\" in model_name.lower():\n",
    "    conv_mode = \"llava_llama_2\"\n",
    "elif \"mistral\" in model_name.lower():\n",
    "    conv_mode = \"mistral_instruct\"\n",
    "elif \"v1.6-34b\" in model_name.lower():\n",
    "    conv_mode = \"chatml_direct\"\n",
    "elif \"v1\" in model_name.lower():\n",
    "    conv_mode = \"llava_v1\"\n",
    "elif \"mpt\" in model_name.lower():\n",
    "    conv_mode = \"mpt\"\n",
    "else:\n",
    "    conv_mode = \"llava_v0\"\n",
    "\n",
    "conv = conv_templates[conv_mode].copy()\n",
    "if \"mpt\" in model_name.lower():\n",
    "    roles = ('user', 'assistant')\n",
    "else:\n",
    "    roles = conv.roles\n",
    "\n",
    "image, caption, img_path,_ = get_data(0)\n",
    "image_size = image.size\n",
    "# Similar operation in model_worker.py\n",
    "image_tensor = process_images([image], image_processor, model.config)\n",
    "if type(image_tensor) is list:\n",
    "    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]\n",
    "else:\n",
    "    image_tensor = image_tensor.to(model.device, dtype=torch.float16)\n",
    "\n",
    "inp = \"\"\n",
    "prompt = \"\"\"{}: The caption: {}, matches the image? Answer only as YES or NO.\"\"\".format(roles[0], caption)\n",
    "inp = prompt\n",
    "\n",
    "print(f\"{roles[1]}: \", end=\"\")\n",
    "\n",
    "if image is not None:\n",
    "    # first message\n",
    "    if model.config.mm_use_im_start_end:\n",
    "        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
    "    else:\n",
    "        inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n",
    "    image = None\n",
    "\n",
    "conv.append_message(conv.roles[0], inp)\n",
    "conv.append_message(conv.roles[1], None)\n",
    "prompt = conv.get_prompt()\n",
    "\n",
    "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)\n",
    "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
    "keywords = [stop_str]\n",
    "streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
    "\n",
    "with torch.inference_mode():\n",
    "    output_ids = model.generate(\n",
    "        input_ids,\n",
    "        images=image_tensor,\n",
    "        image_sizes=[image_size],\n",
    "        do_sample=True if temperature > 0 else False,\n",
    "        temperature=temperature,\n",
    "        max_new_tokens=max_new_tokens,\n",
    "        streamer=streamer,\n",
    "        use_cache=True)\n",
    "\n",
    "outputs = tokenizer.decode(output_ids[0]).strip()\n",
    "conv.messages[-1][-1] = outputs\n",
    "\n",
    "print(\"outputs\",outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "experiments",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
