{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from IPython.display import Image, display\n",
    "from PIL import Image\n",
    "import io\n",
    "import torch\n",
    "import numpy as np\n",
    "from processing_image import Preprocess\n",
    "from visualizing_image import SingleImageViz\n",
    "from modeling_frcnn import GeneralizedRCNN\n",
    "from utils import Config\n",
    "import utils\n",
    "import json\n",
    "from transformers import VisualBertForQuestionAnswering, BertTokenizerFast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/images/input.jpg\"\n",
    "URL = \"https://vqa.cloudcv.org/media/test2014/COCO_test2014_000000262567.jpg\"\n",
    "OBJ_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt\"\n",
    "ATTR_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt\"\n",
    "VQA_URL = \"https://dl.fbaipublicfiles.com/pythia/data/answers_vqa.txt\"\n",
    "\n",
    "# load object, attribute, and answer labels\n",
    "\n",
    "objids = utils.get_data(OBJ_URL)\n",
    "attrids = utils.get_data(ATTR_URL)\n",
    "vqa_answers = utils.get_data(VQA_URL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(i):\n",
    "    visual_news_data = json.load(open(\"../../../datasets/visualnews/origin/data.json\"))\n",
    "    visual_news_data_mapping = {ann[\"id\"]: ann for ann in visual_news_data}\n",
    "\n",
    "    data = json.load(open(\"../../../news_clippings/news_clippings/data/merged_balanced/test.json\"))\n",
    "    annotations = data[\"annotations\"]\n",
    "    ann_true = annotations[i]\n",
    "\n",
    "    caption = visual_news_data_mapping[ann_true[\"id\"]][\"caption\"]\n",
    "    image_path = visual_news_data_mapping[ann_true[\"image_id\"]][\"image_path\"]\n",
    "    image_path = \"../../../datasets/visualnews/origin/\"+image_path[2:]\n",
    "    image = Image.open(image_path)\n",
    "    #print(\"DATA SAMPLE\")\n",
    "    #print(\"Caption: \", caption)\n",
    "    #print(\"Misinformation (Ground Truth): {}\".format(ann_true[\"falsified\"]))\n",
    "    return image, caption, image_path, ann_true\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, caption, image_path, annotation = get_data(0)\n",
    "images, sizes, scales_yx = image_preprocess(image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#load models and components\n",
    "frcnn_cfg = Config.from_pretrained(\"unc-nlp/frcnn-vg-finetuned\")\n",
    "\n",
    "frcnn = GeneralizedRCNN.from_pretrained(\"unc-nlp/frcnn-vg-finetuned\", config=frcnn_cfg)\n",
    "\n",
    "image_preprocess = Preprocess(frcnn_cfg)\n",
    "\n",
    "bert_tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n",
    "visualbert_vqa = VisualBertForQuestionAnswering.from_pretrained(\"uclanlp/visualbert-vqa\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dict = frcnn(\n",
    "    images,\n",
    "    sizes,\n",
    "    scales_yx=scales_yx,\n",
    "    padding=\"max_detections\",\n",
    "    max_detections=frcnn_cfg.max_detections,\n",
    "    return_tensors=\"pt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = output_dict.get(\"roi_features\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"Does this caption: {}, match the image?\".format(caption)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = bert_tokenizer(\n",
    "    prompt,\n",
    "    padding=\"max_length\",\n",
    "    max_length=512,\n",
    "    truncation=True,\n",
    "    return_token_type_ids=True,\n",
    "    return_attention_mask=True,\n",
    "    add_special_tokens=True,\n",
    "    return_tensors=\"pt\",\n",
    ")\n",
    "output_vqa = visualbert_vqa(\n",
    "    input_ids=inputs.input_ids,\n",
    "    attention_mask=inputs.attention_mask,\n",
    "    visual_embeds=features,\n",
    "    visual_attention_mask=torch.ones(features.shape[:-1]),\n",
    "    token_type_ids=inputs.token_type_ids,\n",
    "    output_attentions=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = output_vqa[\"logits\"].argmax(-1)\n",
    "print(vqa_answers[pred])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "experiments",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
