{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import requests\n",
    "import json\n",
    "from PIL import Image\n",
    "from transformers import CLIPProcessor, CLIPModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(i):\n",
    "    visual_news_data = json.load(open(\"../../../datasets/visualnews/origin/data.json\"))\n",
    "    visual_news_data_mapping = {ann[\"id\"]: ann for ann in visual_news_data}\n",
    "\n",
    "    data = json.load(open(\"../../../news_clippings/news_clippings/data/merged_balanced/test.json\"))\n",
    "    annotations = data[\"annotations\"]\n",
    "    ann_true = annotations[i]\n",
    "\n",
    "    caption = visual_news_data_mapping[ann_true[\"id\"]][\"caption\"]\n",
    "    image_path = visual_news_data_mapping[ann_true[\"image_id\"]][\"image_path\"]\n",
    "    image_path = \"../../../datasets/visualnews/origin/\"+image_path[2:]\n",
    "    image = Image.open(image_path)\n",
    "    #print(\"DATA SAMPLE\")\n",
    "    #print(\"Caption: \", caption)\n",
    "    #print(\"Misinformation (Ground Truth): {}\".format(ann_true[\"falsified\"]))\n",
    "    return image, caption, image_path, ann_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_dtype = torch.float16\n",
    "\n",
    "model = CLIPModel.from_pretrained(\n",
    "    \"openai/clip-vit-base-patch32\",\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch_dtype,\n",
    ")\n",
    "\n",
    "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, caption, image_path, annotation = get_data(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt1 = \"The caption: {}, matches the image\".format(caption)\n",
    "prompt2 = \"The caption: {}, does not match the image\".format(caption)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = processor(text=[prompt1, prompt2], images=image, return_tensors=\"pt\", padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(**inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs = outputs.logits_per_image.softmax(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = [False, True]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results[int(torch.argmax(probs))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "experiments",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
