{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from base import PromptAttn, p_svm_solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 3\n",
    "T = 7\n",
    "d = 2\n",
    "ITN = 1000\n",
    "RUN = 1\n",
    "seed = 0\n",
    "ps = np.zeros((RUN, ITN, d))\n",
    "margins = np.zeros((RUN, ITN))\n",
    "corrs   = np.zeros((RUN, ITN))\n",
    "\n",
    "def data_generator(generator):\n",
    "    shift = torch.tensor([\n",
    "        [1,0],\n",
    "        [-1,0.5],\n",
    "        [-0,-0.5],\n",
    "    ])\n",
    "    # shift = torch.tensor([\n",
    "    #     [1,0],\n",
    "    #     [-1,1],\n",
    "    #     [-0.5,-0.5],\n",
    "    #     [-1.5,0],\n",
    "    # ])\n",
    "    v = torch.tensor([1.,0.])\n",
    "    Y = torch.ones(n)\n",
    "    # X = torch.randn((n, T, d)) * 0.2\n",
    "    X = torch.tensor([[[ 1.0281,  0.1544],\n",
    "         [-0.9691,  0.2835],\n",
    "         [-0.8976,  0.5455],\n",
    "         [-0.8692,  0.4964],\n",
    "         [ 0.0470, -0.4770],\n",
    "         [-0.0572, -0.4288],\n",
    "         [ 0.3460, -0.4928]],\n",
    "\n",
    "        [[ 0.8050, -0.3332],\n",
    "         [-0.9944,  0.7703],\n",
    "         [-1.2300,  0.4998],\n",
    "         [-1.1959,  0.2929],\n",
    "         [-0.1426, -0.4005],\n",
    "         [ 0.0158, -0.6360],\n",
    "         [ 0.1776, -0.4412]],\n",
    "\n",
    "        [[ 0.9779, -0.1994],\n",
    "         [-0.8524,  0.3241],\n",
    "         [-0.8225,  0.6877],\n",
    "         [-1.2249,  0.7280],\n",
    "         [ 0.1164, -0.1782],\n",
    "         [ 0.2241, -0.5947],\n",
    "         [-0.0599, -0.5436]]])\n",
    "    # X[:,0] += shift[0]\n",
    "    # X[:,1:(T+1)//2] += shift[1]\n",
    "    # X[:,-(T-1)//2:] += shift[2]\n",
    "    # X[:,0] += shift[0]\n",
    "    # X[:,1:1+(T-1)//3] += shift[1]\n",
    "    # X[:,1+(T-1)//3:-(T-1)//3] += shift[2]\n",
    "    # X[:,-(T-1)//3:] += shift[3]\n",
    "    # print(X[:,1:1+(T-1)//3].shape,  X[:,1+(T-1)//3:-(T-1)//3].shape, X[:,-(T-1)//3:].shape)\n",
    "    return X, Y, v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 0.83539695\n"
     ]
    }
   ],
   "source": [
    "for run in range(RUN):\n",
    "    # seed = run\n",
    "    generator = torch.Generator()\n",
    "    generator.manual_seed(seed)\n",
    "    X, Y, v = data_generator(generator)\n",
    "    model = PromptAttn(d)\n",
    "    model.prompt.data = torch.zeros(d)\n",
    "    model.w.data = v.float()\n",
    "    optimizer = torch.optim.SGD([model.prompt], lr=1)\n",
    "    for it in range(ITN):\n",
    "        ps[run,it] = model.prompt.detach().numpy()\n",
    "        optimizer.zero_grad()    \n",
    "        loss = torch.log(1+torch.exp(-Y*model(X)))\n",
    "        loss = loss.mean()\n",
    "        loss.backward()\n",
    "        model.prompt.grad /= (model.prompt.grad.norm()+0.000000000001)\n",
    "        optimizer.step()\n",
    "        if not (it) % 1000:\n",
    "            print(run, it, loss.detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  3.4000e-01  3.7230e+00  4e+01  2e+00  3e+01\n",
      " 1:  1.8767e+00 -1.4998e+00  4e+00  1e-01  3e+00\n",
      " 2:  1.2331e+00  6.8484e-01  6e-01  7e-03  2e-01\n",
      " 3:  1.2556e+00  1.1599e+00  1e-01  2e-16  3e-15\n",
      " 4:  1.2339e+00  1.2329e+00  1e-03  4e-16  4e-16\n",
      " 5:  1.2337e+00  1.2337e+00  1e-05  4e-16  7e-16\n",
      " 6:  1.2337e+00  1.2337e+00  1e-07  3e-16  9e-16\n",
      "Optimal solution found.\n",
      "[ 1.55e+00]\n",
      "[ 2.66e-01]\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ucroptml/anaconda3/lib/python3.7/site-packages/numpy/core/shape_base.py:65: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
      "  ary = asanyarray(ary)\n",
      "/home/ucroptml/anaconda3/lib/python3.7/site-packages/numpy/core/shape_base.py:65: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
      "  ary = asanyarray(ary)\n"
     ]
    }
   ],
   "source": [
    "sol, _ = p_svm_solver(X, torch.diag(torch.ones(d)), torch.zeros(n).long())\n",
    "print(sol)\n",
    "rotate_matrix = torch.tensor([[sol[0,0],-sol[1,0]],[sol[1,0],sol[0,0]]]).float()\n",
    "X_ = X # @ rotate_matrix\n",
    "ps_ = torch.tensor(ps).float() # @ rotate_matrix\n",
    "colors = ['r','r','r','r','r','b','b']\n",
    "colors = ['g','c','y','gray','gray','gray','gray','gray','gray']\n",
    "markers = ['*', '^', '^', '^', '^', '^', '^', '^', '^', '^']\n",
    "ss = [220, 140, 140, 140, 140, 140, 140, 140, 140, 140]\n",
    "alpha = 1\n",
    "plt.figure(figsize=(8,6))\n",
    "x = np.arange(0,50)\n",
    "plt.plot(x,sol[1]/sol[0]*x, '--', color='red', linewidth=3, zorder=-1)\n",
    "plt.plot([-10+bias,10+bias],[sol[0]/sol[1]*10,-sol[0]/sol[1]*10], ':', color='k', linewidth=3, zorder=-1)\n",
    "# plt.plot(x,0*x, 'r--', linewidth=3, zorder=-1)\n",
    "for i in range(ITN-1):\n",
    "    plt.arrow((ps_[0,i,0]), (ps_[0,i,1]), (ps_[0,i+1,0]-ps_[0,i,0]), (ps_[0,i+1,1]-ps_[0,i,1]), length_includes_head=True,head_width=0.2, head_length=0.2, color='gray', linewidth=2, alpha=alpha, zorder=-1)\n",
    "\n",
    "ls_best = ()\n",
    "ls_support = ()\n",
    "for i in range(n):\n",
    "    for t in range(T):\n",
    "        l = plt.scatter(X_[i,t,0],X_[i,t,1],color=colors[i], marker=markers[t], s=ss[t])\n",
    "        if t == 0:\n",
    "            ls_best += (l,)\n",
    "        if t == 1:\n",
    "            ls_support += (l,)\n",
    "\n",
    "\n",
    "# plt.scatter(X_[0,0,0],X_[0,0,1], marker='*',color='r',s=220, label=labels[0], zorder=1)\n",
    "# plt.scatter(X_[0,1,0],X_[0,1,1],marker='s',color='b',s=140, label=labels[1], zorder=1)\n",
    "# plt.scatter(X[0,2,0],X[0,2,1],color='g',s=140, label=labels[2])\n",
    "\n",
    "# plt.xlim([-2,6])\n",
    "# plt.ylim([-2,4])\n",
    "\n",
    "plt.legend([ls_best, ls_support], ['Opt tokens', 'Non-opt tokens'], fontsize=20, loc='upper right', handler_map = {tuple: matplotlib.legend_handler.HandlerTuple(None)})\n",
    "plt.xlim([-2,22])\n",
    "plt.ylim([-2,16])\n",
    "plt.xticks([0,5,10,15,20],fontsize=25)\n",
    "plt.yticks([0,5,10,15],fontsize=25)\n",
    "# plt.xlim([-2,6])\n",
    "# plt.ylim([-2,4])\n",
    "# plt.xticks(fontsize=25)\n",
    "# plt.yticks(fontsize=25)\n",
    "# plt.xlabel(r'$x_1$', fontsize=30)\n",
    "# plt.ylabel(r'$x_2$', fontsize=30)\n",
    "# plt.legend(fontsize=20, loc='upper right')\n",
    "plt.grid()\n",
    "plt.tight_layout()\n",
    "# plt.show()\n",
    "plt.savefig(f'multi_converge_zoom_out.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 1.0281,  0.1544],\n",
      "         [-0.9691,  0.2835],\n",
      "         [-0.8976,  0.5455],\n",
      "         [-0.8692,  0.4964],\n",
      "         [ 0.0470, -0.4770],\n",
      "         [-0.0572, -0.4288],\n",
      "         [ 0.3460, -0.4928]],\n",
      "\n",
      "        [[ 0.8050, -0.3332],\n",
      "         [-0.9944,  0.7703],\n",
      "         [-1.2300,  0.4998],\n",
      "         [-1.1959,  0.2929],\n",
      "         [-0.1426, -0.4005],\n",
      "         [ 0.0158, -0.6360],\n",
      "         [ 0.1776, -0.4412]],\n",
      "\n",
      "        [[ 0.9779, -0.1994],\n",
      "         [-0.8524,  0.3241],\n",
      "         [-0.8225,  0.6877],\n",
      "         [-1.2249,  0.7280],\n",
      "         [ 0.1164, -0.1782],\n",
      "         [ 0.2241, -0.5947],\n",
      "         [-0.0599, -0.5436]]])\n"
     ]
    }
   ],
   "source": [
    "print(X)\n",
    "X = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.4246])\n"
     ]
    }
   ],
   "source": [
    "def get_margin(X, p):\n",
    "    score = 1000\n",
    "    for i in range(X.shape[0]):\n",
    "        for t in range(1,X.shape[1]):\n",
    "            if (X[i,0] - X[i,t]) @ p/p.norm() < score:\n",
    "                score = (X[i,0] - X[i,t]) @ p/p.norm()\n",
    "                index = (X[i,0] + X[i,t])/2\n",
    "    return score, index\n",
    "sol_ = torch.tensor(np.array(sol)).float()\n",
    "margin, index = get_margin(X, sol_)\n",
    "bias = (index[0]/sol_[1]*sol_[0]+index[1])/sol_[0]*sol_[1]\n",
    "print(bias)\n",
    "# print(index)\n",
    "# print(sol_)\n",
    "# import math\n",
    "# print(margin)\n",
    "# bias = margin * math.sqrt(sol_[0]**2 + sol_[1]**2)/sol_[0]/2\n",
    "# print(bias)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python36",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
