{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from base import PromptAttn, p_svm_solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "eq_support = False\n",
    "T = 3\n",
    "d = 3\n",
    "ITN = 20\n",
    "init_p = torch.tensor([\n",
    "    [0.,1.,0],\n",
    "    [0.,0.5,0],\n",
    "    [0.,0.,0],\n",
    "    [0.,-0.5,0],\n",
    "    [0.,-1.,0],\n",
    "    [0.,-1.5,0],\n",
    "])\n",
    "\n",
    "ps = np.zeros((init_p.shape[0], ITN, d))\n",
    "\n",
    "def data_generator():\n",
    "    if eq_support:\n",
    "        v = torch.tensor([0.9,1.09,0.9])\n",
    "        # v = torch.tensor([1.,1.1,1])\n",
    "        Y = torch.tensor([1])\n",
    "        X = torch.tensor([[[-0.1,1,0],[1.,0.,0],[0,0,1.]]])\n",
    "    else:\n",
    "        v = torch.tensor([0.9,1.09,0])\n",
    "        Y = torch.tensor([1])\n",
    "        X = torch.tensor([[[-0.1,1.,0],[1.,0,0],[0,0,1.]]])\n",
    "    return X, Y, v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 0.38149148\n",
      "1 0 0.40367585\n",
      "2 0 0.42580315\n",
      "3 0 0.44534504\n",
      "4 0 0.46084064\n",
      "5 0 0.4721161\n"
     ]
    }
   ],
   "source": [
    "for i in range(init_p.shape[0]):\n",
    "    X, Y, v = data_generator()\n",
    "    model = PromptAttn(d, hidden_size=3, identity_W=False)\n",
    "    model.key.weight.data   = torch.tensor([[1.,0,0],[0,1.,0]])\n",
    "    model.query.weight.data = torch.tensor([[1.,0,0],[0,1.,0]])\n",
    "    # W = model.key.weight.data.T @ model.query.weight.data\n",
    "    # sol, X_ = p_svm_solver(X, W, ids)\n",
    "    model.prompt.data = init_p[i].clone().float().view(-1)\n",
    "    model.w.data = torch.tensor(np.array(v)).float().view(-1)\n",
    "    optimizer = torch.optim.SGD([model.prompt], lr=1)\n",
    "    for it in range(ITN):\n",
    "        ps[i,it] = model.prompt.detach().numpy()\n",
    "        optimizer.zero_grad()    \n",
    "        loss = torch.log(1+torch.exp(-Y*model(X)))\n",
    "        loss = loss.mean()\n",
    "        loss.backward()\n",
    "        model.prompt.grad /= (model.prompt.grad.norm()+0.000000001)\n",
    "        optimizer.step()\n",
    "        if not (it) % 1000:\n",
    "            print(i, it, loss.detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  1.9157e-01  3.8314e-01  6e-01  1e+00  2e-16\n",
      " 1:  3.1526e-01  4.4010e-01  5e-02  2e-01  1e-16\n",
      " 2:  5.0096e-01  4.9031e-01  1e-02  8e-17  9e-16\n",
      " 3:  4.9511e-01  4.9459e-01  5e-04  2e-16  1e-16\n",
      " 4:  4.9505e-01  4.9504e-01  6e-06  8e-17  5e-16\n",
      " 5:  4.9505e-01  4.9505e-01  6e-08  2e-16  1e-15\n",
      "Optimal solution found.\n",
      "     pcost       dcost       gap    pres   dres\n",
      " 0:  1.9194e-01  3.8580e-01  6e-01  1e+00  2e-16\n",
      " 1:  3.1693e-01  4.4402e-01  5e-02  2e-01  0e+00\n",
      " 2:  5.0601e-01  4.9521e-01  1e-02  1e-16  1e-15\n",
      " 3:  5.0006e-01  4.9954e-01  5e-04  2e-17  4e-16\n",
      " 4:  5.0000e-01  4.9999e-01  7e-06  2e-16  7e-16\n",
      " 5:  5.0000e-01  5.0000e-01  7e-08  1e-17  2e-16\n",
      "Optimal solution found.\n"
     ]
    }
   ],
   "source": [
    "W = model.key.weight.data.T @ model.query.weight.data\n",
    "sol_best, _ = p_svm_solver(X, W, torch.tensor([0]))\n",
    "sol_second, _ = p_svm_solver(X, W, torch.tensor([1]))\n",
    "colors = ['r','r','r','r','r','b','b']\n",
    "colors = ['gray','gray','gray','gray','gray','gray','gray','gray','gray']\n",
    "labels = [r'(-0.1,1), score=1.1', r'(1,0), score=1', rf'(0,0), score={1 if eq_support else 0}']\n",
    "labels = [r'(-0.1,1), score=1', r'(1,0), score=0.9', rf'(0,0), score={0.9 if eq_support else 0}']\n",
    "alpha = 1\n",
    "plt.figure(figsize=(8,6))\n",
    "x = np.arange(0,20)\n",
    "plt.plot(-x,sol_best[1]/sol_best[0]*-x, 'r--', linewidth=3, zorder=-1)\n",
    "plt.plot(x,sol_second[1]/sol_second[0]*x,'b--', linewidth=3, zorder=-1)\n",
    "for i in range(ITN-1):\n",
    "    for j in range(init_p.shape[0]):\n",
    "        plt.arrow((ps[j,i,0]), (ps[j,i,1]), (ps[j,i+1,0]-ps[j,i,0]), (ps[j,i+1,1]-ps[j,i,1]), length_includes_head=True,head_width=0.1, head_length=0.1, color=colors[j], linewidth=2, alpha=alpha, zorder=-1)\n",
    "\n",
    "plt.scatter(X[0,2,0],X[0,2,1],color='g',s=140, label=labels[2])\n",
    "plt.scatter(X[0,1,0],X[0,1,1],marker='s',color='b',s=140, label=labels[1], zorder=1)\n",
    "plt.scatter(X[0,0,0],X[0,0,1], marker='*',color='r',s=220, label=labels[0], zorder=1)\n",
    "if not eq_support:\n",
    "    plt.xlim([-2,12])\n",
    "    plt.ylim([-2,10])\n",
    "else:\n",
    "    plt.xlim([-5,9])\n",
    "    plt.ylim([-2,10])\n",
    "\n",
    "plt.xticks(fontsize=25)\n",
    "plt.yticks(fontsize=25)\n",
    "plt.xlabel(r'$x_1$', fontsize=30)\n",
    "plt.ylabel(r'$x_2$', fontsize=30)\n",
    "plt.legend(fontsize=20, loc='upper right')\n",
    "plt.grid()\n",
    "# plt.show()\n",
    "plt.tight_layout()\n",
    "# plt.savefig(f'local_converge_{eq_support * 1}_v2.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
