{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86067865-0742-4b44-a0e2-bc87447a977a",
   "metadata": {
    "is_executing": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c613b1ca-4366-4c13-b0d4-47e581195d5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch_geometric.data import Batch, HeteroData, Data\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import from_networkx, to_networkx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07824662-90d5-4e94-9748-13975f2c28d7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5253b78c-64f5-4b2d-b4d9-fd7107bbdebd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5baa7a-a9fa-41b9-aa05-755f29c333b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root = 'datasets/geo_1000_10d_100g_dense-5_cs0.35'\n",
    "os.mkdir(root)\n",
    "os.mkdir(os.path.join(root, 'processed'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "488e204b-4b1e-4603-83f6-095c9cc6c97a",
   "metadata": {},
   "source": [
    "|dataset | # nodes | # dim | # gaussian | density | clus. |\n",
    "|--- | --- | --- | --- | --- | --- |\n",
    "|geo-1000-2 | 1000 | 2 | 100 | 0.5 | 1 |\n",
    "|geo-2000-2 | 2000 | 2 | 100 | 0.5 | 1.3 |\n",
    "|geo-3000-2 | 3000 | 2 | 100 | 0.5 | 1.7 |\n",
    "|geo-5000-2 | 5000 | 2 | 100 | 0.5 | 2.1 |\n",
    "|geo-10000-2 | 10000 | 2 | 100 | 0.5 | 2.9 |\n",
    "|geo-1000-5 | 1000 | 5 | 100 | 1.e-5 | 0.4 |\n",
    "|geo-1000-10 | 1000 | 10 | 100 | 1.e-5 | 0.25 |\n",
    "|geo-2000-10 | 2000 | 10 | 100 | 1.e-5 | 0.28 |\n",
    "|geo-3000-10 | 3000 | 10 | 100 | 1.e-5 | 0.3 |\n",
    "|geo-5000-10 | 5000 | 10 | 100 | 1.e-5 | 0.33 |\n",
    "|geo-10000-10 | 10000 | 10 | 100 | 1.e-5 | 0.36 |\n",
    "|geo-1000-10-dense | 1000 | 10 | 100 | 1.e-5 | 0.19 |\n",
    "|geo-1000-10-sparse | 1000 | 10 | 100 | 1.e-5 | 0.35 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fd208f4-5892-47fb-ba4b-0419d651e0e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial import distance_matrix\n",
    "\n",
    "def generate_gmm_unit_disk_graph(num_nodes, \n",
    "                                 dim=2, \n",
    "                                 num_components=3, \n",
    "                                 density=2.0, \n",
    "                                 cluster_spread=0.1):\n",
    "    \"\"\"\n",
    "    Generates a geometric graph using a Gaussian Mixture Model where \n",
    "    the means and covariances are explicitly derived from the density.\n",
    "\n",
    "    Args:\n",
    "        num_nodes (int): Number of nodes.\n",
    "        dim (int): Dimension (e.g., 2, 3).\n",
    "        num_components (int): Number of Gaussian clusters.\n",
    "        density (float): Nodes per unit volume. Determines the global box size L.\n",
    "        cluster_spread (float): Controls how spread out each cluster is relative to L.\n",
    "                                Higher = clusters merge into a uniform blob.\n",
    "                                Lower = tight, separated islands.\n",
    "        seed (int): Random seed.\n",
    "    \"\"\"\n",
    "\n",
    "    # 1. Calculate Characteristic Box Size (L)\n",
    "    # Volume V = N / density\n",
    "    # Length L = V^(1/d)\n",
    "    box_size = (num_nodes / density) ** (1.0 / dim)\n",
    "\n",
    "    # 2. Define Mixture Weights & Counts\n",
    "    mixture_weights = rng.dirichlet(np.ones(num_components))\n",
    "    component_counts = rng.multinomial(num_nodes, mixture_weights)\n",
    "\n",
    "    all_points = []\n",
    "\n",
    "    for i in range(num_components):\n",
    "        count = component_counts[i]\n",
    "        if count == 0:\n",
    "            continue\n",
    "\n",
    "        # 3. Generate Means dependent on Box Size\n",
    "        # Cluster centers are placed uniformly within the calculated box [-L/2, L/2]\n",
    "        mean = np.random.uniform(-box_size / 2, box_size / 2, size=(dim,))\n",
    "        cov = np.eye(dim)\n",
    "\n",
    "        # Sample points\n",
    "        points = rng.multivariate_normal(mean, cov, size=count)\n",
    "        all_points.append(points)\n",
    "\n",
    "    # Combine\n",
    "    pos = np.vstack(all_points)\n",
    "    pos = pos * cluster_spread\n",
    "    \n",
    "    # 5. Compute Edges (Strictly d <= 1.0)\n",
    "    dist_mat = distance_matrix(pos, pos)\n",
    "    \n",
    "    # Radius Graph logic\n",
    "    adj_mask = (dist_mat <= 1.0)\n",
    "    \n",
    "    rows, cols = np.where(adj_mask)\n",
    "    \n",
    "    # 6. PyG Data\n",
    "    edge_index = torch.tensor(np.array([rows, cols]), dtype=torch.long)\n",
    "    edge_weight = torch.tensor(dist_mat[rows, cols], dtype=torch.float)\n",
    "\n",
    "    data = Data(edge_index=edge_index, \n",
    "                edge_weight=edge_weight, \n",
    "                num_nodes=num_nodes, pos=torch.from_numpy(pos).float())\n",
    "    \n",
    "    return data, pos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "267ac96b-8a2e-4e79-a2ab-d52861d8df88",
   "metadata": {},
   "outputs": [],
   "source": [
    "data, pos = generate_gmm_unit_disk_graph(1000, \n",
    "                                 dim=10, \n",
    "                                 num_components=100, \n",
    "                                 density=1e-5, \n",
    "                                 cluster_spread=0.35)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10c67e3e-1212-40b6-adc6-8afa71d3c04a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c403a459-15cd-4113-a6d5-693f46f6b7a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.pos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd1ef028-13fa-4236-a9cb-ffed93352add",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e677fc-bf71-4dd0-820e-248b68be538d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8398ff23-fce2-464d-9302-8d4b3077b968",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e672ad-1e0f-45a9-9498-8421fc904510",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(pos[:, :2], columns=['X', 'Y'])\n",
    "\n",
    "# 4. Draw Points\n",
    "# s=15 controls dot size, alpha=0.6 makes them slightly transparent\n",
    "sns.scatterplot(\n",
    "    data=df, x='X', y='Y', \n",
    "    color='darkblue', s=15, alpha=0.6, edgecolor=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "567a9645-d6d7-4369-8e4f-f2729bede0b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e4e01d11-8f2c-40a3-be31-7c8250564aac",
   "metadata": {},
   "source": [
    "# create problems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07eb0423-ed04-42e9-accd-483f224901f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "graphs = []\n",
    "pkg_idx = 0\n",
    "success_cnt = 0\n",
    "\n",
    "max_iter = 12000\n",
    "num = 10000\n",
    "\n",
    "pbar = tqdm(range(max_iter))\n",
    "for i in pbar:\n",
    "    data, _ = generate_gmm_unit_disk_graph(1000, \n",
    "                                 dim=10, \n",
    "                                 num_components=100, \n",
    "                                 density=1e-5, \n",
    "                                 cluster_spread=0.35)\n",
    "    success_cnt += 1\n",
    "    graphs.append(data)\n",
    "\n",
    "    if len(graphs) >= 1000 or success_cnt == num:\n",
    "        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')\n",
    "        pkg_idx += 1\n",
    "        graphs = []\n",
    "\n",
    "    if success_cnt >= num:\n",
    "        break\n",
    "\n",
    "    pbar.set_postfix({'suc': success_cnt})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f7f3f14-34a6-4441-a295-0add9c4f62c2",
   "metadata": {},
   "source": [
    "## save as test only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c21b9c9f-254e-4c45-a00f-bd25dfe0554a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import InMemoryDataset\n",
    "\n",
    "datas = torch.load(f'{root}/processed/batch0.pt')\n",
    "datas = Batch.to_data_list(datas)\n",
    "torch.save(InMemoryDataset().collate(datas), f'{root}/processed/test.pt')\n",
    "torch.save(InMemoryDataset().collate(datas), f'{root}/processed/train.pt')\n",
    "torch.save(InMemoryDataset().collate(datas), f'{root}/processed/valid.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0e05636-3dad-41e0-ba59-287f205ed741",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1ea29481-8779-4fde-ba45-c14655e6fe16",
   "metadata": {},
   "source": [
    "## save as normal dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba781a6-076b-462a-9366-0fb6a3457468",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataset import LPDataset\n",
    "\n",
    "ds = LPDataset(root, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d948b00-ca6e-4972-9480-49f7e67f2d81",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21719b0-8d74-49fc-9931-f9f23ceff2fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import degree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84ba4086-a49e-4424-8ba6-a8acd499ab65",
   "metadata": {},
   "outputs": [],
   "source": [
    "dd = []\n",
    "for g in ds:\n",
    "    dd.append(degree(g.edge_index[0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "717d1339-4c6a-4148-9c25-f6f92a9c57aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean(torch.tensor(dd))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "893478e4-3a9b-409c-84e8-62e3b9532860",
   "metadata": {},
   "outputs": [],
   "source": [
    "34.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f59538c-7671-4a1b-afa8-516fb8f5ccdc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
