{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7ae9891d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import time \n",
    "\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.utils import subgraph\n",
    "from torch_geometric.datasets import TUDataset, ZINC\n",
    "from torch_geometric.data import DataLoader\n",
    "\n",
    "\n",
    "import multiprocessing\n",
    "from ogb.graphproppred import PygGraphPropPredDataset\n",
    "from ogb.graphproppred import Evaluator\n",
    "evaluator = Evaluator('ogbg-molhiv')\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from model import *\n",
    "import wandb\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "af664233",
   "metadata": {},
   "outputs": [],
   "source": [
    "#training parameter\n",
    "batch_size = 128\n",
    "epochs = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "abcc34d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "faa54804",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"  # Arrange GPU devices starting from 0\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]= \"0,1,2,3\"  # Set the GPUs 2 and 3 to use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7e9050bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SubgraphPartitioning(object):\n",
    "    def __init__(self, kappa=2, tau=1, device = device):\n",
    "        self.kappa = kappa  \n",
    "        self.tau = tau      \n",
    "        self.device = device\n",
    "    def __call__(self, data):\n",
    "        num_nodes_computed = max(data.num_nodes, int(data.edge_index.max()) + 1)\n",
    "        edge_index_np = data.edge_index.cpu().numpy()\n",
    "        adj = {i: set() for i in range(num_nodes_computed)}\n",
    "        for u, v in zip(edge_index_np[0], edge_index_np[1]):\n",
    "            adj[u].add(v)\n",
    "            adj[v].add(u)\n",
    "\n",
    "        G_nx = nx.Graph()\n",
    "        G_nx.add_nodes_from(range(num_nodes_computed))\n",
    "        edges = list(zip(edge_index_np[0], edge_index_np[1]))\n",
    "        G_nx.add_edges_from(edges)\n",
    "        \n",
    "        centrality_measures = {\n",
    "            \"degree\": nx.degree_centrality(G_nx),\n",
    "            \"betweenness\": nx.betweenness_centrality(G_nx, normalized=True),\n",
    "            \"closeness\": nx.closeness_centrality(G_nx)\n",
    "        }\n",
    "        subgraphs = []\n",
    "        for measure_name, measure_dict in centrality_measures.items():\n",
    "            sorted_nodes = sorted(measure_dict.items(), key=lambda x: x[1], reverse=True)\n",
    "            seeds = [node for node, _ in sorted_nodes[:self.kappa]]\n",
    "            seeds_info = {seed: {\"S_v\": {seed}, \"frontier\": {seed}} for seed in seeds}\n",
    "            global_claim = set(seeds)\n",
    "            for _ in range(self.tau):\n",
    "                candidates = {seed: set() for seed in seeds}\n",
    "                for seed in seeds:\n",
    "                    for node in seeds_info[seed][\"frontier\"]:\n",
    "                        for nbr in adj[node]:\n",
    "                            if nbr not in global_claim:\n",
    "                                candidates[seed].add(nbr)\n",
    "                for seed in seeds:\n",
    "                    seeds_info[seed][\"S_v\"].update(candidates[seed])\n",
    "                    seeds_info[seed][\"frontier\"] = candidates[seed]\n",
    "                global_claim = set()\n",
    "                for seed in seeds:\n",
    "                    global_claim.update(seeds_info[seed][\"S_v\"])\n",
    "            for seed in seeds:\n",
    "                S_v = seeds_info[seed][\"S_v\"]\n",
    "                subset = torch.tensor(sorted(list(S_v)), dtype=torch.long)\n",
    "                sub_edge_index, _ = subgraph(subset, data.edge_index,\n",
    "                                             num_nodes=num_nodes_computed,\n",
    "                                             relabel_nodes=True)\n",
    "                subgraphs.append({\n",
    "                    'subset': subset,\n",
    "                    'edge_index': sub_edge_index,\n",
    "                    'measure': measure_name\n",
    "                })\n",
    "        data.subgraphs = subgraphs\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5457af07",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Subgraph partitioning\n",
    "\n",
    "dataset_TU = [\"MUTAG\", \"NCI1\",\"NCI109\",\"PROTEINS\", \"PTC_MR\", \"PTC_FR\", \"PTC_MM\", \"PTC_FM\",\"IMDB-BINARY\",\"IMDB-MULTI\",\"COLLAB\"]\n",
    "dataset_ogbg = [\"ogbg-molhiv\"]\n",
    "dataset_ZINC = [\"ZINC\"]\n",
    "transform = T.Compose([SubgraphPartitioning(kappa=2, tau=4)])\n",
    "\n",
    "\n",
    "for dataset_name in dataset_TU:\n",
    "    path = osp.join(\"./SP_dataset\", dataset_name)\n",
    "    dataset = TUDataset(root= path, name=dataset_name,transform=transform)\n",
    "\n",
    "for dataset_name in dataset_ogbg:\n",
    "    path = osp.join(\"./SP_dataset\", dataset_name)\n",
    "    dataset = PygGraphPropPredDataset(name=dataset_name, root=path, transform=transform)\n",
    "\n",
    "for dataset_name in dataset_ZINC:\n",
    "    path = osp.join(\"./SP_dataset\", dataset_name)\n",
    "    dataset = ZINC(root=path, subset=True, transform=transform)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b62a5612",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exmaple on MUTAG\n",
    "dataset_name = \"MUTAG\"\n",
    "\n",
    "path = osp.join(\"./SP_dataset\", dataset_name)\n",
    "dataset = TUDataset(root= path, name=dataset_name,transform=transform)\n",
    "\n",
    "k = 10\n",
    "kf = KFold(n_splits=k, shuffle=True, random_state=42)\n",
    "\n",
    "folds = []\n",
    "indices = np.arange(len(dataset))\n",
    "for train_idx, val_idx in kf.split(indices):\n",
    "    train_fold = dataset[train_idx.tolist()]\n",
    "    val_fold = dataset[val_idx.tolist()]\n",
    "    train_loader = DataLoader(train_fold, batch_size, shuffle=True, num_workers=0)\n",
    "    val_loader = DataLoader(val_fold, batch_size)\n",
    "    test_loader = DataLoader(dataset, batch_size)\n",
    "    folds.append((train_loader, val_loader,test_loader))\n",
    "\n",
    "print(f\"Created {len(folds)} folds.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "0094daee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def num_graphs(data):\n",
    "    if hasattr(data, 'num_graphs'):\n",
    "        return data.num_graphs\n",
    "    else:\n",
    "        return data.x.size(0)\n",
    "\n",
    "    \n",
    "def train(model, optimizer, loader, dataset_type):\n",
    "    time_dict = {}\n",
    "    time_dict['data'] = []\n",
    "    time_dict['model'] = []\n",
    "    time_dict['loss'] = []\n",
    "    time_dict['backward'] = []\n",
    "    time_dict['optimizer'] = []\n",
    "    time_dict['total'] = []\n",
    "\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    total_loss = 0\n",
    "    for data in loader:\n",
    "        t_total_start = time.time()\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        data.x = data.x.float()\n",
    "        data = data.to(device)\n",
    "\n",
    "        t_data = time.time() - t_total_start\n",
    "        out = model(data)\n",
    "        t_model = time.time() - t_total_start\n",
    "        if dataset_type ==\"TU\":\n",
    "            loss = F.nll_loss(out, data.y.view(-1))\n",
    "        elif dataset_type ==\"ogbg\":\n",
    "            loss = F.binary_cross_entropy_with_logits(out, data.y.to(torch.float))\n",
    "        else:\n",
    "            loss = F.l1_loss(out, data.y.view(-1))\n",
    "        t_loss = time.time() - t_total_start\n",
    "        loss.backward()\n",
    "        t_backward = time.time() - t_total_start\n",
    "        total_loss += loss.item() * num_graphs(data)\n",
    "        optimizer.step()\n",
    "        t_optimizer = time.time() - t_total_start\n",
    "        time_dict['data'].append(t_data)\n",
    "        time_dict['model'].append(t_model)\n",
    "        time_dict['loss'].append(t_loss)\n",
    "        time_dict['backward'].append(t_backward)\n",
    "        time_dict['optimizer'].append(t_optimizer)\n",
    "        time_dict['total'].append(t_data + t_model + t_loss + t_backward + t_optimizer)\n",
    "\n",
    "    return total_loss / len(loader.dataset), time_dict\n",
    "\n",
    "\n",
    "\n",
    "def evaluate_mae(model, loader):\n",
    "    model.eval()\n",
    "    total_mae = 0\n",
    "    for data in loader:\n",
    "        data.x = data.x.float()\n",
    "        data = data.to(device)\n",
    "        with torch.no_grad():\n",
    "            pred = model(data).view(-1)\n",
    "        total_mae += F.l1_loss(pred, data.y.view(-1), reduction='sum').item()\n",
    "    return total_mae / len(loader.dataset)\n",
    "    \n",
    "def evaluate(model, loader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    loss = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        with torch.no_grad():\n",
    "            out = model(data)\n",
    "            pred = out.max(1)[1]\n",
    "        y_true = data.y.view(-1)\n",
    "        correct += pred.eq(y_true).sum().item()\n",
    "        loss += F.nll_loss(out, y_true, reduction='sum').item()\n",
    "    return correct / len(loader.dataset) , loss / len(loader.dataset)\n",
    "\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_roc(model, loader, use_encoder = False):\n",
    "    evaluator = Evaluator('ogbg-molhiv')\n",
    "    model.eval()\n",
    "    y_pred, y_true = [], []\n",
    "    for data in loader:\n",
    "        if not use_encoder:\n",
    "            data.x = data.x.float()\n",
    "        data = data.to(device)\n",
    "        pred = model(data)\n",
    "        y_pred.append(pred.cpu())\n",
    "        y_true.append(data.y.cpu())\n",
    "        \n",
    "    y_true = torch.cat(y_true, dim=0)\n",
    "    y_pred = torch.cat(y_pred, dim=0)\n",
    "    \n",
    "    rocauc = evaluator.eval({'y_true': y_true, 'y_pred': y_pred})['rocauc']\n",
    "    \n",
    "    probs = torch.sigmoid(y_pred)\n",
    "    if probs.dim() > 1 and probs.size(1) > 1:\n",
    "        probs = probs[:, 1]\n",
    "    else:\n",
    "        probs = probs.squeeze()\n",
    "    fpr, tpr, _ = roc_curve(y_true.numpy(), probs.numpy())\n",
    "    \n",
    "    return rocauc, fpr, tpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f522134e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set model\n",
    "\n",
    "from model.gin_regression import GIN\n",
    "\n",
    "MODEL = GIN\n",
    "\n",
    "dataset_type = \"TU\"\n",
    "\n",
    "if \"ogbg\" in dataset.name:\n",
    "    dataset_type = \"ogbg\"  # According to the gin file, there is an encoder usage setting in regression. When train on the OGBG dataset, the provided encoder usage setting can be turned on and off\n",
    "    \n",
    "if \"ZINC\" in dataset.name:\n",
    "    dataset_type = \"ZINC\"\n",
    "\n",
    "fold_dict = {}\n",
    "wandb_project_name = \"test\"\n",
    "layers = [1,2,3,4,5,6,7,8,9,10]\n",
    "hiddens = [8,16,32,64,128]\n",
    "\n",
    "for layer in layers:\n",
    "    for hidden in hiddens:\n",
    "        val_losses = []\n",
    "        accs = []\n",
    "        val_rocs = []\n",
    "        \n",
    "        for fold_idx, (train_loader_fold, val_loader_fold, test_loader_fold) in tqdm(enumerate(folds)):\n",
    "            wandb.init(project=wandb_project_name, name=f\"fold_{fold_idx}\", config={\"layer\": layer, \"hidden\": hidden, \"fold\" : fold_idx}, group=f\"{dataset_name}_molhiv_layer_{layer}_hidden_{hidden}\", job_type = f\"fold_{fold_idx}\")\n",
    "\n",
    "            model = MODEL(dataset, layer, hidden).to(device)\n",
    "            optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "            scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, min_lr=1e-5)\n",
    "            print(f\"Starting Fold {fold_idx}/{len(folds)}\")\n",
    "            \n",
    "            total_time_list = []\n",
    "            data_time_list = []\n",
    "            model_time_list = []\n",
    "            loss_time_list = []\n",
    "            backward_time_list = []\n",
    "            optimizer_time_list = []\n",
    "\n",
    "            for epoch in tqdm(range(1, epochs + 1)):\n",
    "                loss, time_dict = train(model, optimizer,train_loader_fold, dataset_type)\n",
    "                if dataset_type == \"TU\":\n",
    "                    val_acc, val_loss = evaluate(model, val_loader_fold)\n",
    "                    val_losses.append(val_loss)\n",
    "                    accs.append(val_acc)\n",
    "                elif dataset_type == \"ogbg\":\n",
    "                    val_roc_auc, _, _ = evaluate_roc(model, val_loader_fold)\n",
    "                    val_rocs.append(val_roc_auc)\n",
    "                    val_loss=val_roc_auc\n",
    "                elif dataset_type == \"ZINC\":\n",
    "                    val_loss = evaluate_mae(model,val_loader_fold)\n",
    "                    val_losses.append(val_loss)\n",
    "\n",
    "                scheduler.step(val_loss)\n",
    "                \n",
    "\n",
    "\n",
    "                wandb.log({\n",
    "                    \"epoch\": epoch,\n",
    "                    \"loss\": loss,\n",
    "                    \"val_loss\": val_loss,\n",
    "                    \"total_time\": np.sum(time_dict['total']),\n",
    "                    \"data_time\": np.mean(time_dict['data']),\n",
    "                    \"model_time\": np.mean(time_dict['model']),\n",
    "                    \"loss_time\": np.mean(time_dict['loss']),\n",
    "                    \"backward_time\": np.mean(time_dict['backward']),\n",
    "                    \"optimizer_time\": np.mean(time_dict['optimizer']),\n",
    "                })\n",
    "\n",
    "            if dataset_type == \"TU\":\n",
    "                test_acc, test_loss = evaluate(model, test_loader_fold)\n",
    "                wandb.log({\n",
    "                \"test_acc\": test_acc,\n",
    "                \"test_loss\": test_loss,\n",
    "                })\n",
    "            elif dataset_type == \"ogbg\":\n",
    "                test_roc_auc, _, _ = evaluate_roc(model, test_loader_fold)\n",
    "                wandb.log({\n",
    "                \"test_roc_auc\": test_roc_auc,\n",
    "                })\n",
    "            elif dataset_type == \"ZINC\":\n",
    "                test_loss = evaluate_mae(model, test_loader_fold)\n",
    "                wandb.log({\n",
    "                \"test_loss\": test_loss,\n",
    "                })\n",
    "\n",
    "            log_dir = os.path.join(\"test\", f\"{dataset_name}/{layer}/{hidden}\", model.__class__.__name__)\n",
    "            os.makedirs(log_dir, exist_ok=True)\n",
    "            torch.save(model.state_dict(), os.path.join(log_dir, f\"fold_{fold_idx}.pth\"))\n",
    "            wandb.finish()\n",
    "#dict 저장\n",
    "import pickle\n",
    "with open(f'test/{dataset_name}.pkl', 'wb') as f:\n",
    "    pickle.dump(fold_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e214637f",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ca9e6f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
