{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "581daf4d-137c-4254-b274-baae97385d98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42720790-ac77-4ad7-94b3-8f37f52ac38a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def smart_split(s):\n",
    "    # Remove enclosing braces if any\n",
    "    s = s.strip()\n",
    "    if s.startswith('{') and s.endswith('}'):\n",
    "        s = s[1:-1]\n",
    "    # Split on either comma or whitespace\n",
    "    return re.split(r'[,\\s]+', s.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a293ee1a-d59a-49a0-981b-3f6878e2d388",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_dat(root, filename):\n",
    "    with open(f'{root}/{filename}', 'r') as f:\n",
    "        lines = [line.strip() for line in f if line.strip() and not line.startswith('*')]\n",
    "\n",
    "    if isinstance(eval(lines[0]), str):\n",
    "        lines = lines[1:]\n",
    "    b = np.array([float(n) for n in smart_split(lines[3])], dtype=np.float32)\n",
    "\n",
    "    m = int(lines[0])\n",
    "    block_sizes = [int(n) for n in smart_split(lines[2])]\n",
    "    block_sizes = np.abs(block_sizes)\n",
    "    assert len(block_sizes) == int(lines[1])\n",
    "    biases = np.cumsum(np.hstack([np.zeros(1, dtype=np.int32), block_sizes]))\n",
    "\n",
    "    n = sum(block_sizes)\n",
    "    Fs = np.zeros((n, n, m + 1), dtype=np.float32)\n",
    "    for line in lines[4:]:\n",
    "        num_mat, num_block, row, col, num = smart_split(line)\n",
    "        num = float(num)\n",
    "        if num:\n",
    "            num_mat = int(num_mat)\n",
    "            num_block = int(num_block) - 1\n",
    "            row = int(row) - 1\n",
    "            col = int(col) - 1\n",
    "            bias = biases[num_block]\n",
    "            Fs[bias + row, bias + col, num_mat] = num\n",
    "\n",
    "    Fs = (Fs + np.transpose(Fs, (1, 0, 2))) / 2\n",
    "    C = -Fs[..., 0]\n",
    "    A = Fs[..., 1:]\n",
    "    return C, A, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a7f81a5-45ca-4427-99ba-9edbe0e1e179",
   "metadata": {},
   "outputs": [],
   "source": [
    "root = './SDPLIB'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18a4158d-0cb4-4365-8258-7eff8531640e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a3fa0a7-4ac3-4f1c-bcb3-a5c6e0113141",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = [f for f in os.listdir(root) if f.endswith('dat-s')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1aef998-35b3-4abf-b9ed-298448255ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "exceptions = ['maxG32.dat-s', 'maxG55.dat-s', 'maxG60.dat-s', \n",
    "              'theta4.dat-s', 'theta5.dat-s', 'theta6.dat-s', 'thetaG11.dat-s', 'thetaG51.dat-s', ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f5a319-f2be-49f1-a06e-3fc9590ce561",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = [f for f in files if f not in exceptions]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c93d1f03-5fa5-4ca7-b7b0-7269086bf5e8",
   "metadata": {},
   "source": [
    "## we take mcp family"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40d01e8a-b72b-4ea1-ba7a-f003628f7ae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "done_files = [f for f in os.listdir(root) if f.endswith('.npz') and f.startswith('mcp')]\n",
    "\n",
    "for filename in files:\n",
    "    f = filename.split('.')[0]\n",
    "    if f\"{f}.npz\" not in done_files:\n",
    "        print(filename)\n",
    "        C, A, b = read_dat(root, filename)\n",
    "        print(filename, A.shape[-1], C.shape[0])\n",
    "        np.savez(f'{root}/{f}.npz', C, A, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd5121c9-93c7-488e-9fc1-bcebd9972580",
   "metadata": {},
   "outputs": [],
   "source": [
    "done_files = [f for f in os.listdir(root) if f.endswith('.npz') and f.startswith('mcp')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14a62e5b-ade3-4310-b78e-c4cfaa5a2eeb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f19a9689-a883-4794-b082-1d2a7176c635",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from torch_geometric.data import Batch, HeteroData\n",
    "from numpy.linalg import LinAlgError\n",
    "\n",
    "from utils.evaluation import solve_sdp_cvxpy, solve_sdp_scs\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "from cvxpy import DCPError, DGPError, DPPError, SolverError\n",
    "from utils.evaluation import map_vec, mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce40de6-1826-4611-9243-489a47c312bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "graphs = []\n",
    "\n",
    "for file in done_files:\n",
    "    f = np.load(f'{root}/{file}')\n",
    "    print(file)\n",
    "    C, A, b = f['arr_0'], f['arr_1'], f['arr_2']\n",
    "\n",
    "    X, y, dual, sol = solve_sdp_scs(C, A, b)\n",
    "    assert sol['info']['status'].startswith('solved')\n",
    "\n",
    "    print(sol['info']['status'], sol['info']['solve_time'])\n",
    "    \n",
    "    m = b.shape[0]\n",
    "    n = C.shape[0]\n",
    "    A = torch.from_numpy(A).float()\n",
    "    A = A.reshape(-1, A.shape[-1]).T  # m, n**2\n",
    "    A_where = torch.where(A)\n",
    "    \n",
    "    c2v_idx = torch.vstack(A_where)\n",
    "    c2v_value = A[A_where][:, None]\n",
    "    \n",
    "    C = torch.from_numpy(C).float().reshape(-1)[None]\n",
    "    # sparse vals obj connections\n",
    "    C_where = torch.where(C)\n",
    "    o2v_idx = torch.vstack(C_where)\n",
    "    o2v_value = C[C_where][:, None]\n",
    "\n",
    "    x = torch.from_numpy(X).float().reshape(-1)\n",
    "    y = torch.from_numpy(y).float()\n",
    "    dual = torch.from_numpy(dual).float().reshape(-1)\n",
    "\n",
    "    data = HeteroData(\n",
    "        cons={\n",
    "            'num_nodes': m,\n",
    "            'x': torch.empty(m, 0),\n",
    "             },\n",
    "        vals={\n",
    "            'num_nodes': n ** 2,\n",
    "            'x': torch.empty(n ** 2, 0),\n",
    "        },\n",
    "        obj={\n",
    "            'num_nodes': 1,\n",
    "            'x': torch.ones(1).float(),\n",
    "             },\n",
    "        cons__to__vals={'edge_index': c2v_idx,\n",
    "                        'edge_attr': c2v_value},\n",
    "        obj__to__vals={'edge_index': o2v_idx,\n",
    "                        'edge_attr': o2v_value},\n",
    "        x_solution=x,\n",
    "        y_solution=y,\n",
    "        dual_solution=dual,\n",
    "        obj_solution=torch.tensor([sol['info']['pobj']]),\n",
    "        b=torch.from_numpy(b).float(),\n",
    "    )\n",
    "    data.name = file.split('.')[0]\n",
    "    graphs.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c769a0-bfe1-44f2-b909-846a7121f3ae",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c9cde0ac-8def-44f5-986d-3955b4450d6c",
   "metadata": {},
   "source": [
    "## training set only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b11a7c-a5e6-4e63-88f5-7a062ba85c1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import InMemoryDataset\n",
    "\n",
    "torch.save(InMemoryDataset().collate(graphs), f'{root}/processed/train.pt')\n",
    "torch.save(None, f'{root}/processed/valid.pt')\n",
    "torch.save(None, f'{root}/processed/test.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d349b0c5-167d-4753-89f6-fdd31dfddec4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataset import LPDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00250f37-0465-468d-b2f3-18015378542d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = LPDataset(root, 'train', transform=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "345f9a3a-926b-4835-b7ab-8c7f709c5eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set.data.x_solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da909a98-4e42-4d12-bcc5-c30e3237af76",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set.data.dual_solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4831cc77-bafc-4f42-b7ae-b546bc8e0e1b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
