{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "received-advance",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "import os\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid,TUDataset\n",
    "from torch_geometric.nn import GATConv, GCNConv,GINConv\n",
    "import numpy as np\n",
    "import igraph as ig\n",
    "from torch_geometric.nn import global_mean_pool,global_add_pool\n",
    "from functools import reduce\n",
    "import pickle\n",
    "from sklearn.model_selection import KFold,StratifiedKFold\n",
    "import torch.optim as optim\n",
    "import csv\n",
    "import pandas as pd\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "import networkx as nx\n",
    "from torch_geometric.utils.convert import to_networkx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "continent-overhead",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "regulation-passion",
   "metadata": {},
   "source": [
    "# 1. Load dataset\n",
    "\n",
    "### A. dataset_name\n",
    "\n",
    "DD, MUTAG, IMDB-BINARY, REDDIT-BINARY ,COLLAB, NCI1, NCI109\n",
    "\n",
    "### B. Check data information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "advisory-apparatus",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dataset: MUTAG(188):\n",
      "====================\n",
      "Number of graphs: 188\n",
      "Number of features: 7\n",
      "Number of classes: 2\n",
      "=============================================================\n",
      "AVERAGE # H1 CYCLES: 2.8617021276595747\n",
      "AVERAGE MAGNITUDE # CYCLES: 6.243389057750762\n",
      "# GRAPH WITH CYCLES: 188\n"
     ]
    }
   ],
   "source": [
    "dataset_name = 'MUTAG'\n",
    "dataset = TUDataset(root='dataset', name=dataset_name,use_node_attr=True)\n",
    "\n",
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('====================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n",
    "print('=============================================================')\n",
    "total_cycle=0\n",
    "total_magnitude_cycle=0\n",
    "nonzero_count=0\n",
    "for i in range(len(dataset)):\n",
    "    data=dataset[i]\n",
    "    Xgraph = to_networkx(data,to_undirected= True)\n",
    "    num_g_cycle=Xgraph.number_of_edges() - Xgraph.number_of_nodes() + nx.number_connected_components(Xgraph)\n",
    "    total_cycle += num_g_cycle\n",
    "    node_each_cycle=nx.cycle_basis(Xgraph)  \n",
    "   \n",
    "    if len(node_each_cycle)>0:\n",
    "        magnitude = 0\n",
    "        for j in range(len(node_each_cycle)):\n",
    "            magnitude += len(node_each_cycle[j])\n",
    "        average_magnitude=magnitude/len(node_each_cycle)\n",
    "        nonzero_count+=1\n",
    "    else :\n",
    "        average_magnitude=0\n",
    "        \n",
    "    total_magnitude_cycle+=average_magnitude\n",
    "avg_total_cycle=total_cycle/len(dataset)   \n",
    "avg_total_magnitude=total_magnitude_cycle/len(dataset)\n",
    "print(f'AVERAGE # H1 CYCLES: {avg_total_cycle}') \n",
    "print(f'AVERAGE MAGNITUDE # CYCLES: {avg_total_magnitude}') \n",
    "print(f'# GRAPH WITH CYCLES: {nonzero_count}') "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "annoying-survivor",
   "metadata": {},
   "source": [
    "# 2. Preprocessing\n",
    "## A. Normalize \n",
    "- DD, MUTAG,  IMDB-BINARY, REDDIT-BINARY ,COLLAB, NCI1, NCI109 ( nomalize = False )\n",
    "- ENZYMES,PROTEINS_full (normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "indian-closer",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "28"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset=utils.data_load(dataset,normalize=False)\n",
    "max_node=utils.max_node_dataset(dataset)\n",
    "max_node"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "congressional-sport",
   "metadata": {},
   "source": [
    "## B. Make clique adjacency matrix \n",
    "- From utils , Cy2C algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "optimum-bulletin",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "file\n"
     ]
    }
   ],
   "source": [
    "file = f'./dataset/{dataset_name}/H1_ver2'\n",
    "\n",
    "if os.path.isfile(file):\n",
    "    NEWDATA = torch.load(file)     \n",
    "    print('file')\n",
    "else:        \n",
    "    SUB_ADJ=[]\n",
    "    RAW_SUB_ADJ=[]\n",
    "    NEWDATA=[]\n",
    "    for i in range(len(dataset)):                    \n",
    "        \n",
    "        data=dataset[i]\n",
    "        v1=data.edge_index[0,:]\n",
    "        v2=data.edge_index[1,:]\n",
    "        #print(torch.max(v1))\n",
    "        adj = torch.zeros((max_node,max_node))\n",
    "        adj[v1,v2]=1\n",
    "        adj=adj.numpy()\n",
    "        (adj==adj.T).all()\n",
    "        list_feature=(data.x)\n",
    "        list_adj=(adj)       \n",
    "        \n",
    "        #print(dataset[i])\n",
    "        _, _, _, _, sum_sub_adj = utils.make_cycle_adj_speed_nosl(list_adj,data)\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print(i)\n",
    "            \n",
    "        #_sub_adj=np.array(sub_adj)\n",
    "\n",
    "        if len(sum_sub_adj)>0:    \n",
    "            new_adj=np.stack((list_adj,sum_sub_adj),0)\n",
    "        else :\n",
    "            sum_sub_adj=np.zeros((1, list_adj.shape[0], list_adj.shape[1]))\n",
    "            new_adj=np.concatenate((list_adj.reshape(1, list_adj.shape[0], list_adj.shape[1]),sum_sub_adj),0)\n",
    "\n",
    "        #SUB_ADJ.append(new_adj)\n",
    "        SUB_ADJ=new_adj\n",
    "        #------합치기\n",
    "        data=dataset[i]\n",
    "        check1=torch.sum(data.edge_index[0]-np.where(SUB_ADJ[0]==1)[0])+torch.sum(data.edge_index[1]-np.where(SUB_ADJ[0]==1)[1])\n",
    "        if check1 != 0 :\n",
    "            print('error')\n",
    "\n",
    "        data.cycle_index=torch.stack((torch.LongTensor(np.where(SUB_ADJ[1]!=0)[0]), torch.LongTensor(np.where(SUB_ADJ[1]!=0)[1])),1).T.contiguous()\n",
    "        #data.cycle_attr = torch.FloatTensor(SUB_ADJ[1][np.where(SUB_ADJ[1]!=0)[0],np.where(SUB_ADJ[1]!=0)[1]]) \n",
    "        #FloatTensor 형태여야됨 \n",
    "        NEWDATA.append(data)\n",
    "        \n",
    "    torch.save(NEWDATA,file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "declared-embassy",
   "metadata": {},
   "source": [
    "## C. stratified 10-fold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "seven-moisture",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "folder\n",
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_14770/3633259774.py:4: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
      "  dataset_class=np.array(dataset_class)\n",
      "/tmp/ipykernel_14770/3633259774.py:4: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
      "  dataset_class=np.array(dataset_class)\n"
     ]
    }
   ],
   "source": [
    "dataset_class=[]\n",
    "for i in range(len(dataset)):\n",
    "    dataset_class.append(dataset[i].y)\n",
    "dataset_class=np.array(dataset_class)\n",
    "dataset_class.shape, dataset_class[[0,10,50]]\n",
    "\n",
    "\n",
    "\n",
    "folder = f'./dataset/{dataset_name}/kfold_data'\n",
    "if os.path.isdir(folder):\n",
    "    print('folder')\n",
    "\n",
    "    for j in range(10):\n",
    "        print(j)\n",
    "        test_index = torch.as_tensor(np.loadtxt(f'./dataset/{dataset_name}/kfold_data/test_idx-{j}.txt',dtype=np.int32), dtype=torch.long)\n",
    "        for k in range(10):\n",
    "            train_index = torch.as_tensor(np.loadtxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/train_idx-{k}.txt',dtype=np.int32), dtype=torch.long)\n",
    "            valid_index = torch.as_tensor(np.loadtxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/valid_idx-{k}.txt',dtype=np.int32), dtype=torch.long)    \n",
    "            all_index = reduce(np.union1d, (train_index, valid_index, test_index))\n",
    "            assert len(dataset) == len(all_index)\n",
    "\n",
    "else :        \n",
    "    os.makedirs(folder)\n",
    "    kkf=StratifiedKFold(n_splits=10, shuffle=True)\n",
    "    #kf = KFold(n_splits=10, shuffle=True)\n",
    "    kkf2=StratifiedKFold(n_splits=10, shuffle=True)\n",
    "    #kf2 = KFold(n_splits=10, shuffle=True)\n",
    "    kkf.get_n_splits(dataset,dataset_class)\n",
    "    print(kkf)\n",
    "    j=0\n",
    "    for train_total_index, test_index in kkf.split(dataset,dataset_class):\n",
    "        #print(train_index, test_index)\n",
    "        np.savetxt(f'./dataset/{dataset_name}/kfold_data/train_total_idx-{j}.txt',(train_total_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "        np.savetxt(f'./dataset/{dataset_name}/kfold_data/test_idx-{j}.txt',(test_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "        assert len(dataset)==len(reduce(np.union1d, (test_index, train_total_index)))\n",
    "        k=0\n",
    "        os.mkdir(f'./dataset/{dataset_name}/kfold_data/train_total_{j}') \n",
    "        \n",
    "        dataset_class_train=[]\n",
    "        dataset_train=[]\n",
    "        for i in train_total_index:\n",
    "            dataset_class_train.append(dataset[i].y)\n",
    "            dataset_train.append(dataset[i])\n",
    "        dataset_class_train=np.array(dataset_class_train)\n",
    "        dataset_train=np.array(dataset_train)\n",
    "        kkf2.get_n_splits(train_total_index,dataset_class_train)\n",
    "        \n",
    "        for ii, jj in kkf2.split(dataset_train,dataset_class_train):        \n",
    "            valid_index=train_total_index[jj]\n",
    "            train_index=train_total_index[ii]\n",
    "            np.savetxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/valid_idx-{k}.txt',(valid_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "            np.savetxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/train_idx-{k}.txt',(train_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "            k+=1\n",
    "            assert len(train_total_index)==len(reduce(np.union1d, (valid_index, train_index)))\n",
    "        j+=1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dramatic-commissioner",
   "metadata": {},
   "source": [
    "# 3. Train & Test\n",
    "\n",
    "### Baseline-GNNs\n",
    "- from nets import GCN, GAT, GIN\n",
    "- Option Cy2C=False\n",
    "\n",
    "### Cy2C-GNNs\n",
    "- from nets import Cy2C-GCN, Cy2C-GAT, Cy2C-GIN\n",
    "- Option Cy2C=True(default)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "described-peace",
   "metadata": {},
   "source": [
    "## A. Cy2C-GNNs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "junior-literature",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nets import Cy2C_GCN, Cy2C_GAT, Cy2C_GIN\n",
    "from Trainer import Trainer\n",
    "device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "model_name=['Cy2C_GCN', 'Cy2C_GAT', 'Cy2C_GIN']\n",
    "class_name=[Cy2C_GCN, Cy2C_GAT, Cy2C_GIN]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "sufficient-cinema",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3 #lr=1e-4\n",
    "batch_size=32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "still-naples",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====================================\n",
      "===== (10fold)Cy2C_GCN_1(64)_drop0.0(0.0)_deacy0.0 ===== MUTAG =====\n",
      "=====================================\n",
      "folder_on\n",
      "load mainfold, subfold== 0 0\n",
      "Mainfold_index: 0, Subfold_index:0\n",
      "main & sub ===0,0,best acc & loss==,0.7895,0.0418,final acc & loss==0.7895,0.0475,best_epoch==545,final_epoch==646\n",
      "load mainfold, subfold== 1 0\n",
      "Mainfold_index: 1, Subfold_index:0\n",
      "main & sub ===1,0,best acc & loss==,0.8421,0.0289,final acc & loss==0.7368,0.0231,best_epoch==34,final_epoch==135\n",
      "load mainfold, subfold== 2 0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m=====================================\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     11\u001b[0m trainer\u001b[38;5;241m=\u001b[39mTrainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset\u001b[38;5;241m.\u001b[39mnum_node_features,dataset\u001b[38;5;241m.\u001b[39mnum_classes,lr\u001b[38;5;241m=\u001b[39mlr,hidden_dim\u001b[38;5;241m=\u001b[39mhidden_dim,n_layer\u001b[38;5;241m=\u001b[39mn_layer,num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,drop_mid\u001b[38;5;241m=\u001b[39mdrop_mid,small_fold\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,batch_size\u001b[38;5;241m=\u001b[39mbatch_size,decay\u001b[38;5;241m=\u001b[39mdecay, main_fold\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/YY/NIPS2022_code/Trainer.py:158\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    155\u001b[0m patience \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m50000\u001b[39m):\n\u001b[0;32m--> 158\u001b[0m     train_loss\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    159\u001b[0m     valid_acc,valid_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest(valid_loader)\n\u001b[1;32m    160\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscheduler\u001b[38;5;241m.\u001b[39mstep(valid_loss)\n",
      "File \u001b[0;32m~/YY/NIPS2022_code/Trainer.py:281\u001b[0m, in \u001b[0;36mTrainer.train_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    279\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m    280\u001b[0m total_loss\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m\n\u001b[0;32m--> 281\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_loader:  \u001b[38;5;66;03m# Iterate in batches over the training dataset.\u001b[39;00m\n\u001b[1;32m    282\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m    283\u001b[0m     data\u001b[38;5;241m=\u001b[39mdata\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch/utils/data/dataloader.py:521\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    519\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    520\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()\n\u001b[0;32m--> 521\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    522\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    523\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    524\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    525\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch/utils/data/dataloader.py:561\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    559\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    560\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 561\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    562\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m    563\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data)\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:52\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     51\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n\u001b[0;32m---> 52\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollate_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch_geometric/loader/dataloader.py:19\u001b[0m, in \u001b[0;36mCollater.__call__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m     17\u001b[0m elem \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(elem, BaseData):\n\u001b[0;32m---> 19\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mBatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_data_list\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfollow_batch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     20\u001b[0m \u001b[43m                                \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexclude_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(elem, torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[1;32m     22\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m default_collate(batch)\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch_geometric/data/batch.py:68\u001b[0m, in \u001b[0;36mBatch.from_data_list\u001b[0;34m(cls, data_list, follow_batch, exclude_keys)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m     57\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfrom_data_list\u001b[39m(\u001b[38;5;28mcls\u001b[39m, data_list: List[BaseData],\n\u001b[1;32m     58\u001b[0m                    follow_batch: Optional[List[\u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m     59\u001b[0m                    exclude_keys: Optional[List[\u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m     60\u001b[0m     \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Constructs a :class:`~torch_geometric.data.Batch` object from a\u001b[39;00m\n\u001b[1;32m     61\u001b[0m \u001b[38;5;124;03m    Python list of :class:`~torch_geometric.data.Data` or\u001b[39;00m\n\u001b[1;32m     62\u001b[0m \u001b[38;5;124;03m    :class:`~torch_geometric.data.HeteroData` objects.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[38;5;124;03m    :obj:`follow_batch`.\u001b[39;00m\n\u001b[1;32m     66\u001b[0m \u001b[38;5;124;03m    Will exclude any keys given in :obj:`exclude_keys`.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m     batch, slice_dict, inc_dict \u001b[38;5;241m=\u001b[39m \u001b[43mcollate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     69\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     70\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdata_list\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_list\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     71\u001b[0m \u001b[43m        \u001b[49m\u001b[43mincrement\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     72\u001b[0m \u001b[43m        \u001b[49m\u001b[43madd_batch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdata_list\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBatch\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     73\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfollow_batch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfollow_batch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     74\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexclude_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexclude_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     75\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     77\u001b[0m     batch\u001b[38;5;241m.\u001b[39m_num_graphs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(data_list)\n\u001b[1;32m     78\u001b[0m     batch\u001b[38;5;241m.\u001b[39m_slice_dict \u001b[38;5;241m=\u001b[39m slice_dict\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch_geometric/data/collate.py:84\u001b[0m, in \u001b[0;36mcollate\u001b[0;34m(cls, data_list, increment, add_batch, follow_batch, exclude_keys)\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m     83\u001b[0m \u001b[38;5;66;03m# Collate attributes into a unified representation:\u001b[39;00m\n\u001b[0;32m---> 84\u001b[0m value, slices, incs \u001b[38;5;241m=\u001b[39m \u001b[43m_collate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstores\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     85\u001b[0m \u001b[43m                               \u001b[49m\u001b[43mincrement\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     87\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, Tensor) \u001b[38;5;129;01mand\u001b[39;00m value\u001b[38;5;241m.\u001b[39mis_cuda:\n\u001b[1;32m     88\u001b[0m     device \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mdevice\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch_geometric/data/collate.py:131\u001b[0m, in \u001b[0;36m_collate\u001b[0;34m(key, values, data_list, stores, increment)\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cat_dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m elem\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m    130\u001b[0m     values \u001b[38;5;241m=\u001b[39m [value\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m value \u001b[38;5;129;01min\u001b[39;00m values]\n\u001b[0;32m--> 131\u001b[0m slices \u001b[38;5;241m=\u001b[39m \u001b[43mcumsum\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcat_dim\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    132\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m increment:\n\u001b[1;32m    133\u001b[0m     incs \u001b[38;5;241m=\u001b[39m get_incs(key, values, data_list, stores)\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/torch_geometric/data/collate.py:217\u001b[0m, in \u001b[0;36mcumsum\u001b[0;34m(value)\u001b[0m\n\u001b[1;32m    215\u001b[0m out \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mnew_empty((value\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, ) \u001b[38;5;241m+\u001b[39m value\u001b[38;5;241m.\u001b[39msize()[\u001b[38;5;241m1\u001b[39m:])\n\u001b[1;32m    216\u001b[0m out[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m--> 217\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcumsum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i,CLASS_NAME in enumerate(class_name):\n",
    "    for decay in [0.0, 0.0001]:\n",
    "        for hidden_dim in [64,128]:\n",
    "            for drop_ini in [0.0,0.2,0.4]:\n",
    "                for drop_mid in [0.0, 0.2, 0.4]:\n",
    "                    for n_layer in [1,2,3,4,5]:\n",
    "                        name=f'(10fold){model_name[i]}_{n_layer}({hidden_dim})_drop{drop_ini}({drop_mid})_deacy{decay}'\n",
    "                        print('=====================================')\n",
    "                        print('=====',name,'=====',dataset_name,'=====')\n",
    "                        print('=====================================')\n",
    "                        trainer=Trainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset.num_node_features,dataset.num_classes,lr=lr,hidden_dim=hidden_dim,n_layer=n_layer,num_workers=1,drop_mid=drop_mid,small_fold=1,batch_size=batch_size,decay=decay, main_fold=10)\n",
    "                        trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "invalid-impossible",
   "metadata": {},
   "source": [
    "## B. Baseline-GNNs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "upper-wrist",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nets import GCN, GAT, GIN\n",
    "from Trainer import Trainer\n",
    "device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "model_name=['GCN', 'GAT', 'GIN']\n",
    "class_name=[GCN, GAT, GIN]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "conceptual-colony",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-4\n",
    "batch_size=32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pleased-sunset",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "=====================================\n",
      "===== (10FOLDBF)GCN_1(136)_drop0.0(0.0)_deacy0.0 ===== MUTAG =====\n",
      "=====================================\n",
      "folder_on\n",
      "folder_on\n",
      "load mainfold, subfold== 0 0\n",
      "Mainfold_index: 0, Subfold_index:0\n",
      "main & sub ===0,0,best acc & loss==,0.7368,0.0280,final acc & loss==0.7368,0.0286,best_epoch==155,final_epoch==256\n",
      "load mainfold, subfold== 1 0\n"
     ]
    }
   ],
   "source": [
    "for i,CLASS_NAME in enumerate(class_name):\n",
    "    for decay in [0.0]:\n",
    "        for hidden_dim in [136]:\n",
    "            for drop_ini in [0.0]:\n",
    "                for drop_mid in [0.0]:\n",
    "                    for n_layer in [1,2,3,4,5]:\n",
    "                        print(n_layer)\n",
    "                        name=f'(10FOLDBF){model_name[i]}_{n_layer}({hidden_dim})_drop{drop_ini}({drop_mid})_deacy{decay}'\n",
    "                        print('=====================================')\n",
    "                        print('=====',name,'=====',dataset_name,'=====')\n",
    "                        print('=====================================')\n",
    "                        trainer=Trainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset.num_node_features,dataset.num_classes,lr=lr,hidden_dim=hidden_dim,n_layer=n_layer,num_workers=1,drop_mid=drop_mid,small_fold=1,batch_size=batch_size,decay=decay, main_fold=10,Cy2C=True,base=True)\n",
    "                        trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "durable-myanmar",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "YODECOPY",
   "language": "python",
   "name": "yodecopy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
