{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.linalg as npl\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from base import PromptAttn, p_svm_solver, get_margin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2\n",
    "T = 2\n",
    "d = 2\n",
    "ITN = 1000\n",
    "\n",
    "loss_type = ['-x', 'exp', 'log']\n",
    "loss_type = ['-x', 'log']\n",
    "Cs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100]\n",
    "ps      = np.zeros((len(loss_type), len(Cs), ITN, d))\n",
    "margins = np.zeros((len(loss_type), len(Cs), ITN))\n",
    "corrs   = np.zeros((len(loss_type), len(Cs), ITN))\n",
    "\n",
    "def data_generator(C):\n",
    "    v = torch.tensor([(C+1)/2,(C-1)/2])\n",
    "    Y = torch.tensor([1,1])\n",
    "    X = torch.tensor([\n",
    "        [[0,0],[1.,1.]],\n",
    "        [[0,0],[1.,-1.]],\n",
    "    ])\n",
    "    return X, Y, v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -0.5\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -0.75\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -1.0\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -1.25\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -1.5\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -1.75\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -2.0\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -2.25\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -2.5\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -2.75\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "0 0 -25.25\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.474077\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.3936693\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.33774516\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.3005025\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.27648336\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.26133215\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.2519137\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.24611348\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.24256237\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.24039617\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n",
      "1 0 0.2370385\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(loss_type)):\n",
    "    for j in range(len(Cs)):\n",
    "        X, Y, v = data_generator(Cs[j])\n",
    "        sol, Xtemp = p_svm_solver(X, torch.diag(torch.ones(d)), torch.tensor([1,1]))\n",
    "        model = PromptAttn(d)\n",
    "        model.prompt.data = torch.zeros(d)\n",
    "        model.w.data = torch.tensor(np.array(v)).float().view(-1)\n",
    "        optimizer = torch.optim.SGD([model.prompt], lr=0.1)\n",
    "        for it in range(ITN):\n",
    "            ps[i,j,it] = model.prompt.detach().numpy()\n",
    "            optimizer.zero_grad()    \n",
    "            if loss_type[i] == '-x':\n",
    "                loss = -Y*model(X)\n",
    "            elif loss_type[i] == 'exp':\n",
    "                loss = torch.exp(-Y*model(X))\n",
    "            elif loss_type[i] == 'log':\n",
    "                loss = torch.log(1+torch.exp(-Y*model(X)))\n",
    "            loss = loss.mean()\n",
    "            loss.backward()\n",
    "            model.prompt.grad /= (model.prompt.grad.norm()+0.000000001)\n",
    "            optimizer.step()\n",
    "            margins[i,j,it] = get_margin(torch.Tensor(Xtemp), model.prompt)\n",
    "            corrs[i,j,it] = model.prompt.detach().numpy().dot(sol)/npl.norm(model.prompt.detach().numpy())/npl.norm(sol)\n",
    "            if not (it) % 1000:\n",
    "                print(i, it, loss.detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  2.2222e-01  4.4444e-01  7e-01  1e+00  2e-16\n",
      " 1:  3.8047e-01  4.9185e-01  9e-03  1e-01  2e-16\n",
      " 2:  5.0019e-01  5.0000e-01  2e-04  6e-19  9e-16\n",
      " 3:  5.0000e-01  5.0000e-01  2e-06  2e-20  1e-16\n",
      " 4:  5.0000e-01  5.0000e-01  2e-08  9e-23  9e-23\n",
      "Optimal solution found.\n"
     ]
    }
   ],
   "source": [
    "sol, _ = p_svm_solver(X, torch.diag(torch.ones(d)), torch.tensor([1,1]))\n",
    "colors = ['r','b','g']\n",
    "lines = ['-','-','-']\n",
    "alpha = 1\n",
    "num = 10000\n",
    "plt.figure(figsize=(8,6))\n",
    "x = np.arange(0,20)\n",
    "# plt.plot(x,sol[1]/sol[0]*x, 'k--', linewidth=3)\n",
    "ls_loss = []\n",
    "ls_Cs = []\n",
    "for i in range(2): # range(len(loss_type)):\n",
    "    for j in range(len(Cs)-1):\n",
    "        l, = plt.plot(ps[i,j,:num,0], ps[i,j,:num,1], linewidth=3, linestyle=lines[i], color=colors[i], alpha=0.09* (j+1), zorder=-1)\n",
    "        if i == 0:\n",
    "            ls_Cs.append(l)\n",
    "        if j == len(Cs)-2:\n",
    "            ls_loss.append(l)\n",
    "    l, = plt.plot(ps[i,-1,:num,0], ps[i,-1,:num,1], linewidth=3, linestyle='--', color=colors[i], alpha=1, zorder=-1)\n",
    "    if i == 0:\n",
    "        ls_Cs.append(l)\n",
    "\n",
    "\n",
    "\n",
    "plt.scatter(X[0,0,0],X[0,0,1],marker='^',color='g',s=140, label=r'(0,0), non-opt token($\\gamma$=0)', zorder=1)\n",
    "plt.scatter(X[0,1,0],X[0,1,1],marker='*',color='g',s=220, label=r'(1,1), opt token($\\gamma$=C)', zorder=1)\n",
    "plt.scatter(X[1,0,0],X[1,0,1],marker='v',color='c',s=140, label=r'(0,0), non-opt token($\\gamma$=0)', zorder=1)\n",
    "plt.scatter(X[1,1,0],X[1,1,1],marker='*',color='c',s=220, label=r'(1,-1), opt token($\\gamma$=1)', zorder=1)\n",
    "plt.xlim([-2,12])\n",
    "plt.ylim([-6,6])\n",
    "\n",
    "plt.legend(fontsize=15, loc='upper right')\n",
    "# legend1 = plt.legend(ls_Cs, [f'C={Cs[i]}' for i in range(len(Cs))], loc=3, fontsize=15)\n",
    "legend1 = plt.legend([ls_Cs[0],ls_Cs[4],ls_Cs[-2], ls_Cs[-1]], ['C=1','C=2~9','C=10','C=100'], loc=3, fontsize=15)\n",
    "# legend2 = plt.legend(ls_loss, [r\"$\\ell(x)=-y\\cdot f(x)$\", r'$\\ell(x)=log(1+e^{-y\\cdot f(x)})$', r\"loss$\\ell(x)=e^{-y\\cdot f(x)}$\"], loc=2, fontsize=15)\n",
    "legend2 = plt.legend(ls_loss, [r\"$\\ell(u)=-u$\", r'$\\ell(u)=log(1+e^{-u})$', r\"loss$\\ell(x)=e^{-y\\cdot f(x)}$\"], loc=4, fontsize=15)\n",
    "plt.legend(fontsize=15, loc='upper right')\n",
    "plt.gca().add_artist(legend1)\n",
    "plt.gca().add_artist(legend2)\n",
    "\n",
    "plt.xticks(fontsize=25)\n",
    "plt.yticks(fontsize=25)\n",
    "# plt.xlabel(r'$x_1$', fontsize=30)\n",
    "# plt.ylabel(r'$x_2$', fontsize=30)\n",
    "plt.grid()\n",
    "# plt.show()\n",
    "plt.tight_layout()\n",
    "# plt.savefig(f'diff_loss_func_v2.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python36",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
