{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flickr30k logical composition retrieval example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "from datasets import load_dataset\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from subspaces import ridge_projector, join\n",
    "from model import TransformerSubspaceEmbedder\n",
    "from train_nli import NLITrainingData\n",
    "\n",
    "CACHE_DIR = \"./.cache\"\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load SNLI-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = NLITrainingData.load(\"./runs/all-mpnet-base-v2_128x128_context35.pt\")\n",
    "\n",
    "model = TransformerSubspaceEmbedder(\n",
    "    train_data.base_model_name, train_data.N, train_data.D, train_data.lbd, two_way=False, cache_dir=CACHE_DIR,\n",
    ")\n",
    "model.load_state_dict(train_data.state_dict, strict=False)\n",
    "model.eval()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Flickr30k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"nlphuji/flickr30k\", cache_dir=CACHE_DIR, split=\"test\")\n",
    "captions, img_indices = [], []\n",
    "\n",
    "for i in range(len(dataset)):\n",
    "    cs = dataset[i][\"caption\"]\n",
    "    captions.extend(cs)\n",
    "    img_indices.extend([i]*len(cs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute caption subspaces (smooth projectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "B = 256\n",
    "chunked_captions = [captions[i:i + B] for i in range(0, len(captions), B)]\n",
    "\n",
    "P_captions = []\n",
    "for chunk in tqdm(chunked_captions):\n",
    "    with torch.no_grad():\n",
    "        P = model.encode(chunk, max_length=train_data.max_length, device=device)\n",
    "    P_captions.append(P)\n",
    "P_captions = torch.cat(P_captions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrieval with logical composition of queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text2projector = lambda text : model.encode([text], train_data.max_length, device=device)\n",
    "I = torch.eye(model.D, device=P.device)[None]\n",
    "\n",
    "P1 = text2projector(\"people\")\n",
    "P2 = text2projector(\"sitting\")\n",
    "P3 = text2projector(\"on a bench\")\n",
    "P_query = P1 @ P2 @ P3 # people AND sitting AND on a bench\n",
    "\n",
    "P1 = text2projector(\"food being prepared\")\n",
    "P2 = text2projector(\"outdoors on the grill\")\n",
    "P_query = P1 @ P2 # food being prepared AND outdoors on the grill\n",
    "\n",
    "P1 = text2projector(\"an animal interacting with a human\")\n",
    "P2 = text2projector(\"in a zoo\")\n",
    "P_query = P1 @ P2 # an animal interacting with a human AND in a zoo\n",
    "\n",
    "P1 = text2projector(\"a person walking\")\n",
    "P2 = text2projector(\"on the sidewalk\")\n",
    "P_query = P1 @ (I - P2) # a person walking AND NOT on the sidewalk\n",
    "P_query = P1 @ P2  # a person walking AND on the sidewalk\n",
    "\n",
    "P1 = text2projector(\"a person riding a bicycle\")\n",
    "P2 = text2projector(\"a person walking a dog\")\n",
    "P_query = join(P1[0], P2[0], train_data.lbd) # a person riding a bicycle OR a person walking a dog\n",
    "\n",
    "P1 = text2projector(\"a group of people\")\n",
    "P2 = text2projector(\"military\")\n",
    "P_query = P1 @ P2 # a group of people AND military\n",
    "\n",
    "P1 = text2projector(\"a child playing outdoors\")\n",
    "P2 = text2projector(\"near a road\")\n",
    "P_query = P1 @ (I - P2) # a child playing outdoors AND NOT near a road\n",
    "P_query = P1 @ P2 # a child playing outdoors AND near a road\n",
    "\n",
    "P1 = text2projector(\"a man\")\n",
    "P2 = text2projector(\"on a boat\")\n",
    "P3 = text2projector(\"is fishing\")\n",
    "P_query = P1 @ (I - P2) @ P3 # a man AND NOT on a boat AND is fishing\n",
    "\n",
    "P1 = text2projector(\"a person playing the violin\")\n",
    "P2 = text2projector(\"standing in the street\")\n",
    "P_query = P1 @ (I - P2) # a person playing the violin AND NOT standing in the street\n",
    "\n",
    "P1 = text2projector(\"a man and a surfboard\")\n",
    "P2 = text2projector(\"is surfing\")\n",
    "P_query = P1 @ (I - P2) # a man and a surfboard AND NOT is surfing\n",
    "P_query = P1 @ P2 # a man and a surfboard AND is surfing\n",
    "\n",
    "P_captions = P_captions.to(device)\n",
    "scores = torch.einsum(\"bii->b\", P_query @ P_captions) / torch.einsum(\"bii->b\", P_captions)\n",
    "idx = torch.argsort(scores, descending=True)\n",
    "for i in range(10):\n",
    "    print(f\"Normalized inclusion score: {scores[idx[i]].item():.3f} | Caption: {captions[idx[i]]} (idx {idx[i]})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display_indices = []\n",
    "i = 0\n",
    "while len(display_indices) < 6:\n",
    "    if img_indices[idx[i]] not in display_indices:\n",
    "        display_indices.append(img_indices[idx[i]])\n",
    "    i += 1\n",
    "\n",
    "fig = plt.figure(figsize=(15, 2))\n",
    "gs = gridspec.GridSpec(1, 6, wspace=0.05)\n",
    "\n",
    "for i in range(6):\n",
    "    ax = fig.add_subplot(gs[0, i])\n",
    "    img = dataset[display_indices[i]][\"image\"]\n",
    "    ax.imshow(img)\n",
    "    ax.axis(\"off\")\n",
    "    ax.set_aspect('auto')\n",
    "    print(dataset[display_indices[i]][\"caption\"])\n",
    "#plt.savefig(\"./results.pdf\", bbox_inches='tight', pad_inches=0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
