{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Union, Tuple, List, Callable, Dict\n",
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "import torch.nn.functional as nnf\n",
    "import numpy as np\n",
    "import abc\n",
    "import ptp_utils\n",
    "import seq_aligner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8611b56fe2074865b646a46b91d4d5f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "LOW_RESOURCE = False \n",
    "NUM_DIFFUSION_STEPS = 50\n",
    "GUIDANCE_SCALE = 7.5\n",
    "MAX_NUM_WORDS = 77\n",
    "gpu_num = 0\n",
    "device = torch.device(f'cuda:{gpu_num}') if torch.cuda.is_available() else torch.device('cpu')\n",
    "ldm_stable = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\").to(device)\n",
    "tokenizer = ldm_stable.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{1: 'a', 2: 'black', 3: 'cat', 4: 'and', 5: 'a', 6: 'white', 7: 'dog'}\n"
     ]
    }
   ],
   "source": [
    "prompt = \"A black cat and a white dog\"\n",
    "token_idx_to_word = {idx: ldm_stable.tokenizer.decode(t)\n",
    "                         for idx, t in enumerate(ldm_stable.tokenizer(prompt)['input_ids'])\n",
    "                         if 0 < idx < len(ldm_stable.tokenizer(prompt)['input_ids']) - 1}\n",
    "print(token_idx_to_word)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diffuser",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
