{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e1c4d959",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import osqp\n",
    "from scipy import sparse\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "import torch\n",
    "import pandas as pd\n",
    "from qpth.qp import QPFunction\n",
    "import time\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "17bff98e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "ndim = 5\n",
    "neq = 5\n",
    "nineq = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "90ed11a9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "q = cp.Parameter(ndim)\n",
    "q.value = np.random.random(ndim)\n",
    "b = np.random.random(neq)\n",
    "h = G@np.random.random(ndim)\n",
    "c = np.linalg.norm(np.random.random(ndim),2)\n",
    "osA = np.vstack([G,A])\n",
    "osA = sparse.csc_matrix(osA)\n",
    "l = np.hstack([-np.inf*np.ones(nineq),b])\n",
    "u = np.hstack([h,b])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "064c5283",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The optimal value is -1.0754149890363587\n",
      "A solution x is\n",
      "[-0.41945053 -0.1753434  -0.95889347 -0.09294559 -0.27032009]\n",
      "gradient is\n",
      "[-0.37548808 -0.81061491  0.58608191 -0.95749098 -0.64131667]\n"
     ]
    }
   ],
   "source": [
    "x = cp.Variable(ndim)\n",
    "soc_constraints = [cp.norm(x,2)<= c]\n",
    "\n",
    "prob = cp.Problem(cp.Minimize(q.T@x),\n",
    "                  soc_constraints)\n",
    "prob.solve(requires_grad = True)\n",
    "print(\"The optimal value is\", prob.value)\n",
    "print(\"A solution x is\")\n",
    "print(x.value)\n",
    "print('gradient is')\n",
    "prob.backward()\n",
    "print(q.gradient)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "93c042ca",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def osqp_interface(P,q,A,l,u):\n",
    "    prob = osqp.OSQP()\n",
    "    prob.setup(P, q, A, l, u,verbose = False)\n",
    "    t0 = time.time()\n",
    "    res = prob.solve()\n",
    "    return res.x,res.y,time.time() - t0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "0950fb5f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "t1 = np.sum(soc_constraints[0].dual_value)\n",
    "t0 = np.linalg.norm(x.value,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "07f4c882",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.7606515  -0.05421308 -0.29647289 -0.02873713 -0.08357819]\n",
      " [-0.05421308  0.86767542 -0.12393491 -0.01201302 -0.03493829]\n",
      " [-0.29647289 -0.12393491  0.21258027 -0.06569511 -0.19106563]\n",
      " [-0.02873713 -0.01201302 -0.06569511  0.88397035 -0.01852   ]\n",
      " [-0.08357819 -0.03493829 -0.19106563 -0.01852     0.83647518]]\n"
     ]
    }
   ],
   "source": [
    "pi = x.value\n",
    "pis = pi[None].T@pi[None]\n",
    "m1 = soc_constraints[0].dual_value\n",
    "m0 = np.linalg.norm(x.value,2)\n",
    "Pm = m1/m0*np.eye(len(x.value)) - m1/(m0**3)*pis\n",
    "print(Pm)\n",
    "P = sparse.csc_matrix(Pm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "62fe1c6a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OSQP Backward Time spent: 0.0001900196075439453\n"
     ]
    }
   ],
   "source": [
    "lambs = soc_constraints[0].dual_value # active set\n",
    "active_set = np.argwhere(lambs>1e-8)\n",
    "# bG = -c[i].T[active_set,:].squeeze()\n",
    "bb = np.zeros(len(active_set))\n",
    "bh = np.zeros(len(active_set))\n",
    "bq = np.ones(ndim)\n",
    "osnewA = np.vstack([pi[None]])\n",
    "osnewA = sparse.csc_matrix(osnewA)\n",
    "l_new = np.hstack([bb])\n",
    "u_new = np.hstack([bb])\n",
    "\n",
    "x_grad, y_grad, time_spent_backward = osqp_interface(P,bq,osnewA,l_new,u_new)\n",
    "print('OSQP Backward Time spent:',time_spent_backward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "feb762ba",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.37548995, -0.81061755,  0.58608174, -0.95749388, -0.64131901])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "0a3831d3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.37548785, -0.81061481,  0.58608243, -0.95749093, -0.64131652])"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lambs = soc_constraints[0].dual_value\n",
    "G = pi[None]\n",
    "br = np.linalg.norm(G,2) - c\n",
    "KKT_L1 = np.hstack([Pm, G.T])\n",
    "KKT_L2 = np.hstack([lambs*G, br[None,None]])\n",
    "KKT = np.vstack([KKT_L1,KKT_L2])\n",
    "exact_bb =-(np.linalg.inv(KKT)@np.hstack([np.ones(ndim),np.zeros(1)]))[:ndim]\n",
    "exact_bb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "e742d62c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CVXPY</th>\n",
       "      <th>BPQP</th>\n",
       "      <th>Exact</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-0.375488</td>\n",
       "      <td>-0.375490</td>\n",
       "      <td>-0.375488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.810615</td>\n",
       "      <td>-0.810618</td>\n",
       "      <td>-0.810615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.586082</td>\n",
       "      <td>0.586082</td>\n",
       "      <td>0.586082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.957491</td>\n",
       "      <td>-0.957494</td>\n",
       "      <td>-0.957491</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.641317</td>\n",
       "      <td>-0.641319</td>\n",
       "      <td>-0.641317</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      CVXPY      BPQP     Exact\n",
       "0 -0.375488 -0.375490 -0.375488\n",
       "1 -0.810615 -0.810618 -0.810615\n",
       "2  0.586082  0.586082  0.586082\n",
       "3 -0.957491 -0.957494 -0.957491\n",
       "4 -0.641317 -0.641319 -0.641317"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "report_table = pd.DataFrame({'CVXPY':q.gradient, 'BPQP': x_grad, 'Exact': exact_bb})\n",
    "report_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "74eec342",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def SOCP_instances(ndim):\n",
    "    q = np.random.random(ndim)\n",
    "    c = np.linalg.norm(np.random.random(ndim),2)\n",
    "    return c,q\n",
    "def SOCP_cvxpy_solver(q,c):\n",
    "    qq = cp.Parameter(ndim)\n",
    "    qq.value = q\n",
    "    x = cp.Variable(ndim)\n",
    "    x1 = cp.Variable(ndim)\n",
    "    soc_constraints = [cp.norm(x,2)<= c]\n",
    "    soc_constraints2 = [cp.norm(x1,2)<= c]\n",
    "    prob = cp.Problem(cp.Minimize(qq.T@x),\n",
    "                      soc_constraints)\n",
    "    prob2 = cp.Problem(cp.Minimize(qq.T@x1),\n",
    "                      soc_constraints2)\n",
    "    tt3 = time.time()\n",
    "    prob2.solve(solver = 'ECOS')\n",
    "    tt4 = time.time() - tt3\n",
    "    tt0 = time.time()\n",
    "    prob.solve(requires_grad = True, solver = 'SCS')\n",
    "    tt1 = time.time() - tt0\n",
    "    tt11 = time.time()\n",
    "    prob.backward()\n",
    "    tt2 = time.time() - tt11\n",
    "    return x1.value, qq.gradient, soc_constraints[0].dual_value, tt1, tt2, tt4\n",
    "\n",
    "def SOCP_BPQP_backward(x,q,c,lambdas):\n",
    "    if lambdas==0:\n",
    "        print('No well-defined gradients')\n",
    "        return 0\n",
    "    pi = x\n",
    "    pis = pi[None].T@pi[None]\n",
    "    m1 = lambdas\n",
    "    m0 = np.linalg.norm(x,2)\n",
    "    P = m1/m0*np.eye(len(x)) - m1/(m0**3)*pis\n",
    "    P = sparse.csc_matrix(P)\n",
    "    bb = np.zeros(1)\n",
    "    bh = np.zeros(1)\n",
    "    bq = np.ones(len(x))\n",
    "    osnewA = np.vstack([pi[None]])\n",
    "    osnewA = sparse.csc_matrix(osnewA)\n",
    "    l_new = np.hstack([bb])\n",
    "    u_new = np.hstack([bb])\n",
    "    x_grad, y_grad, time_spent_backward = osqp_interface(P,bq,osnewA,l_new,u_new)\n",
    "    return x_grad, time_spent_backward\n",
    "def dict_report(stats, key, value):\n",
    "    if key in stats.keys():\n",
    "        stats[key] = np.append(stats[key], value)\n",
    "    else:\n",
    "        stats[key] = value\n",
    "def SOCP_cal_exact_backward(lambs,x,c):\n",
    "    if lambs==0:\n",
    "        print('No well-defined gradients')\n",
    "        return 0\n",
    "    ndim = len(x)\n",
    "    pi = x\n",
    "    pis = pi[None].T@pi[None]\n",
    "    m1 = lambs\n",
    "    m0 = np.linalg.norm(pi,2)\n",
    "    Pm = m1/m0*np.eye(len(pi)) - m1/(m0**3)*pis\n",
    "\n",
    "    G = pi[None]\n",
    "    br = np.linalg.norm(G,2) - c\n",
    "    KKT_L1 = np.hstack([Pm, G.T])\n",
    "    KKT_L2 = np.hstack([lambs*G, br[None,None]])\n",
    "    KKT = np.vstack([KKT_L1,KKT_L2])\n",
    "    t5 = time.time()\n",
    "    exact_bb =-(np.linalg.inv(KKT)@np.hstack([np.ones(ndim),np.zeros(1)]))[:ndim]\n",
    "    return exact_bb,time.time()-t5\n",
    "def cal_L2_accuracy(x_exact,x_approx):\n",
    "    return np.sqrt(np.sum((x_exact - x_approx)**2))/len(x_exact)\n",
    "def get_results_table(results_dict):\n",
    "    d = {}\n",
    "    missing_methods = []\n",
    "    for method in results_dict.keys():\n",
    "        if method in results_dict:\n",
    "            d[method] = ['{:.1e}({:.1e})'.format(np.nanmean(results_dict[method]),np.nanstd(results_dict[method]))]\n",
    "        else:\n",
    "            missing_methods.append(method)\n",
    "    df = pd.DataFrame.from_dict(d, orient='index')\n",
    "    df.index.names = ['avg']\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "78f9b1a5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 100/100 [00:18<00:00,  5.32it/s]\n"
     ]
    }
   ],
   "source": [
    "n_list = [10,50,100,500]\n",
    "cvxpy_stats = {}\n",
    "stats = {}\n",
    "time_spent = {}\n",
    "forward = {}\n",
    "ecos_f = {}\n",
    "exact_back = {}\n",
    "exact_acc = {}\n",
    "iters = 200\n",
    "for i in tqdm(range(iters)):\n",
    "    for ndim in n_list:\n",
    "        c,q = SOCP_instances(ndim) # neq = nineq\n",
    "        x_cp_value, x_cp_grad, lambdas,time_spent_forward,time_spent_backward, time_ecos_spent = SOCP_cvxpy_solver(q,c) # cvxpy Forward and Backward\n",
    "        \n",
    "        x_grad, bpqp_backward_timespent = SOCP_BPQP_backward(x_cp_value,q,c,lambdas)\n",
    "        exact_grad, exact_timespent = SOCP_cal_exact_backward(lambdas,x_cp_value,c)\n",
    "\n",
    "        acc_backward = cal_L2_accuracy(x_cp_grad,exact_grad)\n",
    "        acc_cvxpy = cal_L2_accuracy(exact_grad,x_grad)\n",
    "        dif = cal_L2_accuracy(x_cp_grad,x_grad)\n",
    "\n",
    "        dict_report(cvxpy_stats, f'{ndim}', time_spent_backward)\n",
    "        dict_report(ecos_f, f'{ndim}', time_ecos_spent)\n",
    "        dict_report(forward, f'{ndim}', time_spent_forward)\n",
    "        dict_report(stats, f'{ndim}', bpqp_backward_timespent)\n",
    "        dict_report(exact_back, f'{ndim}', exact_timespent)\n",
    "        dict_report(exact_acc, f'{ndim}', acc_cvxpy)\n",
    "\n",
    "        dict_report(time_spent, f'{ndim}', acc_backward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "85351640",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df_cp = pd.DataFrame(cvxpy_stats).to_csv('./analysis/cvxpy_backward_time.csv')\n",
    "df_bp = pd.DataFrame(stats).to_csv('./analysis/bpqp_backward_time.csv')\n",
    "df_time = pd.DataFrame(time_spent).to_csv('./analysis/cp_acc.csv')\n",
    "df_ecos = pd.DataFrame(ecos_f).to_csv('./analysis/ecos_forward.csv')\n",
    "df_forward = pd.DataFrame(forward).to_csv('./analysis/cvxpy_forward_time.csv')\n",
    "df_exact = pd.DataFrame(exact_back).to_csv('./analysis/exact_back_time.csv')\n",
    "df_acc = pd.DataFrame(exact_acc).to_csv('./analysis/bpqp_acc.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}