{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86067865-0742-4b44-a0e2-bc87447a977a",
   "metadata": {
    "is_executing": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c613b1ca-4366-4c13-b0d4-47e581195d5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from scipy.linalg import LinAlgError\n",
    "import numpy as np\n",
    "from torch_geometric.data import Batch, HeteroData\n",
    "from scipy.sparse import coo_array\n",
    "\n",
    "from sklearn.datasets import make_sparse_spd_matrix\n",
    "from utils.evaluation import normalize_cons, data_inactive_constraints, data_contraint_heuristic, gurobi_solve_qp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5253b78c-64f5-4b2d-b4d9-fd7107bbdebd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5baa7a-a9fa-41b9-aa05-755f29c333b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root = 'datasets/qp_250_250_0.02_0.02'\n",
    "os.mkdir(root)\n",
    "os.mkdir(os.path.join(root, 'processed'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26cbdb3b-3931-4ae5-a4b5-a6da6b315bd8",
   "metadata": {},
   "source": [
    "### Generic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e31733-3880-4299-bb84-1e7bb191dd7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "density = 0.05\n",
    "Pdensity = 0.05\n",
    "nrows = 100\n",
    "ncols = 100\n",
    "\n",
    "def surrogate_gen():\n",
    "    assert max(nrows, ncols) * density > 1\n",
    "\n",
    "    m, n = min(nrows, ncols), max(nrows, ncols)\n",
    "\n",
    "    # make sure rows and cols are selected at least once\n",
    "    rows = np.hstack([np.arange(m), np.random.randint(0, m, (n - m,))])\n",
    "    cols = np.arange(n)\n",
    "\n",
    "    # generate the rest\n",
    "    nnz = int(nrows * ncols * density)\n",
    "    num_rest = nnz - n\n",
    "\n",
    "    rows_rest = np.random.randint(0, m, (num_rest,))\n",
    "    cols_rest = np.random.randint(0, n, (num_rest,))\n",
    "\n",
    "    values = np.random.randn(nnz)\n",
    "\n",
    "    A = coo_array((values, (np.hstack([rows, rows_rest]), np.hstack([cols, cols_rest]))), shape=(m, n)).toarray()\n",
    "\n",
    "    x_feas = np.abs(np.random.randn(ncols))  # Ensure x_feas is non-negative\n",
    "    b = A @ x_feas + np.abs(np.random.randn(nrows))  # Ensure feasibility\n",
    "\n",
    "    c = np.random.rand(ncols)\n",
    "    P = make_sparse_spd_matrix(n_dim=A.shape[1], alpha=1 - Pdensity / 2.,\n",
    "                               smallest_coef=0.1, largest_coef=0.9, random_state=rng).astype(np.float64)\n",
    "    return P, A, b, c\n",
    "\n",
    "bounds = None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4e01d11-8f2c-40a3-be31-7c8250564aac",
   "metadata": {},
   "source": [
    "# create ineq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07eb0423-ed04-42e9-accd-483f224901f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ips = []\n",
    "graphs = []\n",
    "pkg_idx = 0\n",
    "success_cnt = 0\n",
    "\n",
    "max_iter = 15000\n",
    "num = 10000\n",
    "\n",
    "pbar = tqdm(range(max_iter))\n",
    "for i in pbar:\n",
    "    P, A, b, c = surrogate_gen()\n",
    "    P = P / np.abs(P).max()\n",
    "    c = c / (np.abs(c).max() + 1.e-10)  # does not change the result\n",
    "    A, b = normalize_cons(A, b)\n",
    "    \n",
    "    try:\n",
    "        assert np.linalg.matrix_rank(A) == min(*A.shape)\n",
    "        assert np.all(np.any(A, axis=1)) and np.all(np.any(A, axis=0))\n",
    "        solution, duals = gurobi_solve_qp(P, c, A, b)\n",
    "        assert solution is not None\n",
    "        \n",
    "    except (AssertionError, LinAlgError):\n",
    "        continue\n",
    "    else:\n",
    "        if solution is not None:\n",
    "            inactive_idx = data_inactive_constraints(A, b, solution)\n",
    "            heur_idx = data_contraint_heuristic(P, A, b, c)\n",
    "\n",
    "            obj = 0.5 * solution @ P @ solution + c.dot(solution)\n",
    "\n",
    "            P = torch.from_numpy(P).to(torch.float)\n",
    "            P_where = torch.where(P)\n",
    "            \n",
    "            A = torch.from_numpy(A).to(torch.float)\n",
    "            b = torch.from_numpy(b).to(torch.float)\n",
    "            c = torch.from_numpy(c).to(torch.float)\n",
    "            x = torch.from_numpy(solution).to(torch.float)\n",
    "\n",
    "            A_where = torch.where(A)\n",
    "            data = HeteroData(\n",
    "                cons={\n",
    "                    'num_nodes': b.shape[0],\n",
    "                    'x': torch.empty(b.shape[0]),\n",
    "                     },\n",
    "                vals={\n",
    "                    'num_nodes': c.shape[0],\n",
    "                    'x': torch.empty(c.shape[0]),\n",
    "                },\n",
    "                cons__to__vals={'edge_index': torch.vstack(A_where),\n",
    "                                'edge_attr': A[A_where][:, None]},\n",
    "                vals__to__vals={'edge_index': torch.vstack(P_where),\n",
    "                                'edge_attr': P[P_where][:, None]},\n",
    "                x_solution=x,\n",
    "                duals=torch.from_numpy(duals).float(),\n",
    "                obj_solution=torch.tensor(obj).float(),\n",
    "                q=c,\n",
    "                b=b,\n",
    "                inactive_idx=torch.from_numpy(inactive_idx).long(),\n",
    "                heur_idx=torch.from_numpy(heur_idx).long(),\n",
    "            )\n",
    "            success_cnt += 1\n",
    "            graphs.append(data)\n",
    "\n",
    "    if len(graphs) >= 1000 or success_cnt == num:\n",
    "        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')\n",
    "        pkg_idx += 1\n",
    "        graphs = []\n",
    "\n",
    "    if success_cnt >= num:\n",
    "        break\n",
    "\n",
    "    pbar.set_postfix({'suc': success_cnt, 'obj': obj})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba781a6-076b-462a-9366-0fb6a3457468",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataset import LPDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9f79d2-e484-4f1c-a35c-3924381ad80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = LPDataset(root, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e39d6d22-91bd-4a44-9a27-599fceadf2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = ds[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed3ade0-02dc-4b52-ba21-12b451265ac9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3167d5-1c3a-48e4-b983-767d8ec87c82",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fabd98-e334-4631-93ac-78e695cd0172",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transforms.lp_preserve import AddDumbVariables, OracleDropInactiveConstraint, AddRedundantConstraint, ScaleConstraint, ScaleCoordinate, OracleDropIdleVariable, OracleBiasProblem\n",
    "from transforms.lp_preserve import ComboPreservedTransforms\n",
    "from utils.evaluation import recover_qp_from_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772e4a40-afbd-4870-8f52-5a1321bfefc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "Q,A,c,b,*_ = recover_qp_from_data(data, np.float64)\n",
    "solution, duals = gurobi_solve_qp(Q, c, A, b)\n",
    "0.5 * solution @ Q @ solution + c.dot(solution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b884a9-2e63-4244-ab72-95cf5d6548ee",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f19ff7d-5dd8-4670-8cf1-a86c594dcf19",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf = ComboPreservedTransforms({'OracleDropInactiveConstraint': 0.9,\n",
    "                               'OracleDropIdleVariable': 0.9,\n",
    "                               'ScaleConstraint': 1.,\n",
    "                               'ScaleCoordinate': 1.,\n",
    "                               'AddRedundantConstraint': 0.5,\n",
    "                               'AddDumbVariables': 0.5})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4556ce94-c909-4045-b208-5d606e25c333",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf = OracleBiasProblem(1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ad95f3-2125-460a-8ff1-950ab485b64f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(100):\n",
    "    d1 = tf(data)\n",
    "    Q, A, c, b, *_ = recover_qp_from_data(d1, np.float64)\n",
    "    solution, duals = gurobi_solve_qp(Q, c, A, b)\n",
    "    obj = 0.5 * solution @ Q @ solution + c.dot(solution)\n",
    "    print(obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4aab95e-04ff-43aa-926e-dff4e417534d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c734ad96-6cd8-49b0-9c08-beb6dff061be",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(100):\n",
    "    d1 = tf(data)\n",
    "    Q, A, c, b, *_ = recover_qp_from_data(d1, np.float64)\n",
    "    solution, duals = gurobi_solve_qp(Q, c, A, b)\n",
    "    obj = 0.5 * solution @ Q @ solution + c.dot(solution)\n",
    "    transformed_obj = d1.obj_solution\n",
    "    print(obj, transformed_obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f64a0f5-e356-4a10-bfd3-d2ee76f86277",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74a56743-7a40-446a-8a2b-3b42feffbf94",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
