{
  "metadata": {
    "language_info": {
      "name": "python",
      "version": "3.6.4",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "accelerator": "GPU",
    "gpuClass": "standard",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "a93132d117fa44248fafc4cbdd4dd3cf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7d61b78e2f0a47e69128396c509a93d4",
              "IPY_MODEL_8dac24333cc2415ab98605297b929f0a",
              "IPY_MODEL_15cd74648a284e62b8b9e5cc31bcaad6"
            ],
            "layout": "IPY_MODEL_2080c04b509c455ea704afc0119c0903"
          }
        },
        "7d61b78e2f0a47e69128396c509a93d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1b3f96e978d344ac974c6c91ff785d4e",
            "placeholder": "​",
            "style": "IPY_MODEL_b39527c3d86f4e5086cefa1c2bf1a3ba",
            "value": " 80%"
          }
        },
        "8dac24333cc2415ab98605297b929f0a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_eafcc48a2184475fbc0baa0a8681a43a",
            "max": 95,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_a8eb14c6b14d49dbaf2fb8ecc3b492eb",
            "value": 76
          }
        },
        "15cd74648a284e62b8b9e5cc31bcaad6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0a05339440df4fb59a2cce4daf651f2d",
            "placeholder": "​",
            "style": "IPY_MODEL_792f7ee4644f4afb939c0772c4870847",
            "value": " 76/95 [01:36&lt;00:23,  1.25s/it]"
          }
        },
        "2080c04b509c455ea704afc0119c0903": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1b3f96e978d344ac974c6c91ff785d4e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b39527c3d86f4e5086cefa1c2bf1a3ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "eafcc48a2184475fbc0baa0a8681a43a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a8eb14c6b14d49dbaf2fb8ecc3b492eb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0a05339440df4fb59a2cce4daf651f2d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "792f7ee4644f4afb939c0772c4870847": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat_minor": 0,
  "nbformat": 4,
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "normal_label = \"five\"\n",
        "normal_class = 5"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:17:55.073085Z",
          "iopub.execute_input": "2023-05-23T09:17:55.073426Z",
          "iopub.status.idle": "2023-05-23T09:17:55.079159Z",
          "shell.execute_reply.started": "2023-05-23T09:17:55.073359Z",
          "shell.execute_reply": "2023-05-23T09:17:55.078343Z"
        },
        "trusted": true,
        "id": "oL3vbPhQ6jv5"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Generate Prompt"
      ],
      "metadata": {
        "id": "jNa_aDzd6jv6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import gensim\n",
        "import gensim.downloader\n",
        "from nltk.stem import WordNetLemmatizer\n",
        "import numpy as np\n",
        "import re\n",
        "from torchvision.datasets import MNIST"
      ],
      "metadata": {
        "_uuid": "1d71046b34cff8e56f33a21f55a8f4ca39b41047",
        "_cell_guid": "84fc472a-9c46-4fbb-ad0c-6083863bfb4b",
        "execution": {
          "iopub.status.busy": "2023-05-23T07:26:26.847013Z",
          "iopub.execute_input": "2023-05-23T07:26:26.847469Z",
          "iopub.status.idle": "2023-05-23T07:26:26.851902Z",
          "shell.execute_reply.started": "2023-05-23T07:26:26.847400Z",
          "shell.execute_reply": "2023-05-23T07:26:26.850969Z"
        },
        "trusted": true,
        "id": "H2Qwwvam6jv7"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import nltk\n",
        "nltk.download('wordnet')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DiJWvmlk-45F",
        "outputId": "e37d9710-4182-4b5c-b98b-f874aac94516"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "[nltk_data] Downloading package wordnet to /root/nltk_data...\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model = gensim.downloader.load('word2vec-google-news-300')"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T07:33:21.814105Z",
          "iopub.execute_input": "2023-05-23T07:33:21.814458Z",
          "iopub.status.idle": "2023-05-23T07:39:28.784422Z",
          "shell.execute_reply.started": "2023-05-23T07:33:21.814401Z",
          "shell.execute_reply": "2023-05-23T07:39:28.783685Z"
        },
        "trusted": true,
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Xh6GPUlH6jv8",
        "outputId": "c7ddfc0e-c278-40a7-cb61-c6f441e12425"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[==================================================] 100.0% 1662.8/1662.8MB downloaded\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "wl =WordNetLemmatizer()"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T07:39:28.785374Z",
          "iopub.execute_input": "2023-05-23T07:39:28.785614Z",
          "iopub.status.idle": "2023-05-23T07:39:29.347223Z",
          "shell.execute_reply.started": "2023-05-23T07:39:28.785571Z",
          "shell.execute_reply": "2023-05-23T07:39:29.346560Z"
        },
        "trusted": true,
        "id": "OcGZ55as6jv9"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pattern = '[a-z]+'\n",
        "topn = 1000\n",
        "all_similars = model.most_similar(positive=[normal_label], topn = topn)\n",
        "normalized_similars_words = []\n",
        "normalized_similars_probs = []\n",
        "for similar in all_similars:\n",
        "    normalized = wl.lemmatize(similar[0].lower())\n",
        "    result = re.match(pattern, normalized)\n",
        "    if result is None or result.group() != normalized:\n",
        "        continue\n",
        "    if normalized != normal_label and normalized not in normalized_similars_words:\n",
        "        normalized_similars_words.append(normalized)\n",
        "        normalized_similars_probs.append(similar[1] ** 10)    \n",
        "\n",
        "probabilities = np.zeros(len(normalized_similars_probs) + 1)\n",
        "probabilities[1:] = np.array(normalized_similars_probs) / sum(normalized_similars_probs)\n",
        "\n",
        "cum_prob = np.cumsum(probabilities, axis=-1)[:, None]\n",
        "r = np.random.uniform(size=(50))\n",
        "samples = np.argmax((cum_prob > r).T, axis=-1)\n",
        "prompts = []\n",
        "for i in samples:\n",
        "    prompts.append(normalized_similars_words[i-1])"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:18:18.922393Z",
          "iopub.execute_input": "2023-05-23T09:18:18.922715Z",
          "iopub.status.idle": "2023-05-23T09:18:19.403681Z",
          "shell.execute_reply.started": "2023-05-23T09:18:18.922656Z",
          "shell.execute_reply": "2023-05-23T09:18:19.402842Z"
        },
        "trusted": true,
        "id": "XNAHG9v16jv9"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "negative_adjectives = [\n",
        "    f'A photo of {normal_label} with a crack',\n",
        "    f'A photo of a broken {normal_label}',\n",
        "    f'A photo of {normal_label} with a defect', \n",
        "    f'A photo of {normal_label} with damage',\n",
        "    f'A photo of {normal_label} with a scratch', \n",
        "    f'A photo of {normal_label} with a hole',\n",
        "    f'A photo of {normal_label} torn',\n",
        "    f'A photo of {normal_label} cut',\n",
        "    f'A photo of {normal_label} with contamination',\n",
        "    f'A photo of {normal_label} with a fracture',\n",
        "    f'A photo of a damaged {normal_label}',\n",
        "    f'A photo of a fractured {normal_label}',\n",
        "    f'A photo of {normal_label} with destruction',\n",
        "    f'A photo of {normal_label} with a mark',\n",
        "]"
      ],
      "metadata": {
        "id": "3KPoIUYYBNxt"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompts = prompts + negative_adjectives"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:18:20.728452Z",
          "iopub.execute_input": "2023-05-23T09:18:20.728771Z",
          "iopub.status.idle": "2023-05-23T09:18:20.735528Z",
          "shell.execute_reply.started": "2023-05-23T09:18:20.728710Z",
          "shell.execute_reply": "2023-05-23T09:18:20.734726Z"
        },
        "trusted": true,
        "id": "XYyMJtI46jv-"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompts"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5r7_HoF0Btiw",
        "outputId": "2fda6039-217b-4fe1-a4d9-ac822acd840a"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['three',\n",
              " 'eleven',\n",
              " 'three',\n",
              " 'six',\n",
              " 'eight',\n",
              " 'seven',\n",
              " 'thirty',\n",
              " 'eight',\n",
              " 'seven',\n",
              " 'two',\n",
              " 'seven',\n",
              " 'six',\n",
              " 'four',\n",
              " 'six',\n",
              " 'four',\n",
              " 'eight',\n",
              " 'nine',\n",
              " 'eight',\n",
              " 'seven',\n",
              " 'seven',\n",
              " 'seven',\n",
              " 'seven',\n",
              " 'seven',\n",
              " 'eight',\n",
              " 'six',\n",
              " 'three',\n",
              " 'four',\n",
              " 'nine',\n",
              " 'sixteen',\n",
              " 'three',\n",
              " 'seven',\n",
              " 'six',\n",
              " 'nine',\n",
              " 'three',\n",
              " 'eight',\n",
              " 'seven',\n",
              " 'eight',\n",
              " 'eight',\n",
              " 'seven',\n",
              " 'six',\n",
              " 'four',\n",
              " 'nine',\n",
              " 'four',\n",
              " 'six',\n",
              " 'two',\n",
              " 'six',\n",
              " 'four',\n",
              " 'three',\n",
              " 'seven',\n",
              " 'three',\n",
              " 'A photo of five with a crack',\n",
              " 'A photo of a broken five',\n",
              " 'A photo of five with a defect',\n",
              " 'A photo of five with damage',\n",
              " 'A photo of five with a scratch',\n",
              " 'A photo of five with a hole',\n",
              " 'A photo of five torn',\n",
              " 'A photo of five cut',\n",
              " 'A photo of five with contamination',\n",
              " 'A photo of five with a fracture',\n",
              " 'A photo of a damaged five',\n",
              " 'A photo of a fractured five',\n",
              " 'A photo of five with destruction',\n",
              " 'A photo of five with a mark']"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Generate Image"
      ],
      "metadata": {
        "id": "4gd4VHmV6jv-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 20"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:18:25.749954Z",
          "iopub.execute_input": "2023-05-23T09:18:25.750281Z",
          "iopub.status.idle": "2023-05-23T09:18:25.754425Z",
          "shell.execute_reply.started": "2023-05-23T09:18:25.750212Z",
          "shell.execute_reply": "2023-05-23T09:18:25.753513Z"
        },
        "trusted": true,
        "id": "Clufv76R6jv_"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "import torchvision\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "import numpy as np \n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "!pip install gdown\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import gdown\n",
        "!gdown --id 1cZwCFQOSUq4w2ckZjhFRlJmZ5cvSdgbs\n",
        "\n",
        "import zipfile\n",
        "with zipfile.ZipFile(\"./glide-text2im-code.zip\", 'r') as zip_ref:\n",
        "    zip_ref.extractall(\"./\")\n",
        "\n",
        "\n",
        "!pip install ftfy\n",
        "import ftfy\n",
        "\n",
        "import sys\n",
        "#!git clone https://github.com/openai/glide-text2im.git\n",
        "sys.path.append('./glide-text2im')\n",
        "\n",
        "\n",
        "\n",
        "import torch\n",
        "import matplotlib.pyplot as plt"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:20:53.800352Z",
          "iopub.execute_input": "2023-05-23T09:20:53.800717Z",
          "iopub.status.idle": "2023-05-23T09:20:59.716084Z",
          "shell.execute_reply.started": "2023-05-23T09:20:53.800658Z",
          "shell.execute_reply": "2023-05-23T09:20:59.715091Z"
        },
        "trusted": true,
        "id": "B6eZjejr6jv_"
      },
      "execution_count": 22,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def show_samples(x):\n",
        "    x = x.permute(0, 2, 3, 1).detach().cpu().numpy()\n",
        "    img = image_grid(x)\n",
        "    plt.figure(figsize=(4,4))\n",
        "    plt.axis('off')\n",
        "    plt.imshow(img)\n",
        "    plt.show()\n",
        "def image_grid(x):\n",
        "    size =64# config.data.image_size\n",
        "    channels =3# config.data.num_channels\n",
        "    img = x.reshape(-1, size, size, channels)\n",
        "    w = int(np.sqrt(img.shape[0]))\n",
        "    img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))\n",
        "    return img"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:21:02.570025Z",
          "iopub.execute_input": "2023-05-23T09:21:02.570451Z",
          "iopub.status.idle": "2023-05-23T09:21:02.602349Z",
          "shell.execute_reply.started": "2023-05-23T09:21:02.570380Z",
          "shell.execute_reply": "2023-05-23T09:21:02.597671Z"
        },
        "trusted": true,
        "id": "-SLaTEVH6jwA"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from PIL import Image\n",
        "from IPython.display import display\n",
        "import torch as th\n",
        "import torch.nn as nn\n",
        "\n",
        "from glide_text2im.clip.model_creation import create_clip_model\n",
        "from glide_text2im.download import load_checkpoint\n",
        "from glide_text2im.model_creation import (\n",
        "    create_model_and_diffusion,\n",
        "    model_and_diffusion_defaults,\n",
        "    model_and_diffusion_defaults_upsampler,\n",
        ")\n",
        "from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer"
      ],
      "metadata": {
        "execution": {
          "iopub.status.busy": "2023-05-23T09:21:03.875726Z",
          "iopub.execute_input": "2023-05-23T09:21:03.876028Z",
          "iopub.status.idle": "2023-05-23T09:21:03.908280Z",
          "shell.execute_reply.started": "2023-05-23T09:21:03.875969Z",
          "shell.execute_reply": "2023-05-23T09:21:03.907311Z"
        },
        "trusted": true,
        "id": "hTsEH_ZS6jwA"
      },
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "has_cuda = th.cuda.is_available()\n",
        "device = th.device('cpu' if not has_cuda else 'cuda')"
      ],
      "metadata": {
        "id": "GRHhQwWS6jwB"
      },
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "options = model_and_diffusion_defaults()\n",
        "options['use_fp16'] = has_cuda\n",
        "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
        "model, diffusion = create_model_and_diffusion(**options)\n",
        "model.eval()\n",
        "if has_cuda:\n",
        "    model.convert_to_fp16()\n",
        "model.to(device)\n",
        "model.load_state_dict(load_checkpoint('base', device))\n",
        "print('total base parameters', sum(x.numel() for x in model.parameters()))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kRwEETQ06jwB",
        "outputId": "cf710f9c-dded-4e1f-cc25-e6967d0d1c04"
      },
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "total base parameters 385030726\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "clip_model = create_clip_model(device=device)\n",
        "clip_model.image_encoder.load_state_dict(load_checkpoint('clip/image-enc', device))\n",
        "clip_model.text_encoder.load_state_dict(load_checkpoint('clip/text-enc', device))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ESTzpFM66jwC",
        "outputId": "7f3d5f09-2fb5-45dd-d1b1-a9b41173a5ab"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def show_images(batch: th.Tensor):\n",
        "    \"\"\" Display a batch of images inline. \"\"\"\n",
        "    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
        "    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
        "    display(Image.fromarray(reshaped.numpy()))"
      ],
      "metadata": {
        "id": "42qhNbTV6jwC"
      },
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "guidance_scale = 10\n",
        "upsample_temp = 0.997"
      ],
      "metadata": {
        "id": "IfZPKFeS6jwC"
      },
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "img_transform_simple = transforms.Compose([transforms.Grayscale(3), transforms.ToTensor(),transforms.Resize((64,64))])\n",
        "train_set = MNIST(root = './', train=True, download=True, transform=img_transform_simple)\n",
        "idx = (torch.tensor(train_set.targets) == normal_class) \n",
        "train_set.targets = torch.tensor(train_set.targets)[idx]\n",
        "train_set.data = train_set.data[idx]\n",
        "images_train=[]\n",
        "for x,_ in train_set:\n",
        "    images_train.append(x)\n",
        "DESIRED_NUMBER_OF_SAMPLES_GENERATED = len(images_train)\n",
        "images_train_t = torch.stack(images_train)\n",
        "train_loader = DataLoader(images_train_t, batch_size=batch_size, shuffle=True)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rOR2dHYN6jwC",
        "outputId": "850262f6-3223-4774-e5bb-722b0ae79f97"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-30-854a825c6c85>:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  idx = (torch.tensor(train_set.targets) == normal_class)\n",
            "<ipython-input-30-854a825c6c85>:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  train_set.targets = torch.tensor(train_set.targets)[idx]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torchvision.transforms as transforms\n",
        "trans_to_32 = transforms.Compose([  transforms.Resize((32,32))])"
      ],
      "metadata": {
        "id": "DtZ2rNWf6jwD"
      },
      "execution_count": 31,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import random\n",
        "\n",
        "all_generated_samples = []\n",
        "\n",
        "last_save = 50\n",
        "save_step = 50\n",
        "\n",
        "def desired_no_samples_reached():\n",
        "    return len(all_generated_samples) * batch_size >= DESIRED_NUMBER_OF_SAMPLES_GENERATED\n",
        "\n",
        "while not desired_no_samples_reached():\n",
        "    for x in train_loader: \n",
        "        x = x.to(device)\n",
        "        random_name = random.choice(prompts)        \n",
        "        print(\"prompt: \", random_name)\n",
        "        tokens = model.tokenizer.encode(random_name)\n",
        "        init = x.mul(2).sub(1)\n",
        "        batch_size = x.shape[0]\n",
        "\n",
        "        tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
        "            tokens, options['text_ctx']\n",
        "        )\n",
        "\n",
        "        # Pack the tokens together into model kwargs. init batch_size\n",
        "        model_kwargs = dict(\n",
        "            tokens=th.tensor([tokens] * batch_size, device=device),\n",
        "            mask=th.tensor([mask] * batch_size, dtype=th.bool, device=device),\n",
        "        )\n",
        "\n",
        "        # Setup guidance function for CLIP model.\n",
        "        cond_fn = clip_model.cond_fn([random_name] * batch_size, guidance_scale)\n",
        "\n",
        "        # Sample from the base model. batch\n",
        "        model.del_cache()\n",
        "        generated_samples = diffusion.p_sample_loop(\n",
        "            model,\n",
        "            (batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
        "            device=device,\n",
        "            clip_denoised=True,\n",
        "            progress=True,\n",
        "            model_kwargs=model_kwargs,\n",
        "            cond_fn=cond_fn,\n",
        "            skip_timesteps=5,\n",
        "            init_image=init.cuda()\n",
        "        )\n",
        "        model.del_cache()\n",
        "        all_generated_samples.append(generated_samples.detach().cpu())\n",
        "        all_generated_samples_tensor = torch.cat(all_generated_samples)\n",
        "        all_generated_samples_tensor = trans_to_32(all_generated_samples_tensor)\n",
        "        \n",
        "        # Show the output\n",
        "        if all_generated_samples_tensor.size(0) >= last_save or desired_no_samples_reached():\n",
        "            last_save += save_step\n",
        "            with open(f'./MNIST_GLIDE_NormalClass_{normal_label}.npy', 'wb') as f:\n",
        "                np.save(f, all_generated_samples_tensor.detach().cpu().numpy())\n",
        "            print(f\"checkpoint saved with {len(all_generated_samples) * batch_size} samples\")\n",
        "        show_images(init)\n",
        "        show_images(generated_samples)\n",
        "        if desired_no_samples_reached():\n",
        "            break\n",
        "        "
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 66,
          "referenced_widgets": [
            "a93132d117fa44248fafc4cbdd4dd3cf",
            "7d61b78e2f0a47e69128396c509a93d4",
            "8dac24333cc2415ab98605297b929f0a",
            "15cd74648a284e62b8b9e5cc31bcaad6",
            "2080c04b509c455ea704afc0119c0903",
            "1b3f96e978d344ac974c6c91ff785d4e",
            "b39527c3d86f4e5086cefa1c2bf1a3ba",
            "eafcc48a2184475fbc0baa0a8681a43a",
            "a8eb14c6b14d49dbaf2fb8ecc3b492eb",
            "0a05339440df4fb59a2cce4daf651f2d",
            "792f7ee4644f4afb939c0772c4870847"
          ]
        },
        "id": "0I8cydB46jwD",
        "outputId": "154d3940-92d1-4519-8954-d206542c8ad3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "prompt:  four\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/95 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a93132d117fa44248fafc4cbdd4dd3cf"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "rnvwAmke6jwD"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}