{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86067865-0742-4b44-a0e2-bc87447a977a",
   "metadata": {
    "is_executing": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c613b1ca-4366-4c13-b0d4-47e581195d5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from utils.evaluation import gurobi_solve_lp\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from scipy.linalg import LinAlgError\n",
    "import numpy as np\n",
    "from torch_geometric.data import Batch, HeteroData\n",
    "from scipy.sparse import coo_array\n",
    "\n",
    "from utils.evaluation import data_contraint_heuristic, data_inactive_constraints, normalize_cons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5253b78c-64f5-4b2d-b4d9-fd7107bbdebd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5baa7a-a9fa-41b9-aa05-755f29c333b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root = 'datasets/gen_250_250_0.02'\n",
    "os.mkdir(root)\n",
    "os.mkdir(os.path.join(root, 'processed'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26cbdb3b-3931-4ae5-a4b5-a6da6b315bd8",
   "metadata": {},
   "source": [
    "### Generic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e31733-3880-4299-bb84-1e7bb191dd7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "density = 0.025\n",
    "nrows = 200\n",
    "ncols = 200\n",
    "\n",
    "def surrogate_gen():\n",
    "    assert max(nrows, ncols) * density > 1\n",
    "\n",
    "    m, n = min(nrows, ncols), max(nrows, ncols)\n",
    "\n",
    "    # make sure rows and cols are selected at least once\n",
    "    rows = np.hstack([np.arange(m), np.random.randint(0, m, (n - m,))])\n",
    "    cols = np.arange(n)\n",
    "\n",
    "    # generate the rest\n",
    "    nnz = int(nrows * ncols * density)\n",
    "    num_rest = nnz - n\n",
    "\n",
    "    rows_rest = np.random.randint(0, m, (num_rest,))\n",
    "    cols_rest = np.random.randint(0, n, (num_rest,))\n",
    "\n",
    "    values = np.random.randn(nnz)\n",
    "\n",
    "    A = coo_array((values, (np.hstack([rows, rows_rest]), np.hstack([cols, cols_rest]))), shape=(m, n)).toarray()\n",
    "\n",
    "    x_feas = np.abs(np.random.randn(ncols))  # Ensure x_feas is non-negative\n",
    "    b = A @ x_feas + np.abs(np.random.randn(nrows))  # Ensure feasibility\n",
    "\n",
    "    c = np.abs(np.random.randn(ncols))\n",
    "    return A, b, c\n",
    "\n",
    "bounds = None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4e01d11-8f2c-40a3-be31-7c8250564aac",
   "metadata": {},
   "source": [
    "# create ineq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07eb0423-ed04-42e9-accd-483f224901f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ips = []\n",
    "graphs = []\n",
    "pkg_idx = 0\n",
    "success_cnt = 0\n",
    "\n",
    "max_iter = 15000\n",
    "num = 10000\n",
    "\n",
    "pbar = tqdm(range(max_iter))\n",
    "for i in pbar:\n",
    "    A, b, c = surrogate_gen()\n",
    "    c = c / (np.abs(c).max() + 1.e-10)  # does not change the result\n",
    "    A, b = normalize_cons(A, b)\n",
    "    \n",
    "    try:\n",
    "        assert np.linalg.matrix_rank(A) == min(*A.shape)\n",
    "        assert np.all(np.any(A, axis=1)) and np.all(np.any(A, axis=0))\n",
    "        # res = linprog(c, A_ub=A, b_ub=b, bounds=bounds, method='highs')\n",
    "        solution, duals = gurobi_solve_lp(A, b, c)\n",
    "        assert solution is not None\n",
    "        assert c.dot(solution) != 0.\n",
    "    except (AssertionError, LinAlgError):\n",
    "        continue\n",
    "    else:\n",
    "        heur_idx = data_contraint_heuristic(None, A, b, c)\n",
    "        inactive_idx = data_inactive_constraints(A, b, solution)\n",
    "        inactive_heur_acc = np.isin(heur_idx, inactive_idx).sum() / len(heur_idx)\n",
    "\n",
    "        A = torch.from_numpy(A).to(torch.float)\n",
    "        b = torch.from_numpy(b).to(torch.float)\n",
    "        c = torch.from_numpy(c).to(torch.float)\n",
    "        x = torch.from_numpy(solution).to(torch.float)\n",
    "\n",
    "        A_where = torch.where(A)\n",
    "        data = HeteroData(\n",
    "            cons={\n",
    "                'num_nodes': b.shape[0],\n",
    "                'x': torch.empty(b.shape[0]),\n",
    "                 },\n",
    "            vals={\n",
    "                'num_nodes': c.shape[0],\n",
    "                'x': torch.empty(c.shape[0]),\n",
    "            },\n",
    "            cons__to__vals={'edge_index': torch.vstack(A_where),\n",
    "                            'edge_attr': A[A_where][:, None]},\n",
    "            x_solution=x,\n",
    "            duals=torch.from_numpy(duals).float(),\n",
    "            obj_solution=c.dot(x),\n",
    "            q=c,\n",
    "            b=b,\n",
    "            heur_idx=torch.from_numpy(heur_idx).long(),\n",
    "            inactive_idx=torch.from_numpy(inactive_idx).long(),\n",
    "        )\n",
    "        success_cnt += 1\n",
    "        graphs.append(data)\n",
    "\n",
    "    if len(graphs) >= 1000 or success_cnt == num:\n",
    "        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')\n",
    "        pkg_idx += 1\n",
    "        graphs = []\n",
    "\n",
    "    if success_cnt >= num:\n",
    "        break\n",
    "\n",
    "    pbar.set_postfix({'suc': success_cnt, 'inactive_heur_acc': inactive_heur_acc})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba781a6-076b-462a-9366-0fb6a3457468",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataset import LPDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9f79d2-e484-4f1c-a35c-3924381ad80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = LPDataset(root, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e39d6d22-91bd-4a44-9a27-599fceadf2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = ds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6a46a5-9254-491d-89c1-89f25f6050fc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f9a0aa1-bbe1-4d57-a518-bfc5b0dcc927",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51991120-e680-4eef-8798-9eef97bcb723",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce4472d9-a978-4876-ae4e-26a154bc5fbe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "729bb646-619a-4245-b542-fe67bb3c5dd0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fabd98-e334-4631-93ac-78e695cd0172",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transforms.lp_preserve import (DropInactiveConstraint, OracleDropInactiveConstraint, OracleDropIdleVariable,\n",
    "                                    AddRedundantConstraint,\n",
    "                                    ScaleConstraint, ScaleCoordinate,\n",
    "                                    AddSubOrthogonalConstraint,\n",
    "                                    AddDumbVariables, OracleBiasProblem)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd7005e-8561-4843-9506-c44a4b6b63bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transforms.lp_preserve import ComboPreservedTransforms, ComboInterpolateTransforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e2d4b90-0c8d-428e-b6bd-301ab43661df",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf = ComboInterpolateTransforms({'DropInactiveConstraint': 0.0,\n",
    "                               'OracleDropInactiveConstraint': 0.,\n",
    "                               'OracleDropIdleVariable': 0.9,\n",
    "                               # 'OracleBiasProblem': 1.,\n",
    "                               'ScaleConstraint': 1.,\n",
    "                               'ScaleCoordinate': 1.,\n",
    "                               'AddRedundantConstraint': 0.5,\n",
    "                               'AddDumbVariables': 0.5}, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd102d4-fc73-4f3c-b682-ca951a8265c3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c795a9c-f93d-4c92-a560-6271a5ced8bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.evaluation import recover_lp_from_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772e4a40-afbd-4870-8f52-5a1321bfefc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "A,c,b,*_ = recover_lp_from_data(data)\n",
    "solution, duals = gurobi_solve_lp(A, b, c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8097a04f-f092-4f04-bb7d-41d2711c3f85",
   "metadata": {},
   "outputs": [],
   "source": [
    "c.dot(solution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfede2fe-d964-4353-8b8c-0777d8944551",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f3c0410-e528-43e4-a185-a7189e0d3fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(100):\n",
    "    d1 = tf(data)\n",
    "    A, c, b, *_ = recover_lp_from_data(d1, np.float64)\n",
    "    solution, duals = gurobi_solve_lp(A, b, c)\n",
    "    obj = c.dot(solution)\n",
    "    transformed_obj = d1.obj_solution\n",
    "    print(obj, transformed_obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c0eea1-2bc3-4f48-95e4-a8fc2eeec0c0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
