{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import osqp\n",
    "from scipy import sparse\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "import torch\n",
    "import pandas as pd\n",
    "from qpth.qp import QPFunction\n",
    "import time\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def osqp_interface(P,q,A,l,u):\n",
    "    prob = osqp.OSQP()\n",
    "    prob.setup(P, q, A, l, u,verbose = False)\n",
    "    t0 = time.time()\n",
    "    res = prob.solve()\n",
    "    return res.x,res.y,time.time() - t0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "ndim = 20\n",
    "neq = 10\n",
    "nineq = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "P = np.random.random((ndim,ndim))\n",
    "P = P.T@P+(0.0001*np.eye(ndim,ndim))\n",
    "Pm = P.copy()\n",
    "P = sparse.csc_matrix(P)\n",
    "q = np.random.random(ndim)\n",
    "A = np.random.random((neq,ndim))\n",
    "G = np.random.random((nineq,ndim))\n",
    "b = np.random.random(neq)\n",
    "h = G@np.random.random(ndim)\n",
    "osA = np.vstack([G,A])\n",
    "osA = sparse.csc_matrix(osA)\n",
    "l = np.hstack([-np.inf*np.ones(nineq),b])\n",
    "u = np.hstack([h,b])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 1.OSQP Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "x_value, y_value, time_spent = osqp_interface(P,q,osA,l,u)\n",
    "print('OSQP Forward Time spent:',time_spent)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 2.OSQP Backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "lambs = y_value[:nineq] # active set\n",
    "active_set = np.argwhere(lambs>1e-8)\n",
    "bG = G[active_set,:].squeeze()\n",
    "bb = np.zeros(neq)\n",
    "bh = np.zeros(len(active_set))\n",
    "bq = np.ones(ndim)\n",
    "osnewA = np.vstack([bG,A])\n",
    "osnewA = sparse.csc_matrix(osnewA)\n",
    "l_new = np.hstack([bh,bb])\n",
    "u_new = np.hstack([bh,bb])\n",
    "\n",
    "x_grad, y_grad, time_spent_backward = osqp_interface(P,bq,osnewA,l_new,u_new)\n",
    "print('OSQP Backward Time spent:',time_spent_backward)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 3.1 CVXPY Backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "qq = cp.Parameter(ndim)\n",
    "qq.value = q\n",
    "x1 = cp.Variable(ndim)\n",
    "prob = cp.Problem(cp.Minimize((1 / 2) * cp.quad_form(x1, P) + qq.T @ x1),\n",
    "                              [G @ x1 <= h,\n",
    "                               A @ x1 == b])\n",
    "t3 = time.time()\n",
    "prob.solve(requires_grad = True,solver='SCS')\n",
    "print('CVXPY Forward Time Spent:',time.time() - t3)\n",
    "t4 = time.time()\n",
    "prob.backward()\n",
    "t4 = time.time() - t4\n",
    "print('CVXPY Backward Time Spent:',t4)\n",
    "total_time_cvxpy = time.time() - t3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3.2 QPTH/OptNet Backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "qpf = QPFunction(verbose=0,maxIter=4000)\n",
    "tP, tq, tG, th, tA, tb = [torch.Tensor(x).unsqueeze(0).cuda() for x in [Pm, q, G, h, A, b]]\n",
    "tq.requires_grad = True\n",
    "t6 = time.time()\n",
    "qpth_x_value = qpf(tP, tq, tG, th, tA, tb)\n",
    "print('qpth Forward Time Spent:', time.time() - t6)\n",
    "t7 = time.time()\n",
    "qpth_x_value.backward(torch.ones(1, ndim).cuda())\n",
    "qpth_x_grad = tq.grad.squeeze().cpu().numpy()\n",
    "print('qpth Backward Time Spent:', time.time() - t7)\n",
    "qpth_x_value = qpth_x_value.squeeze().detach().cpu().numpy()\n",
    "total_time_qpth = time.time() - t6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 4. Exact backward (Matrix inverse, Gould et al. 2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "t5 = time.time()\n",
    "KKT_L1 = np.hstack([Pm,G.T,A.T])\n",
    "KKT_L2 = np.hstack([np.diag(lambs)@G, np.diag(G@x_value-h),np.zeros((nineq,neq))])\n",
    "KKT_L3 = np.hstack([A, np.zeros((neq,neq)),np.zeros((neq,nineq))])\n",
    "KKT = np.vstack([KKT_L1,KKT_L2,KKT_L3])\n",
    "exact_bb =-(np.linalg.inv(KKT)@np.hstack([np.ones(ndim),np.zeros(nineq),np.zeros(neq)]))[:ndim]\n",
    "t5 = time.time()-t5\n",
    "print(\"Exact Backward Time Spent\", t5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 5. Alt-diff, Sun et al. 2023"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import scipy\n",
    "def relu(s):\n",
    "    ss = s\n",
    "    for i in range(len(s)):\n",
    "        if s[i]<0:\n",
    "            ss[i] = 0\n",
    "    return ss\n",
    "\n",
    "def proj(s):\n",
    "    ss = s\n",
    "    for i in range(len(s)):\n",
    "        if s[i]<0:\n",
    "            ss[i] = (ss[i] + math.sqrt(ss[i] ** 2 + 4 * 0.001))/2\n",
    "    return ss\n",
    "\n",
    "def sgn(s,p):\n",
    "    ss = np.zeros(p)\n",
    "    for i in range(len(s)):\n",
    "        if s[i]<=0:\n",
    "            ss[i] = 0\n",
    "        else:\n",
    "            ss[i] = 1\n",
    "    return ss\n",
    "\n",
    "def alt_diff(P,q,A, b, G, h):\n",
    "    t0 = time.time()\n",
    "    n, m, p = len(P), len(A), len(G)\n",
    "    sk = np.zeros(p)\n",
    "    lamb = np.zeros(m)\n",
    "    nu = np.zeros(p)\n",
    "\n",
    "    dxk = np.zeros((m, n))\n",
    "    dsk = np.zeros((p, n))\n",
    "    dlamb = np.zeros((m, n))\n",
    "    dnu = np.zeros((p, n))\n",
    "\n",
    "    rho = 1\n",
    "    M = P + rho * A.T @ A + rho * G.T @ G\n",
    "\n",
    "    begininv = time.time()\n",
    "    R = - np.linalg.inv(P + rho * A.T @ A + rho * G.T @ G)\n",
    "    endinv = time.time()\n",
    "    print(\"The inverse time is\", endinv-begininv)\n",
    "\n",
    "\n",
    "    res = [1000, -100]\n",
    "    # thres = 1e-5\n",
    "\n",
    "    begin2 = time.time()\n",
    "    xk = np.ones(n)\n",
    "    thres = 1e-4\n",
    "    while abs((np.linalg.norm(res[-1]) - np.linalg.norm(res[-2])) / np.linalg.norm(res[-2])) > thres:\n",
    "\n",
    "        xk = R @ (q + A.T @ lamb + G.T @ nu - rho * A.T @ b + rho * G.T @ (sk - h))\n",
    "\n",
    "        dxk = R @ (np.ones(n) + A.T @ dlamb + G.T @ dnu  + rho * G.T @ dsk)\n",
    "        sk = relu(- (1 / rho) * nu - (G @ xk - h))\n",
    "\n",
    "        dsk = (-1 / rho) * sgn(sk,p).reshape(p, 1) @ np.ones((1,n)) * (dnu + rho * G @ dxk)\n",
    "\n",
    "        lamb = lamb + rho * (A @ xk - b)\n",
    "        dlamb = dlamb + rho * (A @ dxk)\n",
    "        nu = nu + rho * (G @ xk + sk - h)\n",
    "        dnu = dnu + rho * (G @ dxk + dsk)\n",
    "        res.append(xk)\n",
    "\n",
    "    y_f = dxk.T[0]\n",
    "    end2 = time.time()\n",
    "    return xk, y_f, endinv-begininv, end2-begin2, end2-t0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "xk, dq,_,_,t6= alt_diff(Pm, q, A, b, G, h)\n",
    "print(\"Alt-diff Total Time Spent\", t6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 6.JAXOpt, Blondel et al., 2021"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import time\n",
    "import jaxopt\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def jax_qp_solver(c):\n",
    "    qp = jaxopt.OSQP()\n",
    "    sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params\n",
    "    loss = jnp.dot(sol.primal,jnp.ones(ndim))\n",
    "    return loss\n",
    "def jax_qp(Q,c,A,b,G,h):\n",
    "    qp = jaxopt.OSQP()\n",
    "    t0 = time.time()\n",
    "    [Qj,cj,Aj,bj,Gj,hj] = [jnp.array(vec) for vec in [Q,c,A,b,G,h]]\n",
    "    sol = qp.run(params_obj=(Qj, cj), params_eq=(Aj, bj), params_ineq=(Gj, hj)).params\n",
    "    t1 = time.time()\n",
    "    forward_time = t1 - t0\n",
    "    t2 = time.time()\n",
    "    jax_gradient = jax.grad(jax_qp_solver)(c)\n",
    "    t3 = time.time()\n",
    "    backward_time = t3 - t2 - forward_time\n",
    "    return jax_gradient, forward_time, backward_time\n",
    "Q = Pm\n",
    "c = q\n",
    "\n",
    "jax_gradient, forward_time, backward_time = jax_qp(Q,c,A,b,G,h)\n",
    "print('JAXOpt Backward time spent', backward_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 7. Accuracy Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def cal_cos_accuracy(x_exact,x_approx):\n",
    "    return np.dot(x_exact,x_approx)/(np.linalg.norm(x_exact)*np.linalg.norm(x_approx))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print('Backward pass')\n",
    "print(pd.DataFrame({'CVXPY':qq.gradient, 'qpth': qpth_x_grad, 'BPQP': x_grad, 'Alt-diff':dq,'JAXOpt':jax_gradient, 'exact': exact_bb}).head(5))\n",
    "print('Time Spent')\n",
    "time_table = pd.Series({'CVXPY':total_time_cvxpy, 'qpth': total_time_qpth, 'BPQP': time_spent+time_spent_backward, 'Alt-diff':t6,'JAXOpt':t7, 'exact': t5})\n",
    "print(time_table.to_frame())\n",
    "np.log(1+time_table.to_frame()).plot.bar(logy=True,legend=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print('OSQP vs CVXPY forward dif',cal_cos_accuracy(x_value , x1.value))\n",
    "print('OSQP vs qpth forward dif',cal_cos_accuracy(x_value , qpth_x_value))\n",
    "print('OSQP vs Exact backward dif',cal_cos_accuracy(exact_bb , x_grad))\n",
    "print('CVXPY vs Exact backward dif',cal_cos_accuracy(exact_bb, qq.gradient))\n",
    "print('qpth vs Exact backward dif',cal_cos_accuracy(exact_bb, qpth_x_grad))\n",
    "print('qpth vs CVXPY backward dif',cal_cos_accuracy(qq.gradient, qpth_x_grad))\n",
    "print('alt-diff vs Exact backward backward dif',cal_cos_accuracy(exact_bb,dq))\n",
    "print('JAXOpt vs Exact backward backward dif',cal_cos_accuracy(exact_bb,jax_gradient))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 8. Comparison with state-of-art differentiable QP optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def QP_instances(ndim,neq,nineq):\n",
    "\n",
    "    P = np.random.randn(ndim,ndim)\n",
    "    P = P.T@P+(0.001*np.eye(ndim,ndim))\n",
    "    Pm = P.copy()\n",
    "    P = sparse.csc_matrix(P)\n",
    "    q = np.random.randn(ndim)\n",
    "    A = np.random.randn(neq,ndim)\n",
    "    G = np.random.randn(nineq,ndim)\n",
    "    b = np.random.randn(neq)\n",
    "    h = np.random.randn(nineq)\n",
    "    osA = np.vstack([G,A])\n",
    "    osA = sparse.csc_matrix(osA)\n",
    "    l = np.hstack([-np.inf*np.ones(len(h)),b])\n",
    "    u = np.hstack([h,b])\n",
    "    return P,q,G,b,A,h,osA,l,u,Pm\n",
    "\n",
    "def QP_OSQP_backward(y_value,P,G,A):\n",
    "    nineq,ndim = G.shape\n",
    "    neq = A.shape[0]\n",
    "    lambs = y_value[:nineq] # active set\n",
    "    active_set = np.argwhere(lambs>1e-8)\n",
    "    bG = G[active_set,:].squeeze()\n",
    "    bb = np.zeros(neq)\n",
    "    bh = np.zeros(len(active_set))\n",
    "    bq = np.ones(ndim)\n",
    "    osnewA = np.vstack([bG,A])\n",
    "    osnewA = sparse.csc_matrix(osnewA)\n",
    "    l_new = np.hstack([bh,bb])\n",
    "    u_new = np.hstack([bh,bb])\n",
    "    x_grad, y_grad, time_spent_backward = osqp_interface(P,bq,osnewA,l_new,u_new)\n",
    "    return x_grad, y_grad, time_spent_backward\n",
    "\n",
    "def QP_cvxpy_backward(P,G,A,h,b,q):\n",
    "    nineq,ndim = G.shape\n",
    "    neq = nineq\n",
    "    qq = cp.Parameter(ndim)\n",
    "    qq.value = q\n",
    "    x1 = cp.Variable(ndim)\n",
    "    Pp = cp.psd_wrap(P)\n",
    "    prob = cp.Problem(cp.Minimize((1 / 2) * cp.quad_form(x1, Pp) + qq.T @ x1),\n",
    "                                  [G @ x1 <= h,\n",
    "                                   A @ x1 == b])\n",
    "    t3 = time.time()\n",
    "    try:\n",
    "        prob.solve(requires_grad=True, solver='SCS')\n",
    "        time_spent_forward = time.time() - t3\n",
    "        t4 = time.time()\n",
    "        prob.backward()\n",
    "        time_spent_backward = time.time() - t4\n",
    "        return x1.value,prob.value, qq.gradient,time_spent_forward,time_spent_backward\n",
    "    except:\n",
    "        return np.nan*np.zeros(ndim),np.nan,np.nan*np.zeros(ndim),np.nan,np.nan\n",
    "\n",
    "def QP_qpth_evaluate(Pm,G,A,h,b,q):\n",
    "    nineq,ndim = G.shape\n",
    "    neq = nineq\n",
    "    qpf = QPFunction(verbose=0,maxIter=4000)\n",
    "    tP, tq, tG, th, tA, tb = [torch.Tensor(x).unsqueeze(0).cuda() for x in [Pm, q, G, h, A, b]]\n",
    "    tq.requires_grad = True\n",
    "    try:\n",
    "        t6 = time.time()\n",
    "        qpth_x_value = qpf(tP, tq, tG, th, tA, tb)\n",
    "        t7 = time.time()\n",
    "        qpth_x_value.backward(torch.ones(1, ndim).cuda())\n",
    "        qpth_x_grad = tq.grad.squeeze().cpu().numpy()\n",
    "        t8 = time.time()\n",
    "        qpth_x_value = qpth_x_value.squeeze().detach().cpu().numpy()\n",
    "        return qpth_x_value,qpth_x_grad,t7-t6,t8-t7\n",
    "    except:\n",
    "        return np.nan*np.zeros(ndim),np.nan*np.zeros(ndim),np.nan,np.nan\n",
    "\n",
    "def QP_cal_exact_backward(Pm,G,A,y_value,x_value,h):\n",
    "    t5 = time.time()\n",
    "    nineq,ndim = G.shape\n",
    "    neq = A.shape[0]\n",
    "    lambs = y_value[:nineq] # active set\n",
    "    KKT_L1 = np.hstack([Pm,G.T,A.T])\n",
    "    KKT_L2 = np.hstack([np.diag(lambs)@G, np.diag(G@x_value-h),np.zeros((nineq,neq))])\n",
    "    KKT_L3 = np.hstack([A, np.zeros((neq,neq)),np.zeros((neq,nineq))])\n",
    "    KKT = np.vstack([KKT_L1,KKT_L2,KKT_L3])\n",
    "    exact_bb =-(np.linalg.inv(KKT)@np.hstack([np.ones(ndim),np.zeros(nineq),np.zeros(neq)]))[:ndim]\n",
    "    return exact_bb,time.time()-t5\n",
    "\n",
    "\n",
    "\n",
    "def dict_report(stats, key, value):\n",
    "    if key in stats.keys():\n",
    "        stats[key] = np.append(stats[key], value)\n",
    "    else:\n",
    "        stats[key] = value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# small scale\n",
    "n_list = [10, 50, 100, 500]\n",
    "nconstraints_list = [5, 10, 20, 100]\n",
    "stats = {}\n",
    "iters = 200\n",
    "for i in tqdm(range(iters)):\n",
    "    for ndim,neq in zip(n_list,nconstraints_list):\n",
    "        P,q,G,b,A,h,osA,l,u,Pm = QP_instances(ndim,neq,neq) # neq = nineq\n",
    "        x_value, y_value, time_spent_forward_osqp = osqp_interface(P,q,osA,l,u) # OSQP Forward\n",
    "        x_grad, y_grad, time_spent_backward_osqp = QP_OSQP_backward(y_value,P,G,A) # OSQP Backward\n",
    "        x_cp_value,y_cp_value, x_cp_grad,time_spent_forward,time_spent_backward = QP_cvxpy_backward(Pm,G,A,h,b,q) # cvxpy Forward and Backward\n",
    "        x_qpth_value, x_qpth_grad,time_spent_qpth_forward, time_qpth_backward= QP_qpth_evaluate(Pm,G,A,h,b,q)\n",
    "        x_alt, x_alt_grad, inverse_time,sol_time, time_spend_alt_overall = alt_diff(Pm,q,A, b, G, h)\n",
    "        Q = Pm\n",
    "        c = q\n",
    "        jax_gradient, jax_forward_time, jax_backward_time = jax_qp(Q,c,A,b,G,h)\n",
    "\n",
    "        exact_bb,t_spent = QP_cal_exact_backward(Pm,G,A,y_value,x_value,h) # Exact backward\n",
    "        acc_forward = cal_cos_accuracy(x_value,x_qpth_value)\n",
    "        acc_osqp_bb = cal_cos_accuracy(exact_bb,x_grad)\n",
    "        acc_cvxpy_bb = cal_cos_accuracy(exact_bb,x_cp_grad)\n",
    "        acc_qpth_bb = cal_cos_accuracy(exact_bb,x_qpth_grad)\n",
    "        acc_alt_bb = cal_cos_accuracy(exact_bb,x_alt_grad)\n",
    "        acc_jax_bb = cal_cos_accuracy(exact_bb,jax_gradient)\n",
    "\n",
    "        print(f\"cvxpy:{time_spent_forward}, {time_spent_backward}\")\n",
    "        dict_report(stats, 'Time OSQP Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward_osqp)\n",
    "        dict_report(stats, 'Time OSQP Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_backward_osqp)\n",
    "        dict_report(stats, 'Time OSQP Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward_osqp+time_spent_backward_osqp)\n",
    "        dict_report(stats, 'Time CVXPY Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward)\n",
    "        dict_report(stats, 'Time CVXPY Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_backward)\n",
    "        dict_report(stats, 'Time CVXPY Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward+time_spent_backward)\n",
    "        dict_report(stats, 'Time Exact Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', t_spent)\n",
    "        dict_report(stats, 'Time Exact Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward_osqp+t_spent)\n",
    "        dict_report(stats, 'Time Qpth Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_qpth_backward)\n",
    "        dict_report(stats, 'Time Qpth Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_qpth_forward)\n",
    "        dict_report(stats, 'Time Qpth Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_qpth_backward+time_spent_qpth_forward)\n",
    "        dict_report(stats, 'Time Alt-diff Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', inverse_time)\n",
    "        dict_report(stats, 'Time Alt-diff Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', sol_time)\n",
    "        dict_report(stats, 'Time Alt-diff Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spend_alt_overall)\n",
    "        dict_report(stats, 'Time JAXOpt Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', jax_forward_time)\n",
    "        dict_report(stats, 'Time JAXOpt Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', jax_backward_time)\n",
    "\n",
    "        dict_report(stats, 'Accuracy JAXOpt Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_jax_bb)\n",
    "        dict_report(stats, 'Accuracy Alt-diff Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_alt_bb)\n",
    "        dict_report(stats, 'Accuracy OSQP Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_osqp_bb)\n",
    "        dict_report(stats, 'Accuracy CVXPY Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_cvxpy_bb)\n",
    "        dict_report(stats, 'Accuracy QPTH Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_qpth_bb)\n",
    "    try:\n",
    "        pd.DataFrame(stats).to_csv('./results/Small_scale_qp.csv')\n",
    "    except:\n",
    "        pd.Series(stats).to_csv('./results/Small_scale_qp.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# large scale\n",
    "n_list = [500, 1500, 3000,5000]\n",
    "nconstraints_list = [200, 500, 1000,2000]\n",
    "stats_large = {}\n",
    "iters = 10\n",
    "for i in tqdm(range(iters)):\n",
    "    for ndim,neq in zip(n_list,nconstraints_list):\n",
    "        P,q,G,b,A,h,osA,l,u,Pm = QP_instances(ndim,neq,neq) # neq = nineq\n",
    "        x_value, y_value, time_spent_forward_osqp = osqp_interface(P,q,osA,l,u) # OSQP Forward\n",
    "        x_grad, y_grad, time_spent_backward_osqp = QP_OSQP_backward(y_value,P,G,A) # OSQP Backward\n",
    "\n",
    "        x_qpth_value, x_qpth_grad,time_spent_qpth_forward, time_qpth_backward= QP_qpth_evaluate(Pm,G,A,h,b,q)\n",
    "        x_alt, x_alt_grad, inverse_time,sol_time,time_spend_alt_overall = alt_diff(Pm,q,A, b, G, h)\n",
    "\n",
    "        exact_bb,t_spent = QP_cal_exact_backward(Pm,G,A,y_value,x_value,h) # Exact backward\n",
    "        acc_forward = cal_cos_accuracy(x_value,x_qpth_value)\n",
    "        acc_osqp_bb = cal_cos_accuracy(exact_bb,x_grad)\n",
    "\n",
    "        acc_qpth_bb = cal_cos_accuracy(exact_bb,x_qpth_grad)\n",
    "        acc_alt_bb = cal_cos_accuracy(exact_bb,x_alt_grad)\n",
    "\n",
    "        dict_report(stats_large, 'Time OSQP Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward_osqp)\n",
    "        dict_report(stats_large, 'Time OSQP Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_backward_osqp)\n",
    "        dict_report(stats_large, 'Time OSQP Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward_osqp+time_spent_backward_osqp)\n",
    "        dict_report(stats_large, 'Time Exact Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', t_spent)\n",
    "        dict_report(stats_large, 'Time Exact Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_forward_osqp+t_spent)\n",
    "        dict_report(stats_large, 'Time Qpth Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_qpth_backward)\n",
    "        dict_report(stats_large, 'Time Qpth Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spent_qpth_forward)\n",
    "        dict_report(stats_large, 'Time Qpth Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_qpth_backward+time_spent_qpth_forward)\n",
    "        dict_report(stats_large, 'Time Alt-diff Forward'+f' ndim:{ndim}'+f' neq=nineq={neq}', inverse_time)\n",
    "        dict_report(stats_large, 'Time Alt-diff Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', sol_time)\n",
    "        dict_report(stats_large, 'Time Alt-diff Overall'+f' ndim:{ndim}'+f' neq=nineq={neq}', time_spend_alt_overall)\n",
    "        dict_report(stats_large, 'Accuracy Alt-diff Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_alt_bb)\n",
    "        dict_report(stats_large, 'Accuracy OSQP Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_osqp_bb)\n",
    "        dict_report(stats_large, 'Accuracy QPTH Backward'+f' ndim:{ndim}'+f' neq=nineq={neq}', acc_qpth_bb)\n",
    "    try:\n",
    "        pd.DataFrame(stats_large).to_csv('./results/Large_scale_qp.csv')\n",
    "    except:\n",
    "        pd.Series(stats_large).to_csv('./results/Large_scale_qp.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_results_table(results_dict):\n",
    "    d = {}\n",
    "    missing_methods = []\n",
    "    for method in results_dict.keys():\n",
    "        if method in results_dict:\n",
    "            d[method] = ['{:.1e}({:.1e})'.format(np.nanmean(results_dict[method]),np.nanstd(results_dict[method]))]\n",
    "        else:\n",
    "            missing_methods.append(method)\n",
    "    df = pd.DataFrame.from_dict(d, orient='index')\n",
    "    df.index.names = ['avg']\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df = get_results_table(stats)\n",
    "df.to_csv('./results/small_scale_qp_stats.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "dfl = get_results_table(stats_large)\n",
    "dfl.to_csv('./results/Large_scale_qp_stats.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 9. Small scale LPs Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def LP_instances(ndim,nineq,delta):\n",
    "    neq = nineq\n",
    "    s0 = np.random.randn(nineq)\n",
    "    lamb0 = np.maximum(-s0, 0)\n",
    "    s0 = np.maximum(s0, 0)\n",
    "    x0 = np.random.random(ndim) # one feasible solution\n",
    "\n",
    "    G = np.random.random((nineq, ndim))\n",
    "    A = np.random.random((neq, ndim))\n",
    "    b = A@x0\n",
    "    h = G @ x0 + s0\n",
    "    c = -G.T @ lamb0\n",
    "    delta = 1e-6\n",
    "    P = np.eye(ndim,ndim)*delta\n",
    "    return P,c,b,A,G,h\n",
    "\n",
    "def LP_cvxpy_backward(P,c,A,b,G,h):\n",
    "    nineq,ndim = A.shape\n",
    "    qq = cp.Parameter(ndim)\n",
    "    qq.value = c\n",
    "    x = cp.Variable(ndim)\n",
    "    Pp = cp.psd_wrap(P)\n",
    "    prob = cp.Problem(cp.Minimize((1 / 2) * cp.quad_form(x, Pp) + qq.T@x),\n",
    "                 [A @ x == b,\n",
    "                  G @ x <= h\n",
    "                  ])\n",
    "    t3 = time.time()\n",
    "\n",
    "    prob.solve(requires_grad=True, solver='SCS')\n",
    "    time_spent_forward = time.time() - t3\n",
    "    t4 = time.time()\n",
    "    prob.backward()\n",
    "    time_spent_backward = time.time() - t4\n",
    "    return x.value,prob.value, qq.gradient,time_spent_forward,time_spent_backward\n",
    "\n",
    "def LP_OSQP_forward(P,G,A,b,h):\n",
    "    nineq,ndim = G.shape\n",
    "    neq = A.shape[0]\n",
    "    Pm = sparse.csc_matrix(P)\n",
    "    osA = np.vstack([G,A])\n",
    "    osA = sparse.csc_matrix(osA)\n",
    "    l = np.hstack([-np.inf*np.ones(nineq),b])\n",
    "    u = np.hstack([h,b])\n",
    "    x_value, y_value, time_spent = osqp_interface(Pm,c,osA,l,u)\n",
    "\n",
    "    return x_value,y_value, time_spent, np.dot(c,x_value)\n",
    "\n",
    "def LP_OSQP_backward(y_value,P,G,A):\n",
    "    nineq,ndim = G.shape\n",
    "    neq = A.shape[0]\n",
    "    lambs = y_value[:nineq] # active set\n",
    "    active_set = np.argwhere(lambs>1e-8)\n",
    "    bG = G[active_set,:].squeeze()\n",
    "    bb = np.zeros(neq)\n",
    "    bh = np.zeros(len(active_set))\n",
    "    bq = np.ones(ndim)\n",
    "    osnewA = np.vstack([bG,A])\n",
    "    osnewA = sparse.csc_matrix(osnewA)\n",
    "    l_new = np.hstack([bh,bb])\n",
    "    u_new = np.hstack([bh,bb])\n",
    "    x_grad, y_grad, time_spent_backward = osqp_interface(P,bq,osnewA,l_new,u_new)\n",
    "    return x_grad, y_grad, time_spent_backward\n",
    "\n",
    "def LP_qpth_evaluate(Pm, G, A, h, b, q):\n",
    "    nineq, ndim = G.shape\n",
    "    neq = A.shape[0]\n",
    "    qpf = QPFunction(verbose=0, maxIter=4000)\n",
    "    tP, tq, tG, th, tA, tb = [torch.Tensor(x).unsqueeze(0).cuda() for x in [Pm, q, G, h, A, b]]\n",
    "    tq.requires_grad = True\n",
    "\n",
    "    t6 = time.time()\n",
    "    qpth_x_value = qpf(tP, tq, tG, th, tA, tb)\n",
    "    t7 = time.time()\n",
    "    qpth_x_value.backward(torch.ones(1, ndim).cuda())\n",
    "    qpth_x_grad = tq.grad.squeeze().cpu().numpy()\n",
    "    t8 = time.time()\n",
    "    qpth_x_value = qpth_x_value.squeeze().detach().cpu().numpy()\n",
    "    return qpth_x_value, qpth_x_grad, t7 - t6, t8 - t7\n",
    "\n",
    "def LP_cal_exact_backward(Pm,G,A,y_value,x_value,h):\n",
    "    nineq,ndim = G.shape\n",
    "    neq = A.shape[0]\n",
    "    lambs = y_value[:nineq] # active set\n",
    "    t5 = time.time()\n",
    "    KKT_L1 = np.hstack([Pm,G.T,A.T])\n",
    "    KKT_L2 = np.hstack([np.diag(lambs)@G, np.diag(G@x_value-h),np.zeros((nineq,neq))])\n",
    "    KKT_L3 = np.hstack([A, np.zeros((neq,neq)),np.zeros((neq,nineq))])\n",
    "    KKT = np.vstack([KKT_L1,KKT_L2,KKT_L3])\n",
    "    exact_bb =-(np.linalg.inv(KKT)@np.hstack([np.ones(ndim),np.zeros(nineq),np.zeros(neq)]))[:ndim]\n",
    "    return exact_bb,time.time()-t5\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "n_list = [10, 50, 100, 500]\n",
    "nconstraints_list = [5, 10, 20, 100]\n",
    "delta = 1e-6\n",
    "lp_stats = {}\n",
    "iters = 200\n",
    "for i in tqdm(range(iters)):\n",
    "    for ndim, neq in zip(n_list, nconstraints_list):\n",
    "        P,c,b,A,G,h = LP_instances(ndim, neq,delta)  # neq = nineq\n",
    "        Pm = sparse.csc_matrix(P)\n",
    "        x_value,y_value, time_spent, f_value = LP_OSQP_forward(Pm,G,A,b,h) # OSQP Forward\n",
    "\n",
    "        x_cp_value, y_cp_value, x_cp_grad, time_spent_forward, time_spent_backward = LP_cvxpy_backward(P,c,A,b,G,h)  # cvxpy Forward and Backward\n",
    "        x_grad, y_grad, time_spent_backward_osqp = LP_OSQP_backward(y_value,Pm,G,A)  # OSQP Backward\n",
    "        x_qpth_value, x_qpth_grad,time_spent_qpth_forward, time_qpth_backward=LP_qpth_evaluate(P, G, A, h, b, c)\n",
    "        exact_bb, t_spent = LP_cal_exact_backward(P,G,A,y_value,x_value,h)  # Exact backward\n",
    "\n",
    "        acc_osqp_bb = cal_cos_accuracy(exact_bb, x_grad)\n",
    "        acc_cvxpy_bb = cal_cos_accuracy(exact_bb, x_cp_grad)\n",
    "        acc_qpth_bb = cal_cos_accuracy(exact_bb, x_qpth_grad)\n",
    "\n",
    "        dict_report(lp_stats, 'Time OSQP Forward' + f' ndim:{ndim}' + f' neq=nineq={neq}', time_spent)\n",
    "        dict_report(lp_stats, 'Time OSQP Backward' + f' ndim:{ndim}' + f' neq=nineq={neq}', time_spent_backward_osqp)\n",
    "        dict_report(lp_stats, 'Time OSQP Overall' + f' ndim:{ndim}' + f' neq=nineq={neq}', time_spent+time_spent_backward_osqp)\n",
    "        dict_report(lp_stats, 'Time CVXPY Forward' + f' ndim:{ndim}' + f' neq=nineq={neq}', time_spent_forward)\n",
    "        dict_report(lp_stats, 'Time CVXPY Backward' + f' ndim:{ndim}' + f' neq=nineq={neq}', time_spent_backward)\n",
    "        dict_report(lp_stats, 'Time CVXPY Overall' + f' ndim:{ndim}' + f' neq=nineq={neq}', time_spent_forward+time_spent_backward)\n",
    "        dict_report(lp_stats, 'Time Exact Backward' + f' ndim:{ndim}' + f' neq=nineq={neq}', t_spent)\n",
    "        dict_report(lp_stats, 'Time Exact Overall' + f' ndim:{ndim}' + f' neq=nineq={neq}', t_spent+time_spent)\n",
    "\n",
    "        dict_report(lp_stats, 'Accuracy OSQP Forward' + f' ndim:{ndim}' + f' neq=nineq={neq}', f_value)\n",
    "        dict_report(lp_stats, 'Accuracy CVXPY Forward' + f' ndim:{ndim}' + f' neq=nineq={neq}', np.dot(c,x_cp_value))\n",
    "        dict_report(lp_stats, 'Accuracy OSQP Backward' + f' ndim:{ndim}' + f' neq=nineq={neq}', acc_osqp_bb)\n",
    "        dict_report(lp_stats, 'Accuracy CVXPY Backward' + f' ndim:{ndim}' + f' neq=nineq={neq}', acc_cvxpy_bb)\n",
    "    if i>1:\n",
    "        pd.DataFrame(lp_stats).to_csv('./results/small_scale_lp.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df_lp = get_results_table(lp_stats)\n",
    "df_lp.to_csv('./results/lp_stats.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zeqiye_bpqp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.undefined"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
