{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35150174",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from model_zoo import get_model\n",
    "from dataset_zoo import VG_Relation, VG_Attribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d2d2e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Please put your data root directory below. We'll download VG-Relation and VG-Attribution images here. \n",
    "# Will be a 1GB zip file (a subset of GQA).\n",
    "root_dir=\"~/.cache\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44bf8eb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, preprocess = get_model(model_name=\"openai-clip:ViT-B/32\", device=\"cuda\", root_dir=root_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b942fb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the VG-R dataset\n",
    "vgr_dataset = VG_Relation(image_preprocess=preprocess, download=True, root_dir=root_dir)\n",
    "vgr_loader = DataLoader(vgr_dataset, batch_size=16, shuffle=False)\n",
    "\n",
    "# Compute the scores for each test case\n",
    "vgr_scores = model.get_retrieval_scores_batched(vgr_loader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0da0737",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the macro accuracy\n",
    "vgr_records = vgr_dataset.evaluate_scores(vgr_scores)\n",
    "symmetric = ['adjusting', 'attached to', 'between', 'bigger than', 'biting', 'boarding', 'brushing', 'chewing', 'cleaning', 'climbing', 'close to', 'coming from', 'coming out of', 'contain', 'crossing', 'dragging', 'draped over', 'drinking', 'drinking from', 'driving', 'driving down', 'driving on', 'eating from', 'eating in', 'enclosing', 'exiting', 'facing', 'filled with', 'floating in', 'floating on', 'flying', 'flying above', 'flying in', 'flying over', 'flying through', 'full of', 'going down', 'going into', 'going through', 'grazing in', 'growing in', 'growing on', 'guiding', 'hanging from', 'hanging in', 'hanging off', 'hanging over', 'higher than', 'holding onto', 'hugging', 'in between', 'jumping off', 'jumping on', 'jumping over', 'kept in', 'larger than', 'leading', 'leaning over', 'leaving', 'licking', 'longer than', 'looking in', 'looking into', 'looking out', 'looking over', 'looking through', 'lying next to', 'lying on top of', 'making', 'mixed with', 'mounted on', 'moving', 'on the back of', 'on the edge of', 'on the front of', 'on the other side of', 'opening', 'painted on', 'parked at', 'parked beside', 'parked by', 'parked in', 'parked in front of', 'parked near', 'parked next to', 'perched on', 'petting', 'piled on', 'playing', 'playing in', 'playing on', 'playing with', 'pouring', 'reaching for', 'reading', 'reflected on', 'riding on', 'running in', 'running on', 'running through', 'seen through', 'sitting behind', 'sitting beside', 'sitting by', 'sitting in front of', 'sitting near', 'sitting next to', 'sitting under', 'skiing down', 'skiing on', 'sleeping in', 'sleeping on', 'smiling at', 'sniffing', 'splashing', 'sprinkled on', 'stacked on', 'standing against', 'standing around', 'standing behind', 'standing beside', 'standing in front of', 'standing near', 'standing next to', 'staring at', 'stuck in', 'surrounding', 'swimming in', 'swinging', 'talking to', 'topped with', 'touching', 'traveling down', 'traveling on', 'tying', 'typing on', 'underneath', 'wading in', 'waiting for', 'walking across', 'walking by', 'walking down', 'walking next to', 'walking through', 'working in', 'working on', 'worn on', 'wrapped around', 'wrapped in', 'by', 'of', 'near', 'next to', 'with', 'beside', 'on the side of', 'around']\n",
    "df = pd.DataFrame(vgr_records)\n",
    "df = df[~df.Relation.isin(symmetric)]\n",
    "print(f\"VG-Relation Macro Accuracy: {df.Accuracy.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe7ced3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the VG-A dataset\n",
    "vga_dataset = VG_Attribution(image_preprocess=preprocess, download=True, root_dir=root_dir)\n",
    "vga_loader = DataLoader(vga_dataset, batch_size=16, shuffle=False)\n",
    "# Compute the scores for each test case\n",
    "vga_scores = model.get_retrieval_scores_batched(vga_loader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "760c2ef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the macro accuracy\n",
    "vga_records = vga_dataset.evaluate_scores(vga_scores)\n",
    "df = pd.DataFrame(vga_records)\n",
    "print(f\"VG-Attribution Macro Accuracy: {df.Accuracy.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c974373",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
