{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86067865-0742-4b44-a0e2-bc87447a977a",
   "metadata": {
    "is_executing": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c613b1ca-4366-4c13-b0d4-47e581195d5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from utils.evaluation import gurobi_solve_lp\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from scipy.linalg import LinAlgError\n",
    "import numpy as np\n",
    "from torch_geometric.data import Batch, HeteroData\n",
    "from scipy.sparse import coo_array\n",
    "\n",
    "from generate_instances_lp import generate_setcover, Graph, generate_indset, generate_cauctions, generate_capacited_facility_location\n",
    "from utils.evaluation import data_contraint_heuristic, data_inactive_constraints, normalize_cons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5253b78c-64f5-4b2d-b4d9-fd7107bbdebd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5baa7a-a9fa-41b9-aa05-755f29c333b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root = 'datasets/setc_100_100_0.05'\n",
    "os.mkdir(root)\n",
    "os.mkdir(os.path.join(root, 'processed'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22b2dfbe-1ca0-4fda-b22f-b356c25a8e9b",
   "metadata": {},
   "source": [
    "### Setcover"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98b27565-ab4d-4872-8798-0050a16cfaa2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def surrogate_gen():\n",
    "    nrows = 100\n",
    "    ncols = 100\n",
    "    density = 0.05\n",
    "    nnzrs = int(nrows * ncols * density)\n",
    "    A, b, c = generate_setcover(nrows, ncols, nnzrs, rng)\n",
    "    return A, b, c"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8feba305-fe78-4c36-9e05-c9925320ce8f",
   "metadata": {},
   "source": [
    "### Indset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d011cf5b-7a78-4eee-bec4-4c95e823787f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def surrogate_gen():\n",
    "    nnodes = 100\n",
    "    edge_probability = 0.02\n",
    "    graph = Graph.erdos_renyi(number_of_nodes=nnodes, edge_probability=edge_probability, random=rng)\n",
    "    A, b, c = generate_indset(graph=graph, nnodes=nnodes)\n",
    "    return A, b, c"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b10bae9-954f-4f8d-891b-c11ef61e61d2",
   "metadata": {},
   "source": [
    "### Cauctions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "409d4b02-d2c1-4765-9137-e840fd511ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def surrogate_gen():\n",
    "    n_items=100\n",
    "    n_bids=100\n",
    "    A, b, c = generate_cauctions(n_items=n_items, n_bids=n_bids, rng=rng, min_value=0.5, max_value=1., add_item_prob=0.5)\n",
    "    # c = np.ones_like(c, dtype=np.float32) * -1.\n",
    "    return A, b, c"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70957397-d822-4a34-b05d-8eeeba50214a",
   "metadata": {},
   "source": [
    "### Facilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e45f1f-07de-42ac-8d2d-2c1ac4c5f8d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def surrogate_gen():\n",
    "    n_customers = 25\n",
    "    n_facilities = 3\n",
    "    ratio = 0.5\n",
    "    # min would be like 0.2-ish\n",
    "    A_ub, b_ub, c = generate_capacited_facility_location(n_customers=n_customers, \n",
    "                                                                     n_facilities=n_facilities, \n",
    "                                                                     ratio=ratio, rng=rng)\n",
    "    return A_ub, b_ub, c"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4e01d11-8f2c-40a3-be31-7c8250564aac",
   "metadata": {},
   "source": [
    "# create ineq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07eb0423-ed04-42e9-accd-483f224901f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ips = []\n",
    "graphs = []\n",
    "pkg_idx = 0\n",
    "success_cnt = 0\n",
    "\n",
    "max_iter = 15000\n",
    "num = 10000\n",
    "\n",
    "pbar = tqdm(range(max_iter))\n",
    "for i in pbar:\n",
    "    A, b, c = surrogate_gen()\n",
    "    c = c / (np.abs(c).max() + 1.e-10)  # does not change the result\n",
    "    A, b = normalize_cons(A, b)\n",
    "    \n",
    "    try:\n",
    "        solution, duals = gurobi_solve_lp(A, b, c)\n",
    "        assert solution is not None\n",
    "        assert c.dot(solution) != 0.\n",
    "    except (AssertionError, LinAlgError):\n",
    "        continue\n",
    "    else:\n",
    "        # heur_idx = data_contraint_heuristic(None, A, b, c)\n",
    "        inactive_idx = data_inactive_constraints(A, b, solution)\n",
    "        # inactive_heur_acc = np.isin(heur_idx, inactive_idx).sum() / len(heur_idx)\n",
    "\n",
    "        A = torch.from_numpy(A).to(torch.float)\n",
    "        b = torch.from_numpy(b).to(torch.float)\n",
    "        c = torch.from_numpy(c).to(torch.float)\n",
    "        x = torch.from_numpy(solution).to(torch.float)\n",
    "\n",
    "        A_where = torch.where(A)\n",
    "        data = HeteroData(\n",
    "            cons={\n",
    "                'num_nodes': b.shape[0],\n",
    "                'x': torch.empty(b.shape[0]),\n",
    "                 },\n",
    "            vals={\n",
    "                'num_nodes': c.shape[0],\n",
    "                'x': torch.empty(c.shape[0]),\n",
    "            },\n",
    "            cons__to__vals={'edge_index': torch.vstack(A_where),\n",
    "                            'edge_attr': A[A_where][:, None]},\n",
    "            x_solution=x,\n",
    "            duals=torch.from_numpy(duals).float(),\n",
    "            obj_solution=c.dot(x),\n",
    "            q=c,\n",
    "            b=b,\n",
    "            # heur_idx=torch.from_numpy(heur_idx).long(),\n",
    "            inactive_idx=torch.from_numpy(inactive_idx).long(),\n",
    "        )\n",
    "        success_cnt += 1\n",
    "        graphs.append(data)\n",
    "\n",
    "    if len(graphs) >= 1000 or success_cnt == num:\n",
    "        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')\n",
    "        pkg_idx += 1\n",
    "        graphs = []\n",
    "\n",
    "    if success_cnt >= num:\n",
    "        break\n",
    "\n",
    "    pbar.set_postfix({'suc': success_cnt, 'obj': c.numpy().dot(solution)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba781a6-076b-462a-9366-0fb6a3457468",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataset import LPDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9f79d2-e484-4f1c-a35c-3924381ad80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = LPDataset(root, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e39d6d22-91bd-4a44-9a27-599fceadf2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = ds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae8555c-06e4-4821-893d-8c3fe2abfbfc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
