{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "moved-southwest",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "matplotlib.use('svg')\n",
    "%matplotlib inline\n",
    "from IPython.display import set_matplotlib_formats\n",
    "set_matplotlib_formats('png', 'pdf')\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rcParams['font.family'] = 'serif'\n",
    "matplotlib.rcParams['font.serif'] = 'Times New Roman'\n",
    "matplotlib.rcParams[\"mathtext.fontset\"] = \"cm\"\n",
    "matplotlib.rcParams.update({'font.size': 12})\n",
    "plt.rcParams['axes.xmargin'] = 0.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "normal-recipe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from utils import *\n",
    "from models.preactresnet import preactresnet18\n",
    "from models.mnist import mnist_classifier\n",
    "import time\n",
    "from train import normalize\n",
    "from datasets import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "motivated-upper",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hmc(loader, model, m_max, p, d, sigma=0.1, l=10, path_len=0.3, \n",
    "                   repeats=3, epsilon=0.03, n_batch=None, anneal_theta=True):\n",
    "    plot_x = []\n",
    "    plot_y = []\n",
    "    avg_time = 0\n",
    "\n",
    "    alpha = path_len * sigma ** 2 / l\n",
    "    for repeat in range(repeats):\n",
    "        avg_running_estimates = np.zeros(d)\n",
    "        num_accepts = 0\n",
    "        total_n = 0\n",
    "        for idx, m in enumerate(np.linspace(1, m_max, d)):\n",
    "            n_total = 0\n",
    "            m = int(m)\n",
    "            t = time.time()\n",
    "            for batch_idx, (X, y) in enumerate(loader):\n",
    "                X, y = X.cuda(), y.cuda()\n",
    "\n",
    "                lower_limit = torch.max(-X, torch.tensor(-epsilon, dtype=X.dtype).view(1, 1, 1).cuda())\n",
    "                upper_limit = torch.min(1 - X, torch.tensor(epsilon, dtype=X.dtype).view(1, 1, 1).cuda())\n",
    "\n",
    "                losses = torch.zeros(X.size(0), m)\n",
    "\n",
    "                delta = (lower_limit - upper_limit) * torch.rand_like(X) + upper_limit\n",
    "                delta.requires_grad = True\n",
    "\n",
    "                if anneal_theta:\n",
    "                    thetas = np.linspace(0, p, m)\n",
    "                else:\n",
    "                    thetas = [np.random.uniform(0, p) for ii in range(m)]\n",
    "                for i, theta in enumerate(thetas):\n",
    "                    mom = torch.randn_like(X).cuda() * sigma\n",
    "\n",
    "                    if X.shape[1] == 1:\n",
    "                        yp = model(X + delta)\n",
    "                    else:\n",
    "                        yp = model(normalize(X + delta))\n",
    "                    loss = nn.CrossEntropyLoss(reduction='none')(yp, y)\n",
    "                    losses[:, i] = loss.detach().cpu()\n",
    "                    log_loss = theta * torch.log(loss + 1e-10)\n",
    "                    log_loss.sum().backward()\n",
    "                    h_delta = torch.norm(mom.view(X.size(0), -1), dim=1)**2/sigma**2/2 - log_loss\n",
    "                    mom += 0.5 * alpha * delta.grad # half step of momentum\n",
    "                    proposal = delta.data\n",
    "                    for j in range(l):\n",
    "                        proposal = proposal.data + alpha * mom / sigma**2    # full step of position\n",
    "                        # reflection\n",
    "                        while len(torch.where(proposal < lower_limit)[0]) > 0 or len(torch.where(proposal > upper_limit)[0]) > 0:\n",
    "                            idx_ = torch.where(proposal < lower_limit)\n",
    "                            if len(idx_[0]) > 0:\n",
    "                                proposal.data[idx_] = 2*lower_limit[idx_] - proposal.data[idx_]\n",
    "                                mom[idx_] = -mom[idx_]\n",
    "                            idx_ = torch.where(proposal > upper_limit)\n",
    "                            if len(idx_[0]) > 0:\n",
    "                                proposal.data[idx_] = 2*upper_limit[idx_] - proposal.data[idx_]\n",
    "                                mom[idx_] = -mom[idx_]\n",
    "                        proposal.requires_grad = True\n",
    "                        assert proposal.grad is None\n",
    "\n",
    "                        if X.shape[1] == 1:\n",
    "                            yp_next = model(X + proposal)\n",
    "                        else:\n",
    "                            yp_next = model(normalize(X + proposal))\n",
    "                        loss_next = nn.CrossEntropyLoss(reduction='none')(yp_next, y)\n",
    "                        log_loss_next = theta * torch.log(loss_next + 1e-10)\n",
    "                        log_loss_next.sum().backward()\n",
    "\n",
    "                        if j != (l-1):\n",
    "                            mom += alpha * proposal.grad         # full step of momentum\n",
    "                    mom += 0.5 * alpha * proposal.grad         # half step of momentum\n",
    "\n",
    "                    h_proposal = torch.norm(mom.view(X.size(0), -1), dim=1)**2/sigma**2/2 - log_loss_next\n",
    "                    delta_h = h_proposal - h_delta\n",
    "                    u = torch.zeros_like(delta_h).uniform_(0,1)\n",
    "                    idx_accept = torch.where(u <= torch.exp(-delta_h))\n",
    "                    delta.data[idx_accept] = proposal.data[idx_accept]\n",
    "\n",
    "                    num_accepts += len(idx_accept[0])\n",
    "                    total_n += delta.size(0)\n",
    "                    delta.grad.zero_()\n",
    "\n",
    "                # compute metrics\n",
    "                \n",
    "                gm = torch.exp(torch.log(losses + 1e-10).sum(dim=1) / m)\n",
    "                gm = gm.mean()\n",
    "                avg_running_estimates[idx] += gm * X.size(0)\n",
    "                n_total += X.size(0)\n",
    "                if n_batch is not None and batch_idx +1 == n_batch:\n",
    "                    break\n",
    "\n",
    "        avg_time += time.time() - t\n",
    "        avg_running_estimates /= n_total\n",
    "        plot_y += list(avg_running_estimates)\n",
    "        plot_x += [int(m) for m in np.linspace(1, m_max, d)]\n",
    "\n",
    "    avg_time /= repeats\n",
    "\n",
    "    return plot_x, plot_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "tight-relative",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_random_sampling(loader, model, m_max, p, d, repeats=3, epsilon=0.03, n_batch=None):\n",
    "    plot_x = []\n",
    "    plot_y = []\n",
    "    avg_time = 0\n",
    "\n",
    "    for repeat in range(repeats):\n",
    "        avg_running_estimates = np.zeros(d)\n",
    "        for idx, m in enumerate(np.linspace(1, m_max, d)):\n",
    "            m = int(m)\n",
    "            t = time.time()\n",
    "            \n",
    "            n_total = 0\n",
    "\n",
    "            for batch_idx, (X, y) in enumerate(loader):\n",
    "                X, y = X.cuda(), y.cuda()\n",
    "\n",
    "                lower_limit = torch.max(-X, torch.tensor(-epsilon, dtype=X.dtype).view(1, 1, 1).cuda())\n",
    "                upper_limit = torch.min(1 - X, torch.tensor(epsilon, dtype=X.dtype).view(1, 1, 1).cuda())\n",
    "\n",
    "                losses = torch.zeros(X.size(0), m)\n",
    "\n",
    "                for i in range(m):\n",
    "                    delta = (lower_limit - upper_limit) * torch.rand_like(X) + upper_limit\n",
    "                    if X.shape[1] == 1:\n",
    "                        yp = model(torch.clamp(X + delta, min=0, max=1))\n",
    "                    else:\n",
    "                        yp = model(normalize(torch.clamp(X + delta, min=0, max=1)))\n",
    "                    loss = nn.CrossEntropyLoss(reduction='none')(yp,y)\n",
    "                    losses[:, i] = loss.detach()\n",
    "                gm = (torch.exp(torch.logsumexp(torch.log(losses) * p - math.log(m), dim=1)/ p)).mean()\n",
    "                avg_running_estimates[idx] += gm * X.size(0)\n",
    "                n_total += X.size(0)\n",
    "                if n_batch is not None and batch_idx +1 == n_batch:\n",
    "                    break\n",
    "        avg_time += time.time() - t\n",
    "        avg_running_estimates /= n_total\n",
    "        plot_y += list(avg_running_estimates)\n",
    "        plot_x += [int(m) for m in np.linspace(1, m_max, d)]\n",
    "\n",
    "    avg_time /= repeats\n",
    "    print('avg time', avg_time)\n",
    "\n",
    "    return plot_x, plot_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "moved-giving",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = 'configs/train/cifar10/standard.json'\n",
    "config_dict = get_config(config)\n",
    "config = config_to_namedtuple(config_dict)\n",
    "test_loader = get_test_loader(config)\n",
    "\n",
    "d = torch.load('experiments/cifar10/standard/checkpoints/checkpoint_199.pth')\n",
    "if config.data.dataset == 'mnist':\n",
    "    model = mnist_classifier().cuda()\n",
    "elif config.data.dataset == 'cifar10':\n",
    "    model = preactresnet18().cuda()\n",
    "    \n",
    "model.load_state_dict(d[\"model\"])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "embedded-russell",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_x1, plot_y1 = plot_hmc(test_loader, model, m_max=500, p=1000, d=100, l=10, sigma=0.1, path_len=0.6, n_batch=1, repeats=1)\n",
    "print(plot_x1, plot_y1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "revolutionary-finance",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_x2, plot_y2 = plot_random_sampling(test_loader, model, m_max=5000, p=1000, d=100, n_batch=1, repeats=1)\n",
    "print(plot_x2, plot_y2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hollywood-armor",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax1) = plt.subplots(1, figsize=(5,3.5))\n",
    "plot_x1 = [x * 10 for x in plot_x1]\n",
    "ax1.scatter(plot_x1, plot_y1, 10)\n",
    "ax1.scatter(plot_x2, plot_y2, 10)\n",
    "ax1.set_ylabel('Estimate')\n",
    "ax1.set_xlabel('Iterations')\n",
    "p=100\n",
    "lgd = plt.figlegend(('Path sampling + HMC', 'Random sampling'), \n",
    "                  loc='upper center', ncol=2, \n",
    "                  bbox_to_anchor=(0.56, 1.04), \n",
    "                  borderaxespad=0)\n",
    "ax1.set_ylim(bottom=0)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figures/combined_convergence_plot_p=1000.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
