{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DawKmcCvwEN_"
      },
      "source": [
        "# Running a VAE on Fashion MNIST"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EgTTI-23wJSO"
      },
      "source": [
        "### Creating the VAE Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-xjMpTgLwLsQ"
      },
      "outputs": [],
      "source": [
        "from psutil import virtual_memory\n",
        "ram_gb = virtual_memory().total / 1e9\n",
        "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import random_split\n",
        "\n",
        "import numpy as np\n",
        "import torchvision\n",
        "from tqdm import tqdm\n",
        "from torchvision.utils import save_image, make_grid\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "plt.style.use('seaborn-darkgrid')\n",
        "dataset_path = '~/datasets'\n",
        "\n",
        "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "x_dim  = 784\n",
        "hidden_dim = 50\n",
        "latent_dim = 50\n",
        "\n",
        "\n",
        "#################\n",
        "batch_size = 100\n",
        "lr = 1e-3\n",
        "epochs = 50\n",
        "#################\n",
        "\n",
        "from torchvision.datasets import MNIST, FashionMNIST\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "\n",
        "mnist_transform = transforms.Compose([transforms.ToTensor(),])\n",
        "#ToTensor converts the  data images to torch tensor\n",
        "kwargs = {'num_workers': 1, 'pin_memory': True}\n",
        "#The above is some setting with the GPUs\n",
        "\n",
        "train_dataset = FashionMNIST(dataset_path, transform=mnist_transform, train=True, download=True)\n",
        "test_dataset  = FashionMNIST(dataset_path, transform=mnist_transform, train=False, download=True)\n",
        "\n",
        "train_size = 60000\n",
        "\n",
        "\n",
        "train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)\n",
        "test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False, **kwargs)\n",
        "\n",
        "class Encoder(nn.Module):\n",
        "\n",
        "    def __init__(self, input_dim, hidden_dim, latent_dim):\n",
        "        super(Encoder, self).__init__()\n",
        "\n",
        "        self.FC_layer_1 = nn.Linear(input_dim, hidden_dim)\n",
        "        self.FC_layer_2 = nn.Linear(hidden_dim, hidden_dim)\n",
        "        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)\n",
        "        self.FC_var   = nn.Linear (hidden_dim, latent_dim)\n",
        "\n",
        "        self.LeakyReLU = nn.LeakyReLU(0.2)\n",
        "\n",
        "        self.training = True\n",
        "\n",
        "    def forward(self, x):\n",
        "        h_1      = self.LeakyReLU(self.FC_layer_1(x))\n",
        "        # R^{image_dim} \\ni x -> = ReLU(A_1(x)) = h_1 \\in R^{hidden_dim}\n",
        "        h_2       = self.LeakyReLU(self.FC_layer_2(h_1))\n",
        "        # R^{hidden_dim} \\ni h_1 -> ReLU(A_2(h_1)) = h_2 \\in R^{hidden_dim}\n",
        "        mean     = self.FC_mean(h_2)\n",
        "        log_var  = self.FC_var(h_2)\n",
        "        # R^{hidden_dim} \\ni h_2 -> (A_31(h_2),A_32(h_2)) = (mean,log_var) \\in R^{hidden_dim} x R^{hidden_dim}\n",
        "\n",
        "        # encoder produces mean and log of variance i.e., parameters of a Gaussian distribution \"q\"\n",
        "\n",
        "        return mean, log_var\n",
        "\n",
        "\n",
        "class Decoder(nn.Module):\n",
        "    def __init__(self, latent_dim, hidden_dim, output_dim):\n",
        "        super(Decoder, self).__init__()\n",
        "\n",
        "        self.FC_dec_layer_1 = nn.Linear(latent_dim, hidden_dim)\n",
        "        self.FC_dec_layer_2 = nn.Linear(hidden_dim, hidden_dim)\n",
        "        self.FC_output = nn.Linear(hidden_dim, output_dim)\n",
        "\n",
        "        self.LeakyReLU = nn.LeakyReLU(0.2)\n",
        "\n",
        "    def forward(self, z):\n",
        "        dec_h_1     = self.LeakyReLU(self.FC_dec_layer_1(z))\n",
        "        # R^{latent_dim} \\ni z -> ReLU(B1(z)) = dec_h_1 \\in R^{hidden_dim}\n",
        "\n",
        "        dec_h_2     = self.LeakyReLU(self.FC_dec_layer_2(dec_h_1))\n",
        "        # R^{hidden_dim} \\ni dec_h_1 -> ReLU(B2(dec_h_1)) = dec_h_2 \\in R^{hidden_dim}\n",
        "\n",
        "        x_hat = torch.sigmoid(self.FC_output(dec_h_2))\n",
        "        #R^{hidden_dim} \\ni dec_h_2 -> Sigmoid(B3(dec_h_2)) = x_hat \\in R^{output_dim}\n",
        "\n",
        "        return x_hat\n",
        "\n",
        "\n",
        "def reparameterization(mean, var):\n",
        "    epsilon = torch.randn_like(var).to(DEVICE)\n",
        "    # sampling epsilon ~ N(0,I_{latent-dimension x latent-dimension})\n",
        "    y = mean + var*epsilon\n",
        "    # The so-called \"reparameterization trick\"\n",
        "    return y\n",
        "## Now we define the final model\n",
        "\n",
        "\n",
        "class Model(nn.Module):\n",
        "    def __init__(self, Encoder, Decoder):\n",
        "        super(Model, self).__init__()\n",
        "\n",
        "        self.Encoder = Encoder\n",
        "        self.Decoder = Decoder\n",
        "\n",
        "\n",
        "    def forward(self, x):\n",
        "        mean, log_var = self.Encoder(x)\n",
        "        y = reparameterization(mean, torch.exp(0.5 * log_var))\n",
        "        x_hat = self.Decoder(y)\n",
        "\n",
        "        return x_hat, mean, log_var"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SDnZqa6nwZL5"
      },
      "source": [
        "### Experiment Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JQYkIhNQwasA"
      },
      "outputs": [],
      "source": [
        "\n",
        "import pickle\n",
        "from datetime import datetime\n",
        "from torch.optim import Adam, SGD\n",
        "def experiment_run(algo, eta, gamma, delta, epochs):\n",
        "    algo = algo.lower()\n",
        "    delta = delta if delta else 0\n",
        "    print()\n",
        "    print(\"Starting new experimen at\", datetime.now())\n",
        "    print(\"Parameters:\", algo, eta, gamma, delta, epochs)\n",
        "    print()\n",
        "    encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)\n",
        "    decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)\n",
        "    model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)\n",
        "\n",
        "    def loss_function(x, x_hat, mean, log_var):\n",
        "        MSE_Loss = nn.functional.mse_loss(x_hat, x, reduction='sum')\n",
        "        KLD      = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - (log_var).exp())\n",
        "        # For dim = 1, torch.sum(a_ij, dim=1) = \\sum_{j=1}^d a_ij\n",
        "        # 1 here is an all ones matrix of size (minibatch_size, latent_dimension)\n",
        "        return MSE_Loss + KLD\n",
        "\n",
        "    if algo == \"adam\":\n",
        "        optimizer = Adam(model.parameters(), lr=eta)\n",
        "    else:\n",
        "        optimizer = SGD(model.parameters(), lr=eta)\n",
        "\n",
        "    print(\"Starting VAE training\")\n",
        "    model.train()\n",
        "    Training_Loss = []\n",
        "    Risk = []\n",
        "    Epoch = []\n",
        "    grad_norms = []\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "\n",
        "        # Optional learning rate scheduling\n",
        "        if epoch == 99 or epoch == 149:\n",
        "            eta *= 0.1\n",
        "\n",
        "        # At the beginning of each epoch we calculate the training and the test loss.\n",
        "        training_loss = 0\n",
        "        for batch_number, (r, _) in enumerate(train_loader):\n",
        "                r = r.view(batch_size, x_dim)\n",
        "                r = r.to(DEVICE)\n",
        "                r_hat, mean, log_var = model(r)\n",
        "                mini_batch_loss = loss_function(r, r_hat, mean, log_var)\n",
        "                training_loss += mini_batch_loss.item()\n",
        "                #mini_batch_loss.item() = training loss on the current mini-batch\n",
        "                #training_loss is accumulating the mini-batch losses to compute the loss on the entire training data.\n",
        "\n",
        "        Training_Loss.append(training_loss/(batch_number*batch_size))\n",
        "\n",
        "        #Now we compute the test loss at the same model parameters at which the above training loss was calculated.\n",
        "        test_loss = 0\n",
        "        for test_batch_number, (t, _) in enumerate(test_loader):\n",
        "                t = t.view(batch_size, x_dim)\n",
        "                t = t.to(DEVICE)\n",
        "                t_hat, mean, log_var = model(t)\n",
        "                mini_batch_loss = loss_function(t, t_hat, mean, log_var)\n",
        "                test_loss += mini_batch_loss.item()\n",
        "\n",
        "        Risk.append(test_loss/(test_batch_number*batch_size))\n",
        "\n",
        "        for batch_number, (x, _) in enumerate(train_loader):\n",
        "            x = x.view(batch_size, x_dim)\n",
        "            x = x.to(DEVICE)\n",
        "\n",
        "            #(x,_) pulls out a mini-batch from the train_loader which has now been converted into an enumeratable data type\n",
        "            #There is some ancilliary information attached to each mini-batch which we dont care about and that is held in that \"_\"\n",
        "            #batch_number adds 1 to itself everytime a mini-batch is pulled out\n",
        "            #Thus the batch_number counts the number of mini-batches in the training data - a number we did not know till now.\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            x_hat, mean, log_var = model(x)\n",
        "            mini_batch_loss = loss_function(x, x_hat, mean, log_var)\n",
        "            mini_batch_loss.backward()\n",
        "\n",
        "            # If GClip or d-GClip, adjust step size per iteration\n",
        "            norm_grad_f = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(\"inf\")).item()\n",
        "            grad_norms.append(norm_grad_f)\n",
        "            if algo == \"gclip\" or algo == \"d-gclip\":\n",
        "                h = min(eta, eta * max(delta, gamma / norm_grad_f) )\n",
        "                for g in optimizer.param_groups:\n",
        "                    g[\"lr\"] = h\n",
        "            optimizer.step()\n",
        "\n",
        "        print(\"\\tEpoch\", epoch + 1, \"complete. Latest training and testing losses:\", Training_Loss[-1], Risk[-1])\n",
        "\n",
        "        Epoch.append(epoch+1)\n",
        "\n",
        "\n",
        "    print(\"The VAE training is over! Final training loss:\", Training_Loss[-1], \" test_loss:\", Risk[-1])\n",
        "    data = {\"training_loss\": Training_Loss, \"testing_loss\": Risk, \"gradient_norms\":grad_norms}\n",
        "    return data, model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rbfFBeA6we0O"
      },
      "source": [
        "### Run Experiments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "Y8ATOqxkwgG6"
      },
      "outputs": [],
      "source": [
        "experiments = [{\"algo\":\"adam\", \"eta\":1e-3, \"gamma\":0, \"delta\":0},\n",
        "                {\"algo\":\"gd\", \"eta\":1e-4, \"gamma\":0, \"delta\":0},\n",
        "                {\"algo\":\"gclip\", \"eta\":1e-3, \"gamma\":200, \"delta\":0},\n",
        "                {\"algo\":\"d-gclip\", \"eta\":1e-3, \"gamma\":200, \"delta\":0.1}]\n",
        "\n",
        "\n",
        "experiment_results = []\n",
        "for exp in experiments:\n",
        "    epochs = 200\n",
        "    data, model = experiment_run(**exp, epochs=epochs)\n",
        "    epoch = [i+1 for i in range(epochs)]\n",
        "    print(data)\n",
        "    experiment_results.append([exp, data])"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "\n",
        "---\n",
        "\n"
      ],
      "metadata": {
        "id": "5Mv6UO7M5Eey"
      }
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}