{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d894b6d-04db-495b-ae54-42c222553320",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "import matplotlib\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import clip\n",
    "import ImageReward as RM\n",
    "import numpy as np\n",
    "import torch\n",
    "from transformers import AutoModel, AutoProcessor\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04db05f6-0902-4833-bc5f-5f8aa137d0e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_model, clip_preprocess = None, None  # CLIP\n",
    "ir_model = None   # ImageReward\n",
    "\n",
    "# PickScore\n",
    "processor_name_or_path = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'\n",
    "model_pretrained_name_or_path = 'yuvalkirstain/PickScore_v1'\n",
    "pickscore_processor = AutoProcessor.from_pretrained(processor_name_or_path)\n",
    "pickscore_model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d557f0-3345-4b6d-aaab-be30110fd714",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clip_score(caption, image_paths):\n",
    "    global clip_model, clip_preprocess\n",
    "    if not clip_model:\n",
    "        clip_model, clip_preprocess = clip.load('ViT-B/32', device=device)\n",
    "        clip_model.eval()\n",
    "\n",
    "    caption = caption if isinstance(caption, list) else [caption]\n",
    "    image_paths = image_paths if isinstance(image_paths, list) else [image_paths]\n",
    "\n",
    "    image_inputs = [\n",
    "        clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device)\n",
    "        for image_path in image_paths\n",
    "    ]\n",
    "    image_inputs = torch.concat(image_inputs, axis=0)\n",
    "    text_inputs = clip.tokenize(caption).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        logits_per_image, _ = clip_model(image_inputs, text_inputs)\n",
    "        logits = logits_per_image.view(-1)\n",
    "    return (logits.cpu().numpy() / 100.0).tolist()\n",
    "\n",
    "\n",
    "def image_reward(caption, image_paths):\n",
    "    \"\"\"Returns ImageReward rewards.\"\"\"\n",
    "    global ir_model\n",
    "    if not ir_model:\n",
    "        ir_model = RM.load('ImageReward-v1.0')\n",
    "\n",
    "    caption = caption if isinstance(caption, list) else [caption]\n",
    "    image_paths = image_paths if isinstance(image_paths, list) else [image_paths]\n",
    "\n",
    "    rewards = []\n",
    "    for image_path in image_paths:\n",
    "        rewards.append(ir_model.score(caption, [image_path]))\n",
    "    return rewards\n",
    "\n",
    "\n",
    "class PickScorer:\n",
    "    def __init__(self, model, preprocessor, device='cuda'):\n",
    "        self.model = model\n",
    "        self.preprocessor = preprocessor\n",
    "        self.device = device\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _features(self, images, texts):\n",
    "        image_embs = self.model.get_image_features(**images)\n",
    "        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\n",
    "        text_embs = self.model.get_text_features(**texts)\n",
    "        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\n",
    "        return image_embs, text_embs\n",
    "\n",
    "    def _load(self, images, prompts):\n",
    "        image_inputs = self.preprocessor(\n",
    "          images=images, padding=True, truncation=True, max_length=77,\n",
    "          return_tensors='pt').to(self.device)\n",
    "        text_inputs = self.preprocessor(\n",
    "          text=prompts, padding=True, truncation=True, max_length=77,\n",
    "          return_tensors='pt').to(self.device)\n",
    "        return image_inputs, text_inputs\n",
    "\n",
    "    def score(self, image_fns, prompt):\n",
    "        image_fns = image_fns if type(image_fns) is list else [image_fns]\n",
    "        images = [Image.open(image_fn) for image_fn in image_fns]\n",
    "        image_inputs, text_inputs = self._load(images, [prompt])\n",
    "        with torch.no_grad():\n",
    "            image_embs, text_embs = self._features(image_inputs, text_inputs)\n",
    "            logits = image_embs @ text_embs.t()\n",
    "            logits = self.model.logit_scale.exp() * logits\n",
    "        return logits.view(-1).cpu().numpy() / 100.\n",
    "\n",
    "\n",
    "class TextNormPickScorer(PickScorer):\n",
    "    def __init__(self, model, preprocessor, scale=1.0, lbd=5.0, base_prompts=None, device='cuda'):\n",
    "        super().__init__(model, preprocessor, device=device)\n",
    "        self._scale = scale\n",
    "        self._lbd = lbd\n",
    "        self._base_prompts = base_prompts\n",
    "    \n",
    "    def score(self, image_fns, prompt, base_prompts=None):\n",
    "        base_prompts = self._base_prompts if self._base_prompts else base_prompts\n",
    "        if base_prompts is None:\n",
    "            raise NotImplementedError(\"base prompts are not set.\")\n",
    "            \n",
    "        image_fns = image_fns if type(image_fns) is list else [image_fns]\n",
    "        images = [Image.open(image_fn) for image_fn in image_fns]\n",
    "        \n",
    "        base_prompts = set(base_prompts) - set([prompt])\n",
    "        base_prompts = [prompt] + list(base_prompts)\n",
    "        \n",
    "        image_inputs, text_inputs = self._load(images, base_prompts)\n",
    "        with torch.no_grad():\n",
    "            image_embs, text_embs = self._features(image_inputs, text_inputs)\n",
    "            logits = image_embs @ text_embs.t()\n",
    "            logits = self.model.logit_scale.exp() * logits\n",
    "\n",
    "            probs = (logits * self._scale).softmax(dim=-1)\n",
    "            p_in, p_oth = probs[:, 0], probs[:, 1:]\n",
    "            logp_xy = p_in.clamp(min=1e-10).log()\n",
    "            \n",
    "            logp_xx = (text_embs[1:] @ text_embs[[0]].t()).view(-1)\n",
    "            delta = (logits[:, [0]] - logits[:, 1:]) / (100. * (1 - logp_xx.view(1, -1)))\n",
    "            delta = delta.amin(1)\n",
    "            \n",
    "        score = (logp_xy + self._lbd * delta).exp().cpu().numpy()\n",
    "        return score\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb43950a-66e0-4b9c-bc36-e1499e7a590b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"a realistic photo of a book and a teddy bear\"\n",
    "image_fns = [\n",
    "    \"images/book_bear_good.png\",\n",
    "    \"images/book_bear_bad.png\",\n",
    "]\n",
    "\n",
    "# Contrastive prompt sets generated by ChatGPT\n",
    "prompt_neg = [\n",
    "    'a cartoon drawing of two books and three teddy bears',\n",
    "    'a realistic photo of five pencils and three notebooks',\n",
    "    'a realistic photo of nine apples and one orange',\n",
    "    'a realistic photo of a pen and four papers',\n",
    "    'a realistic photo of six pens and two notebooks',\n",
    "    'a realistic photo of seven crayons and a coloring book',\n",
    "    'a realistic photo of three rulers and a calculator',\n",
    "    'a realistic photo of a pencil case and five erasers',\n",
    "    'a realistic photo of six markers and two sketchbooks',\n",
    "    'a realistic photo of a paintbrush and a palette',\n",
    "    'a realistic photo of six cups and two saucers',\n",
    "    'a realistic photo of three plates and a fork',\n",
    "    'a realistic photo of five spoons and three knives',\n",
    "    'a realistic photo of three bowls and a spoon',\n",
    "    'a realistic photo of three glasses and a bottle',\n",
    "    'a realistic photo of eight plates and a tablecloth',\n",
    "    'a realistic photo of a mug and a coaster',\n",
    "    'a realistic photo of two napkins and a placemat',\n",
    "    'a realistic photo of a teapot and five teacups',\n",
    "    'a realistic photo of three candles and two flowers'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8364b4a2-b72f-4b06-8b3b-3c0a85e8dc7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickscore = PickScorer(pickscore_model, pickscore_processor, device=device)\n",
    "textnorm = TextNormPickScorer(pickscore_model, pickscore_processor, device=device)\n",
    "\n",
    "scores = {}\n",
    "scores['Human'] = [\"Good\", \"Bad\"]\n",
    "scores['CLIP'] = clip_score(prompt, image_fns)\n",
    "scores['ImageReward'] = image_reward(prompt, image_fns)\n",
    "scores['PickScore'] = pickscore.score(image_fns, prompt)\n",
    "scores['TextNorm'] = textnorm.score(image_fns, prompt, prompt_neg)\n",
    "\n",
    "for method, score in scores.items():\n",
    "    print(f\"{method:10}\\t{score[0]:.4}\\t{score[1]:.4}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce5ddd90-cd75-4e4c-8efa-fc6d6d7b6d39",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
