{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.cm as cm\n",
    "from sklearn import datasets\n",
    "from math import *\n",
    "from losses.label_smoothing import CrossEntropyLoss\n",
    "from rtpt import RTPT\n",
    "\n",
    "import seaborn as sns; sns.set_style('white')\n",
    "\n",
    "%matplotlib inline\n",
    "matplotlib.rcParams['figure.figsize'] = (5, 5)\n",
    "matplotlib.rcParams['font.size'] = 14\n",
    "matplotlib.rcParams['text.usetex'] = False\n",
    "matplotlib.rcParams['lines.linewidth'] = 1.0\n",
    "plt = matplotlib.pyplot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate data\n",
    "np.random.seed(7777)\n",
    "size = 500\n",
    "train_range = (-10, 10)\n",
    "\n",
    "X_train, Y_train = datasets.make_blobs(n_samples=size,\n",
    "                           centers=[[0, -10], [-10, 10], [10, 10]],\n",
    "                           cluster_std=1.6,\n",
    "                           center_box=train_range,\n",
    "                           random_state=17)\n",
    "\n",
    "# Test data\n",
    "test_range = (-15, 15)\n",
    "\n",
    "size = 100\n",
    "test_range = (-15, 15)\n",
    "test_rng = np.linspace(*test_range, size)\n",
    "\n",
    "X1_test, X2_test = np.meshgrid(test_rng, test_rng)\n",
    "X_test = np.stack([X1_test.ravel(), X2_test.ravel()]).T\n",
    "X_test = torch.from_numpy(X_test).float()\n",
    "\n",
    "\n",
    "plt.scatter(X_train[Y_train == 0][:, 0],\n",
    "            X_train[Y_train == 0][:, 1],\n",
    "            c='coral',\n",
    "            edgecolors='k',\n",
    "            linewidths=0.5)\n",
    "plt.scatter(X_train[Y_train == 1][:, 0],\n",
    "            X_train[Y_train == 1][:, 1],\n",
    "            c='yellow',\n",
    "            edgecolors='k',\n",
    "            linewidths=0.5)\n",
    "plt.scatter(X_train[Y_train == 2][:, 0],\n",
    "            X_train[Y_train == 2][:, 1],\n",
    "            c='blue',\n",
    "            marker='p',\n",
    "            edgecolors='k',\n",
    "            linewidths=0.5)\n",
    "plt.xlim(test_range)\n",
    "plt.ylim(test_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        torch.manual_seed(42)\n",
    "        self.model = nn.Sequential(nn.Linear(2, 20),\n",
    "                                          nn.BatchNorm1d(20),\n",
    "                                          nn.ReLU(),\n",
    "                                          nn.Linear(20, 20),\n",
    "                                          nn.BatchNorm1d(20),\n",
    "                                          nn.ReLU(), \n",
    "                                          nn.Linear(20, 3, bias=False))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        return x\n",
    "\n",
    "def train(model, X, Y, alpha):\n",
    "    X_train = torch.from_numpy(X).float().cuda()\n",
    "    y_train = torch.from_numpy(Y).long().cuda()\n",
    "\n",
    "    model = model.cuda()\n",
    "    opt = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)\n",
    "    loss_fkt = CrossEntropyLoss(label_smoothing=alpha)\n",
    "    for it in range(5000):\n",
    "        y = model(X_train)\n",
    "        loss = loss_fkt(y, y_train)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "    print(f'Loss: {loss.item():.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_conf(model, X_test):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        conf_map = F.softmax(model(X_test.cuda()), 1).squeeze().cpu().numpy()\n",
    "        conf = conf_map.max(1)\n",
    "    return conf\n",
    "\n",
    "def compute_route(model, X_train, Y_train, target_cls=2, initial_point=[-3, -10], max_steps=10000, target_conf=0.95):\n",
    "    initial_point = torch.tensor([initial_point]).float().cuda()\n",
    "    target = torch.tensor([target_cls]).cuda()\n",
    "    optimizer = optim.SGD([initial_point.requires_grad_()], lr=0.1, momentum=0.0)\n",
    "    points = [initial_point.clone().detach()]\n",
    "    grad_list = []\n",
    "\n",
    "    for step in range(max_steps):\n",
    "        output = model(initial_point.requires_grad_())\n",
    "        loss = F.cross_entropy(output, target)\n",
    "        optimizer.zero_grad()\n",
    "        model.zero_grad()\n",
    "        loss.backward()\n",
    "        grad = initial_point.grad\n",
    "        grad_list.append(grad.clone().detach().norm().cpu().item())\n",
    "        optimizer.step()\n",
    "        points.append(initial_point.clone().detach())\n",
    "        if output.softmax(1)[0, target_cls] > target_conf:\n",
    "            break\n",
    "    steps_required = step\n",
    "    closest_sample_dist = torch.min(torch.norm(torch.tensor(X_train[Y_train == 2]) - initial_point.cpu(), dim=1))\n",
    "    points = torch.cat(points).cpu()\n",
    "    return points, steps_required, closest_sample_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(X, Y, X1_test, X2_test, Z, test_range, title=None, points=None, file_name=None, colorbar=False):\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    if title is not None:\n",
    "        plt.title(title)\n",
    "\n",
    "    im = plt.contourf(X1_test,\n",
    "                      X2_test,\n",
    "                      Z,\n",
    "                      alpha=0.7,\n",
    "                      cmap='RdPu',\n",
    "                      levels=np.arange(0.0, 1.01, 0.05))\n",
    "    if colorbar:\n",
    "        plt.colorbar(im, ticks=[0.0, 0.25, 0.5, 0.75, 1.0])\n",
    "\n",
    "    plt.scatter(X[Y == 0][:, 0],\n",
    "                X[Y == 0][:, 1],\n",
    "                c='#69c765',\n",
    "                marker='o',\n",
    "                edgecolors='k',\n",
    "                linewidths=0.5)\n",
    "    plt.scatter(X[Y == 1][:, 0],\n",
    "                X[Y == 1][:, 1],\n",
    "                c='#5cd0f7',\n",
    "                marker='s',\n",
    "                edgecolors='k',\n",
    "                linewidths=0.5)\n",
    "    plt.scatter(X[Y == 2][:, 0],\n",
    "                X[Y == 2][:, 1],\n",
    "                c='#fcc947',\n",
    "                marker='p',\n",
    "                s=50,\n",
    "                edgecolors='k',\n",
    "                linewidths=0.8)\n",
    "    if points is not None:\n",
    "        plt.plot(points[:, 0], points[:, 1], '-o', c='yellow', linewidth=4.0, markersize=0.0)\n",
    "        # for i in [0, -1]:\n",
    "        #     plt.scatter([points[i, 0]], [points[i, 1]], linewidths=3, marker='X', c='#80FF00', edgecolor='black', s=500, zorder=1)\n",
    "    plt.xlim(test_range)\n",
    "    plt.ylim(test_range)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    if file_name is not None:\n",
    "        plt.savefig(file_name, dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_list = [0.0, 0.05, -0.05]\n",
    "for alpha in alpha_list:\n",
    "    model = Model().cuda()\n",
    "    train(model, X_train, Y_train, alpha=alpha)\n",
    "    conf = compute_conf(model, X_test)\n",
    "    points, steps_required, closest_sample_dist = compute_route(model, X_train, Y_train, target_cls=2, initial_point=[-3, -10], max_steps=10000, target_conf=0.95)\n",
    "\n",
    "    plot(\n",
    "    X_train,\n",
    "    Y_train,\n",
    "    X1_test,\n",
    "    X2_test,\n",
    "    conf.reshape(size, size),\n",
    "    test_range,\n",
    "    title=\n",
    "    f'Steps Taken: {steps_required}\\nClosest Sample Distance: {closest_sample_dist:.2f}',\n",
    "    points=points,\n",
    "    file_name=f'toy_{alpha}.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_list = [0.01]\n",
    "for alpha in alpha_list:\n",
    "    model = Model().cuda()\n",
    "    train(model, X_train, Y_train, alpha=alpha)\n",
    "    conf = compute_conf(model, X_test)\n",
    "    points, steps_required, closest_sample_dist = compute_route(model, X_train, Y_train, target_cls=2, initial_point=[-3, -10], max_steps=10000, target_conf=0.95)\n",
    "\n",
    "    plot(\n",
    "    X_train,\n",
    "    Y_train,\n",
    "    X1_test,\n",
    "    X2_test,\n",
    "    conf.reshape(size, size),\n",
    "    test_range,\n",
    "    title=\n",
    "    f'Steps Taken: {steps_required}\\nClosest Sample Distance: {closest_sample_dist:.2f}',\n",
    "    points=points,\n",
    "    file_name=f'legend.pdf',\n",
    "    colorbar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_cls = 2\n",
    "loss_fkt = CrossEntropyLoss(label_smoothing=0.0)\n",
    "initial_point = torch.tensor([[-3., -10.]]).cuda()\n",
    "target_point = torch.tensor([[10., 10.]]).cuda()\n",
    "target = torch.tensor([target_cls]).cuda()\n",
    "\n",
    "points = [alpha * initial_point + (1-alpha) * target_point for alpha in np.linspace(0, 1, 50)]\n",
    "points = torch.cat(points).cuda()\n",
    "\n",
    "with torch.no_grad():\n",
    "    output = model(points)\n",
    "    conf = output.softmax(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 5))\n",
    "plt.title(f'Smoothing Factor: α={ALPHA}')\n",
    "\n",
    "plt.plot(np.linspace(0, 1, 50),\n",
    "         conf[:, 0].cpu(),\n",
    "         label='Class 0',\n",
    "         c='coral',\n",
    "         linewidth=4)\n",
    "plt.plot(np.linspace(0, 1, 50), conf[:, 2].cpu(), label='Class 2', c='blue', linewidth=4)\n",
    "plt.xlabel(xlabel='Interpolation λ', fontsize=18)\n",
    "plt.ylabel(ylabel='Softmax Score', fontsize=18)\n",
    "plt.legend(fontsize=18)\n",
    "plt.savefig(f'softmax_scores_{ALPHA}.jpg', dpi=300, bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
