{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "140926f1-774d-4fa2-9bff-6f2af2b4e66b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import demo_util\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "import imagenet_classes\n",
    "from IPython.display import display\n",
    "import os\n",
    "from huggingface_hub import hf_hub_download\n",
    "from modeling.maskgit import ImageBert\n",
    "from modeling.titok import TiTok"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eea3494-1b3b-4350-94c3-3682847fdbc1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load the pretrained models from huggingface\n",
    "# supported tokenizer: [tokenizer_titok_l32_imagenet, tokenizer_titok_b64_imagenet, tokenizer_titok_s128_imagenet]\n",
    "# titok_tokenizer = TiTok.from_pretrained(\"yucornetto/tokenizer_titok_l32_imagenet\")\n",
    "# titok_generator = ImageBert.from_pretrained(\"yucornetto/generator_titok_l32_imagenet\")\n",
    "\n",
    "# or alternatively, downloads from hf\n",
    "# hf_hub_download(repo_id=\"fun-research/TiTok\", filename=\"tokenizer_titok_l32.bin\", local_dir=\"./\")\n",
    "# hf_hub_download(repo_id=\"fun-research/TiTok\", filename=\"generator_titok_l32.bin\", local_dir=\"./\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97141d6-ecb3-4c32-bd44-2abb34906824",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3c0babe-1c3e-47e2-b57d-bc2fb84dbbb4",
   "metadata": {},
   "source": [
    "## Prepare the TiTok models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c35007-978b-4ce4-9140-dac536a6971d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "config = demo_util.get_config(\"configs/infer/TiTok/titok_l32.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37903799-6008-4013-86f8-7fbd6e345038",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40d6cc42-dccf-42f3-898c-074e242ce71f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# supported tokenizer: [tokenizer_titok_l32_imagenet, tokenizer_titok_b64_imagenet, tokenizer_titok_s128_imagenet]\n",
    "titok_tokenizer = TiTok.from_pretrained(\"yucornetto/tokenizer_titok_l32_imagenet\")\n",
    "titok_tokenizer.eval()\n",
    "titok_tokenizer.requires_grad_(False)\n",
    "# or alternatively, downloads from hf\n",
    "\n",
    "# hf_hub_download(repo_id=\"fun-research/TiTok\", filename=\"tokenizer_titok_l32.bin\", local_dir=\"./\")\n",
    "# titok_tokenizer = demo_util.get_titok_tokenizer(config)\n",
    "\n",
    "print(titok_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b715a90d-b7d4-4d62-9dca-4701de86e8c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# supported generator: [generator_titok_l32_imagenet, generator_titok_b64_imagenet, generator_titok_s128_imagenet]\n",
    "titok_generator = ImageBert.from_pretrained(\"yucornetto/generator_titok_l32_imagenet\")\n",
    "titok_generator.eval()\n",
    "titok_generator.requires_grad_(False)\n",
    "\n",
    "# or alternatively, downloads from hf\n",
    "# hf_hub_download(repo_id=\"fun-research/TiTok\", filename=\"generator_titok_l32.bin\", local_dir=\"./\")\n",
    "# titok_generator = demo_util.get_titok_generator(config)\n",
    "print(titok_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a54d50f4-3e83-493b-b197-4bd043347e34",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b1de95-d3c8-4c94-8345-5b8666b65d6f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "titok_tokenizer = titok_tokenizer.to(device)\n",
    "titok_generator = titok_generator.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "226bd4ce-82b3-47ca-8210-655ac397fec7",
   "metadata": {},
   "source": [
    "## Tokenize and Reconstruct an image with 32 discrete tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94cc58e5-e78c-4c3c-a8a9-d82ff5c2ea4e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Tokenize an Image into 32 discrete tokens\n",
    "\n",
    "def tokenize_and_reconstruct(img_path):\n",
    "    original_image = Image.open(img_path)\n",
    "    image = torch.from_numpy(np.array(original_image).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0\n",
    "    encoded_tokens = titok_tokenizer.encode(image.to(device))[1][\"min_encoding_indices\"]\n",
    "    reconstructed_image = titok_tokenizer.decode_tokens(encoded_tokens)\n",
    "    reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)\n",
    "    reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to(\"cpu\", dtype=torch.uint8).numpy()[0]\n",
    "    reconstructed_image = Image.fromarray(reconstructed_image)\n",
    "    print(f\"Input Image is represented by codes {encoded_tokens} with shape {encoded_tokens.shape}\")\n",
    "    print(\"orginal image:\")\n",
    "    display(original_image)\n",
    "    print(\"reconstructed image:\")\n",
    "    display(reconstructed_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a38689f-9d22-451f-b487-7ef3b160b09d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenize_and_reconstruct(\"assets/ILSVRC2012_val_00008636.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66eed36e-97c4-407a-8039-2320a4905f5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenize_and_reconstruct(\"assets/ILSVRC2012_val_00010240.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45a75ef7-a3be-45e1-9c8f-e1b9c720586b",
   "metadata": {},
   "source": [
    "## Generate an image from 32 discrete tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "419a473b-bcaf-41e3-909c-f4aedeca286c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_labels = [torch.randint(0, 999, size=(1,)).item()]\n",
    "\n",
    "# The guidance_scale and randomize_temperature can be adjusted to trade-off between quality and diversity.\n",
    "generated_image = demo_util.sample_fn(\n",
    "    generator=titok_generator,\n",
    "    tokenizer=titok_tokenizer,\n",
    "    labels=sample_labels,\n",
    "    guidance_scale=3.5,\n",
    "    randomize_temperature=1.0,\n",
    "    num_sample_steps=8,\n",
    "    device=device\n",
    ")\n",
    "\n",
    "for i in range(generated_image.shape[0]):\n",
    "    print(f\"labels {sample_labels[i]}, {imagenet_classes.imagenet_idx2classname[sample_labels[i]]}\")\n",
    "    display(Image.fromarray(generated_image[i]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
