{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "gpt2_demo.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6_i3EQa2yOzA",
        "outputId": "07cea0ca-55a5-4545-fd64-064d0652690f"
      },
      "source": [
        "!pip install --upgrade pip\n",
        "!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
        "!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.2.4)\n",
            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
            "Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html\n",
            "Requirement already satisfied: jax[cuda111] in /usr/local/lib/python3.7/dist-packages (0.2.19)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (3.3.0)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (0.12.0)\n",
            "Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (1.19.5)\n",
            "Collecting jaxlib==0.1.70+cuda111\n",
            "  Downloading https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.70%2Bcuda111-cp37-none-manylinux2010_x86_64.whl (197.0 MB)\n",
            "\u001b[K     |████████████████████████████████| 197.0 MB 19 kB/s \n",
            "\u001b[?25hRequirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.12)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.4.1)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax[cuda111]) (1.15.0)\n",
            "Installing collected packages: jaxlib\n",
            "  Attempting uninstall: jaxlib\n",
            "    Found existing installation: jaxlib 0.1.66+cuda111\n",
            "    Uninstalling jaxlib-0.1.66+cuda111:\n",
            "      Successfully uninstalled jaxlib-0.1.66+cuda111\n",
            "Successfully installed jaxlib-0.1.70+cuda111\n",
            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
            "Collecting git+https://github.com/matthias-wright/flaxmodels.git\n",
            "  Cloning https://github.com/matthias-wright/flaxmodels.git to /tmp/pip-req-build-cg84k2dn\n",
            "  Running command git clone -q https://github.com/matthias-wright/flaxmodels.git /tmp/pip-req-build-cg84k2dn\n",
            "  Resolved https://github.com/matthias-wright/flaxmodels.git to commit 242ced2a4a12ace8adc32a705b08064ffeeb31ac\n",
            "Requirement already satisfied: h5py==2.10.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.10.0)\n",
            "Requirement already satisfied: numpy==1.19.5 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (1.19.5)\n",
            "Requirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.23.0)\n",
            "Requirement already satisfied: packaging==20.9 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (20.9)\n",
            "Requirement already satisfied: dataclasses==0.6 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.6)\n",
            "Requirement already satisfied: filelock==3.0.12 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (3.0.12)\n",
            "Requirement already satisfied: jax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.2.19)\n",
            "Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.1.70+cuda111)\n",
            "Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.3.4)\n",
            "Requirement already satisfied: Pillow==7.1.2 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (7.1.2)\n",
            "Requirement already satisfied: regex==2021.4.4 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2021.4.4)\n",
            "Requirement already satisfied: tqdm==4.60.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (4.60.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from h5py==2.10.0->flaxmodels==0.1.0) (1.15.0)\n",
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging==20.9->flaxmodels==0.1.0) (2.4.7)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2.10)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2021.5.30)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (1.24.3)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (3.2.2)\n",
            "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (1.0.2)\n",
            "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (0.0.9)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (3.3.0)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (0.12.0)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.4.1)\n",
            "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.12)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (2.8.2)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (1.3.1)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (0.10.0)\n",
            "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax->flaxmodels==0.1.0) (0.0.8)\n",
            "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.11.1)\n",
            "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.1.6)\n",
            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qr2BfYc9YVHx"
      },
      "source": [
        "# Generate text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RHa6ySp-ywef"
      },
      "source": [
        "This is very simple greedy text generation. There are more sophisticated [methods](https://huggingface.co/blog/how-to-generate) out there."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Y-nDnbE-yvWY",
        "outputId": "3a8d9c4a-6349-4967-aacc-be9b8335f3c0"
      },
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import flaxmodels as fm\n",
        "\n",
        "key = jax.random.PRNGKey(0)\n",
        "\n",
        "# Initialize tokenizer\n",
        "tokenizer = fm.gpt2.get_tokenizer()\n",
        "\n",
        "# Encode start sequence\n",
        "generated = tokenizer.encode('The Manhattan bridge')\n",
        "\n",
        "context = jnp.array([generated])\n",
        "past = None\n",
        "\n",
        "# Initialize model\n",
        "# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']\n",
        "model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n",
        "params = model.init(key, input_ids=context, past_key_values=past)\n",
        "\n",
        "for i in range(20):\n",
        "    # Predict next token in sequence\n",
        "    output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)\n",
        "    token = jnp.argmax(output['logits'][..., -1, :])\n",
        "    #context = jnp.expand_dims(token, axis=(0, 1))\n",
        "    context = jnp.expand_dims(token, axis=0)\n",
        "    # Add token to sequence\n",
        "    generated += [token]\n",
        "    # Update past keys and values\n",
        "    past = output['past_key_values']\n",
        "\n",
        "# Decode sequence of tokens\n",
        "sequence = tokenizer.decode(generated)\n",
        "\n",
        "print()\n",
        "print(sequence)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading: \"https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt\" to /tmp/flaxmodels/merges.txt\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 456k/456k [00:00<00:00, 12.1MiB/s]\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Downloading: \"https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json\" to /tmp/flaxmodels/vocab.json\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 1.04M/1.04M [00:00<00:00, 23.1MiB/s]\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Downloading: \"https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5\" to /tmp/flaxmodels/gpt2.h5\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 703M/703M [00:14<00:00, 48.1MiB/s]\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Downloading: \"https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json\" to /tmp/flaxmodels/gpt2.json\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 715/715 [00:00<00:00, 159kiB/s]\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n",
            "The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKnwDOU2YhSN"
      },
      "source": [
        "# Get language model head output from text input"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zW-IBk_FYm9a"
      },
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import flaxmodels as fm\n",
        "\n",
        "key = jax.random.PRNGKey(0)\n",
        "\n",
        "# Initialize tokenizer\n",
        "tokenizer = fm.gpt2.get_tokenizer()\n",
        "\n",
        "# Encode start sequence\n",
        "input_ids = tokenizer.encode('The Manhattan bridge')\n",
        "input_ids = jnp.array([input_ids])\n",
        "\n",
        "# Initialize model\n",
        "model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n",
        "params = model.init(key, input_ids=input_ids)\n",
        "\n",
        "# Compute output\n",
        "output = model.apply(params, input_ids=input_ids, use_cache=True)\n",
        "# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ui2DneCuYrOA"
      },
      "source": [
        "# Get language model head output from embeddings\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W8PrhOpZYuRZ"
      },
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import flaxmodels as fm\n",
        "                                                                    \n",
        "key = jax.random.PRNGKey(0)\n",
        "\n",
        "# Dummy input                                        \n",
        "input_embds = jax.random.normal(key, shape=(2, 10, 768))\n",
        "\n",
        "# Initialize model\n",
        "model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n",
        "params = model.init(key, input_embds=input_embds)\n",
        "# Compute output\n",
        "output = model.apply(params, input_embds=input_embds, use_cache=True)\n",
        "# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j0IUgj4yYwET"
      },
      "source": [
        "# Get model output from text input"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jSuAZ1YjYxmo"
      },
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import flaxmodels as fm\n",
        "\n",
        "key = jax.random.PRNGKey(0)\n",
        "\n",
        "# Initialize tokenizer\n",
        "tokenizer = fm.gpt2.get_tokenizer()\n",
        "\n",
        "# Encode start sequence\n",
        "input_ids = tokenizer.encode('The Manhattan bridge')\n",
        "input_ids = jnp.array([input_ids])\n",
        "\n",
        "# Initialize model\n",
        "model = fm.gpt2.GPT2Model(pretrained='gpt2')\n",
        "params = model.init(key, input_ids=input_ids)\n",
        "\n",
        "# Compute output\n",
        "output = model.apply(params, input_ids=input_ids, use_cache=True)\n",
        "# output: {'last_hidden_state': ..., 'past_key_values': ...}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-jR2kX9GYzIn"
      },
      "source": [
        "# Get model output from embeddings"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z1taV3BGY06n"
      },
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import flaxmodels as fm\n",
        "                                                                    \n",
        "key = jax.random.PRNGKey(0)\n",
        "\n",
        "# Dummy input\n",
        "input_embds = jax.random.normal(key, shape=(2, 10, 768))\n",
        "                                                                                                      \n",
        "# Initialize model\n",
        "model = fm.gpt2.GPT2Model(pretrained='gpt2')\n",
        "params = model.init(key, input_embds=input_embds)\n",
        "\n",
        "# Compute output\n",
        "output = model.apply(params, input_embds=input_embds, use_cache=True)\n",
        "# output: {'last_hidden_state': ..., 'past_key_values': ...}"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}