{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86067865-0742-4b44-a0e2-bc87447a977a",
   "metadata": {
    "is_executing": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c613b1ca-4366-4c13-b0d4-47e581195d5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch_geometric.data import Batch, HeteroData\n",
    "from numpy.linalg import LinAlgError\n",
    "\n",
    "from utils.evaluation import solve_sdp_cvxpy, solve_sdp_scs\n",
    "from torch_geometric.utils import to_dense_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d881b227-bf0d-47e8-945b-18e834931eeb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5253b78c-64f5-4b2d-b4d9-fd7107bbdebd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5baa7a-a9fa-41b9-aa05-755f29c333b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root = 'datasets/lovas_er_100_10'\n",
    "os.mkdir(root)\n",
    "os.mkdir(os.path.join(root, 'processed'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "461b74a7-9e8a-4cc9-b047-c82847c96fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "from torch_geometric.utils.convert import from_networkx\n",
    "\n",
    "\n",
    "def erdos_renyi_generator(rng, n_min=100, n_max=100, p_min=0.15, p_max=0.15):\n",
    "    n = rng.randint(n_min, n_max + 1)\n",
    "    p = rng.uniform(p_min, p_max)\n",
    "    G = nx.erdos_renyi_graph(n, p, rng)\n",
    "    return from_networkx(G)\n",
    "\n",
    "\n",
    "def barabasi_albert_generator(rng, n_min=100, n_max=100, m_min=4, m_max=4):\n",
    "    n = rng.randint(n_min, n_max + 1)\n",
    "    m = rng.randint(n_min, n_max + 1)\n",
    "    G = nx.barabasi_albert_graph(n, m, rng)\n",
    "    return from_networkx(G)\n",
    "\n",
    "\n",
    "def regular_generator(rng, n_min=100, n_max=100, d_min=3, d_max=3):\n",
    "    n = rng.randint(n_min, n_max + 1)\n",
    "    d = rng.randint(d_min, d_max + 1)\n",
    "    G = nx.random_regular_graph(d, n, rng)\n",
    "    return from_networkx(G)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e9888dd-2772-4a21-81ef-4abf295607d0",
   "metadata": {},
   "source": [
    "### Max cut - Erdos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d83d8ce-57c2-444b-a063-7d0d8a2708be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_max_cut_sdp(nnodes, density):\n",
    "    data = erdos_renyi_generator(rng, nnodes, nnodes, density, density)\n",
    "    N = nnodes\n",
    "    edge_index = data.edge_index\n",
    "    E = edge_index.shape[1]\n",
    "    adj = to_dense_adj(edge_index, max_num_nodes=N)[0].numpy()\n",
    "\n",
    "    A = []\n",
    "    b = []\n",
    "    # diagonals being 1\n",
    "    for i in range(N):\n",
    "        const = np.zeros((N, N))\n",
    "        const[i, i] = 1\n",
    "        A.append(const)\n",
    "        b.append(1)\n",
    "\n",
    "    return adj, np.stack(A, axis=-1).astype(np.float32), np.array(b, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10534bcd-6ac6-4bcb-9c5c-f7f455c5d875",
   "metadata": {},
   "source": [
    "### Max cut - regular"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21715b5e-f762-43eb-802d-6e7816358967",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_max_cut_regular_sdp(nnodes, deg):\n",
    "    data = regular_generator(rng, nnodes, nnodes, deg, deg)\n",
    "    N = nnodes\n",
    "    edge_index = data.edge_index\n",
    "    E = edge_index.shape[1]\n",
    "    adj = to_dense_adj(edge_index, max_num_nodes=N)[0].numpy()\n",
    "\n",
    "    A = []\n",
    "    b = []\n",
    "    # diagonals being 1\n",
    "    for i in range(N):\n",
    "        const = np.zeros((N, N))\n",
    "        const[i, i] = 1\n",
    "        A.append(const)\n",
    "        b.append(1)\n",
    "\n",
    "    return adj, np.stack(A, axis=-1).astype(np.float32), np.array(b, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72f86cbb-8d6a-4208-9922-d2444b591344",
   "metadata": {},
   "source": [
    "### lovas regular"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab2dbfe5-a3a8-48e2-a418-5c1135e7aedd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_lovasz_theta_regular_sdp(nnodes, deg):\n",
    "    \"\"\"\n",
    "    Generates the Lovasz Theta SDP for a random Erdos-Renyi graph.\n",
    "    \n",
    "    Problem:\n",
    "        Minimize   Trace(-J * X)  (Equivalent to Max sum(X_ij))\n",
    "        Subject to Trace(X) = 1\n",
    "                   X_ij = 0   for all (i, j) in Edges\n",
    "                   X >= 0\n",
    "    \n",
    "    This is the \"Dual\" formulation of theta usually used in solvers.\n",
    "    \"\"\"\n",
    "    \n",
    "    data = regular_generator(rng, nnodes, nnodes, deg, deg)\n",
    "    N = nnodes\n",
    "    edge_index = data.edge_index\n",
    "    edge_index = edge_index[:, edge_index[0] < edge_index[1]]\n",
    "    E = edge_index.shape[1]\n",
    "    edge_index = edge_index.t().tolist()\n",
    "    \n",
    "    As = []\n",
    "    As.append(np.eye(N))\n",
    "    b = np.zeros(1 + E, dtype=np.float32)\n",
    "    b[0] = 1.0\n",
    "    \n",
    "    for k in range(E):\n",
    "        u, v = edge_index[k]\n",
    "        A = np.zeros((N, N), dtype=np.float32)\n",
    "        \n",
    "        A[u, v] = 1.0\n",
    "        A[v, u] = 1.0 # Symmetry\n",
    "        As.append(A)\n",
    "\n",
    "    C = -np.ones((N, N), dtype=np.float32)\n",
    "    \n",
    "    return C, np.stack(As, axis=-1), b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6008b065-cbe5-40a9-afe2-8a97e039ef4f",
   "metadata": {},
   "source": [
    "### lovas er"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6109b48d-749d-4c33-bd1c-34f2ae855d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_lovasz_theta_er_sdp(nnodes, prob):\n",
    "    \"\"\"\n",
    "    Generates the Lovasz Theta SDP for a random Erdos-Renyi graph.\n",
    "    \n",
    "    Problem:\n",
    "        Minimize   Trace(-J * X)  (Equivalent to Max sum(X_ij))\n",
    "        Subject to Trace(X) = 1\n",
    "                   X_ij = 0   for all (i, j) in Edges\n",
    "                   X >= 0\n",
    "    \n",
    "    This is the \"Dual\" formulation of theta usually used in solvers.\n",
    "    \"\"\"\n",
    "    \n",
    "    data = erdos_renyi_generator(rng, nnodes, nnodes, prob, prob)\n",
    "    N = nnodes\n",
    "    edge_index = data.edge_index\n",
    "    edge_index = edge_index[:, edge_index[0] < edge_index[1]]\n",
    "    E = edge_index.shape[1]\n",
    "    edge_index = edge_index.t().tolist()\n",
    "    \n",
    "    As = []\n",
    "    As.append(np.eye(N))\n",
    "    b = np.zeros(1 + E, dtype=np.float32)\n",
    "    b[0] = 1.0\n",
    "    \n",
    "    for k in range(E):\n",
    "        u, v = edge_index[k]\n",
    "        A = np.zeros((N, N), dtype=np.float32)\n",
    "        \n",
    "        A[u, v] = 1.0\n",
    "        A[v, u] = 1.0 # Symmetry\n",
    "        As.append(A)\n",
    "\n",
    "    C = -np.ones((N, N), dtype=np.float32)\n",
    "    \n",
    "    return C, np.stack(As, axis=-1), b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe1f465-6e33-4180-b501-d434eee014eb",
   "metadata": {},
   "source": [
    "### MIS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73bd40b7-146e-4315-8993-45794de7b92c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_mis_sdp(nnodes, density):\n",
    "    data = erdos_renyi_generator(rng, nnodes, nnodes, density, density)\n",
    "    N = nnodes\n",
    "    edge_index = data.edge_index\n",
    "    E = edge_index.shape[1]\n",
    "\n",
    "    A = []\n",
    "    b = []\n",
    "    # edge = 0\n",
    "    for i in range(E):\n",
    "        x, y = edge_index[:, i].tolist()\n",
    "        const = np.zeros((N, N))\n",
    "        const[x, y] = 1\n",
    "        A.append(const)\n",
    "        b.append(0)\n",
    "    # trace = 1\n",
    "    A.append(np.eye(N))\n",
    "    b.append(1)\n",
    "\n",
    "    return -np.ones((N, N), dtype=np.float32), np.stack(A, axis=-1).astype(np.float32), np.array(b, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6902cfb-119f-4f22-9a42-ccea15bb1f89",
   "metadata": {},
   "source": [
    "### vertex cover"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "774ab44d-1600-40b7-b1e7-370381d9a9a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_cover_sdp(nnodes, density):\n",
    "    data = erdos_renyi_generator(rng, nnodes, nnodes, density, density)\n",
    "    N = nnodes\n",
    "    edge_index = data.edge_index\n",
    "    E = edge_index.shape[1]\n",
    "\n",
    "    A = []\n",
    "    b = []\n",
    "    for i in range(E):\n",
    "        x, y = edge_index[:, i].tolist()\n",
    "        const = np.zeros((N + 1, N + 1))\n",
    "        if x <= y:\n",
    "            const[x, y] = 1\n",
    "            const[0, x] = -1\n",
    "            const[0, y] = -1\n",
    "        else:\n",
    "            const[x, y] = 1\n",
    "            const[x, 0] = -1\n",
    "            const[y, 0] = -1\n",
    "        A.append(const)\n",
    "        b.append(-1)\n",
    "\n",
    "    for i in range(N+1):\n",
    "        const = np.zeros((N+1, N+1))\n",
    "        const[i, i] = 1\n",
    "        A.append(const)\n",
    "        b.append(1)\n",
    "\n",
    "    C  = np.zeros((N+1, N+1), dtype=np.float32)\n",
    "    C[0, :] = 1\n",
    "    C[:, 0] = 1\n",
    "\n",
    "    return C, np.stack(A, axis=-1).astype(np.float32), np.array(b, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81f68ee7-e617-432e-ad5b-6485df6105d4",
   "metadata": {},
   "source": [
    "### Max 2 SAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6203172-2ea6-4660-8072-aa8024aadaf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_2sat_sdp(klause, var):\n",
    "    clause_vars = np.zeros((klause, 2), dtype=int)\n",
    "    for i in range(klause):\n",
    "        clause_vars[i] = np.random.choice(var, size=2, replace=False)\n",
    "    \n",
    "    signs = np.random.choice([-1, 1], size=(klause, 2), replace=True)\n",
    "    \n",
    "    # the clause matrix\n",
    "    rows = np.repeat(np.arange(klause), 2)\n",
    "    cols = clause_vars.flatten()\n",
    "    data = signs.flatten()\n",
    "    \n",
    "    M = np.zeros((klause, var), dtype=int)\n",
    "    M[rows, cols] = data\n",
    "    MM = M.T @ M\n",
    "    C = np.block([[MM - np.diag(np.diag(MM)), -M.sum(0)[:, None]], \n",
    "                  [-M.sum(0)[None], np.zeros((1, 1))]])\n",
    "\n",
    "    C /= np.abs(C).max()\n",
    "\n",
    "    A = []\n",
    "    b = []\n",
    "    # diagonals being 1\n",
    "    for i in range(var + 1):\n",
    "        const = np.zeros((var + 1, var + 1))\n",
    "        const[i, i] = 1\n",
    "        A.append(const)\n",
    "        b.append(1)\n",
    "\n",
    "    return C.astype(np.float32), np.stack(A, axis=-1).astype(np.float32), np.array(b, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05a7ce26-e77c-4fe4-b582-b97b92b11b44",
   "metadata": {},
   "source": [
    "### hinf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "300a268e-7686-4656-8859-886e454154e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "\n",
    "def generate_boyd_lyapunov_sdp(n=50, num_cuts=500, density=0.05):\n",
    "    \"\"\"\n",
    "    Generates the Boyd LMI problem:\n",
    "       Find P > 0 such that:\n",
    "       1. Tr(P) = 1  (Equality)\n",
    "       2. A^T P + P A < 0 (Inequality)\n",
    "    \n",
    "    Returns compatible (C, A, b, P_true).\n",
    "    Note: A_tensor will contain the Trace constraint at index 0,\n",
    "          followed by 'num_cuts' inequality constraints.\n",
    "    \"\"\"\n",
    "    if num_cuts is None:\n",
    "        num_cuts = n * (n + 1) // 2\n",
    "\n",
    "    # 1. Manufacture Ground Truth P (Normalized)\n",
    "    # Generate random P > 0\n",
    "    aux = rng.randn(n, n)\n",
    "    P_true = aux @ aux.T\n",
    "    # Normalize trace to 1.0\n",
    "    P_true /= np.trace(P_true)\n",
    "    P_true = P_true.astype(np.float32)\n",
    "\n",
    "    # 2. Generate Stable Dynamics A\n",
    "    # We construct A such that A'P + PA = -Q (where Q > 0)\n",
    "    # This guarantees P_true satisfies the inequality strict.\n",
    "    # Q_stable = np.random.randn(n, n)\n",
    "    # Q_stable = Q_stable @ Q_stable.T + 1e-1 * np.eye(n) # Positive Definite\n",
    "    # Q_stable /= np.trace(Q_stable) * 0.5\n",
    "    Q_stable = np.eye(n) * 0.001\n",
    "    \n",
    "    # Solve for A (Inverse Lyapunov)\n",
    "    # A'P + PA = -Q  =>  A approx -0.5 * inv(P) * Q\n",
    "    P_inv = np.linalg.inv(P_true)\n",
    "    A_sys = -0.5 * P_inv @ Q_stable\n",
    "    # Add skew-symmetric part for complexity\n",
    "    # S = np.random.randn(n, n)\n",
    "    # S = S - S.T\n",
    "    # A_sys += P_inv @ S\n",
    "    # A_sys = A_sys.astype(np.float32)\n",
    "\n",
    "    # 3. Construct Constraints\n",
    "    # Total constraints = 1 (Trace) + num_cuts (Lyapunov Stability)\n",
    "    m = 1 + num_cuts\n",
    "    \n",
    "    A_tensor = np.zeros((n, n, m), dtype=np.float32)\n",
    "    b_vec = np.zeros(m, dtype=np.float32)\n",
    "    \n",
    "    # --- Constraint 0: Trace(P) = 1 (Equality) ---\n",
    "    np.fill_diagonal(A_tensor[..., 0], 1.0)\n",
    "    b_vec[0] = 1.0\n",
    "    \n",
    "    # --- Constraints 1..m: v_k^T (A'P + PA) v_k <= -epsilon (Inequality) ---\n",
    "    # We project the matrix inequality A'P+PA < 0 onto random directions v_k\n",
    "    Lyap_Operator_P = A_sys.T @ P_true + P_true @ A_sys # This equals -Q_stable\n",
    "\n",
    "    cnt = 1\n",
    "    fails = 0\n",
    "    while cnt < m:\n",
    "        # Random unit vector\n",
    "        v = rng.randn(n, 1)\n",
    "        v[rng.rand(v.shape[0]) < 1 - density] = 0.\n",
    "        # v /= np.linalg.norm(v)\n",
    "        V = v @ v.T # Outer product\n",
    "        \n",
    "        # Constraint Matrix A_k\n",
    "        # We constrain Trace(A_k * P) <= b_k\n",
    "        # Trace(A_k * P) should represent v^T (A'P + PA) v\n",
    "        # Coefficient matrix for P is: A V + V A'\n",
    "        Constraint_Matrix = A_sys @ V + V @ A_sys.T\n",
    "        Constraint_Matrix = 0.5 * (Constraint_Matrix + Constraint_Matrix.T)\n",
    "        if np.abs(Constraint_Matrix).max() == 0:\n",
    "            fails += 1\n",
    "            if fails == m * 10:\n",
    "                raise ValueError\n",
    "            continue\n",
    "        Constraint_Matrix /= np.abs(Constraint_Matrix).max()\n",
    "        \n",
    "        \n",
    "        A_tensor[..., cnt] = Constraint_Matrix\n",
    "        \n",
    "        # Calculate RHS based on P_true with slack\n",
    "        # Actual value is v^T (-Q) v which is strictly negative\n",
    "        val = np.trace(Constraint_Matrix @ P_true) # This is negative\n",
    "        \n",
    "        # We set bound b_k slightly above val to ensure P_true is feasible\n",
    "        # Constraint: val <= b_k\n",
    "        # Since val is negative (e.g. -5), we can set b_k = -1e-3.\n",
    "        # But to be safe and allow \"some\" margin:\n",
    "        b_vec[cnt] = val # + np.abs(val) * 0.1 # Strictly < 0\n",
    "        \n",
    "        # # Ensure b_k is at least negative (stability condition)\n",
    "        # if b_vec[k] > -1e-4:\n",
    "        #     b_vec[k] = -1e-4\n",
    "\n",
    "        cnt += 1\n",
    "\n",
    "    # 4. Objective\n",
    "    # Maximize smallest eigenvalue? Or just random C.\n",
    "    C = np.ones((n, n), dtype=np.float32)\n",
    "    \n",
    "    return C, A_tensor, b_vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4e01d11-8f2c-40a3-be31-7c8250564aac",
   "metadata": {},
   "source": [
    "# create ineq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f34d94bd-8c18-4047-95a1-2b63b4caf078",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cvxpy import DCPError, DGPError, DPPError, SolverError\n",
    "from utils.evaluation import map_vec, mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07eb0423-ed04-42e9-accd-483f224901f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "graphs = []\n",
    "pkg_idx = 0\n",
    "success_cnt = 0\n",
    "\n",
    "max_iter = 12000\n",
    "num = 10000\n",
    "\n",
    "pbar = tqdm(range(max_iter))\n",
    "for i in pbar:\n",
    "    try:\n",
    "        C, A, b = generate_lovasz_theta_er_sdp(100, 0.1)\n",
    "        # sol, X, stat, times = solve_sdp_cvxpy(C, A, b, fnorm_strength, solver)\n",
    "        X, y, dual, sol = solve_sdp_scs(C, A, b, 1.e-5)\n",
    "        # assert stat == 'optimal'\n",
    "        assert sol['info']['status'] == 'solved'\n",
    "    except (LinAlgError, DCPError, DGPError, DPPError, SolverError, AssertionError):\n",
    "        continue\n",
    "\n",
    "    else:\n",
    "        m = b.shape[0]\n",
    "        n = C.shape[0]\n",
    "        A = torch.from_numpy(A).float()\n",
    "        A = A.reshape(-1, A.shape[-1]).T  # m, n**2\n",
    "        A_where = torch.where(A)\n",
    "        \n",
    "        c2v_idx = torch.vstack(A_where)\n",
    "        c2v_value = A[A_where][:, None]\n",
    "        \n",
    "        C = torch.from_numpy(C).float().reshape(-1)[None]\n",
    "        # sparse vals obj connections\n",
    "        C_where = torch.where(C)\n",
    "        o2v_idx = torch.vstack(C_where)\n",
    "        o2v_value = C[C_where][:, None]\n",
    "\n",
    "        x = torch.from_numpy(X).float().reshape(-1)\n",
    "        y = torch.from_numpy(y).float()\n",
    "        dual = torch.from_numpy(dual).float().reshape(-1)\n",
    "\n",
    "        data = HeteroData(\n",
    "            cons={\n",
    "                'num_nodes': m,\n",
    "                'x': torch.empty(m, 0),\n",
    "                 },\n",
    "            vals={\n",
    "                'num_nodes': n ** 2,\n",
    "                'x': torch.empty(n ** 2, 0),\n",
    "            },\n",
    "            obj={\n",
    "                'num_nodes': 1,\n",
    "                'x': torch.ones(1).float(),\n",
    "                 },\n",
    "            cons__to__vals={'edge_index': c2v_idx,\n",
    "                            'edge_attr': c2v_value},\n",
    "            obj__to__vals={'edge_index': o2v_idx,\n",
    "                            'edge_attr': o2v_value},\n",
    "            x_solution=x,\n",
    "            y_solution=y,\n",
    "            dual_solution=dual,\n",
    "            obj_solution=torch.tensor([sol['info']['pobj']]),\n",
    "            b=torch.from_numpy(b).float(),\n",
    "        )\n",
    "        success_cnt += 1\n",
    "        graphs.append(data)\n",
    "\n",
    "    if len(graphs) >= 1000 or success_cnt == num:\n",
    "        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')\n",
    "        pkg_idx += 1\n",
    "        graphs = []\n",
    "\n",
    "    if success_cnt >= num:\n",
    "        break\n",
    "\n",
    "    pbar.set_postfix({'suc': success_cnt})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f7f3f14-34a6-4441-a295-0add9c4f62c2",
   "metadata": {},
   "source": [
    "## save as test only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c21b9c9f-254e-4c45-a00f-bd25dfe0554a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import InMemoryDataset\n",
    "\n",
    "datas = torch.load(f'{root}/processed/batch0.pt')\n",
    "datas = Batch.to_data_list(datas)\n",
    "torch.save(InMemoryDataset().collate(datas), f'{root}/processed/test.pt')\n",
    "torch.save(None, f'{root}/processed/train.pt')\n",
    "torch.save(None, f'{root}/processed/valid.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ea29481-8779-4fde-ba45-c14655e6fe16",
   "metadata": {},
   "source": [
    "## save as normal dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba781a6-076b-462a-9366-0fb6a3457468",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataset import LPDataset\n",
    "\n",
    "ds = LPDataset(root, 'valid')\n",
    "assert not torch.isnan(ds.data.obj_solution).any()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d30e0de-2a68-43d4-89d7-5d62749f8158",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
