{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "higher-klein",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:The OGB package is out of date. Your version is 1.3.3, while the latest version is 1.3.4.\n"
     ]
    }
   ],
   "source": [
    "import os.path\n",
    "import os\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid,TUDataset\n",
    "from torch_geometric.nn import GATConv, GCNConv,GINConv\n",
    "import numpy as np\n",
    "import igraph as ig\n",
    "from torch_geometric.nn import global_mean_pool,global_add_pool\n",
    "\n",
    "from functools import reduce\n",
    "import pickle\n",
    "from sklearn.model_selection import KFold,StratifiedKFold\n",
    "import torch.optim as optim\n",
    "import csv\n",
    "import pandas as pd\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "from ogb.graphproppred import PygGraphPropPredDataset\n",
    "from torch_geometric.data import DataLoader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1e98805f-9e96-40af-a2ac-fb28aca9b894",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "accompanied-meeting",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=\"ogbg-molhiv\"\n",
    "dataset_name2=\"ogbg-molhiv\"\n",
    "\n",
    "dataset = PygGraphPropPredDataset(name = dataset_name) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "informational-debut",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dataset: PygGraphPropPredDataset(41127):\n",
      "====================\n",
      "Number of graphs: 41127\n",
      "Number of features: 9\n",
      "Number of classes: 2\n"
     ]
    }
   ],
   "source": [
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('====================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "monetary-apparel",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "222"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset=utils.data_load(dataset,normalize=False)\n",
    "max_node=utils.max_node_dataset(dataset)\n",
    "max_node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "disciplinary-projector",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=\"ogbg_molhiv\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "junior-familiar",
   "metadata": {},
   "source": [
    "### 데이터 체크"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "accredited-england",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "file on\n"
     ]
    }
   ],
   "source": [
    "file = f'./dataset/{dataset_name}/H1_ver2'\n",
    "\n",
    "if os.path.isfile(file):\n",
    "    NEWDATA = torch.load(file)     \n",
    "    print('file on')\n",
    "\n",
    "else:        \n",
    "    SUB_ADJ=[]\n",
    "    RAW_SUB_ADJ=[]\n",
    "    NEWDATA=[]\n",
    "    for i in range(len(dataset)):                    \n",
    "        \n",
    "        data=dataset[i]\n",
    "        v1=data.edge_index[0,:]\n",
    "        v2=data.edge_index[1,:]\n",
    "        #print(torch.max(v1))\n",
    "        adj = torch.zeros((max_node,max_node))\n",
    "        adj[v1,v2]=1\n",
    "        adj=adj.numpy()\n",
    "        (adj==adj.T).all()\n",
    "        list_feature=(data.x)\n",
    "        list_adj=(adj)       \n",
    "        \n",
    "        #print(dataset[i])\n",
    "        _, _, _, _, sum_sub_adj = utils.make_cycle_adj_speed_nosl(list_adj,data)\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print(i)\n",
    "            \n",
    "        #_sub_adj=np.array(sub_adj)\n",
    "\n",
    "        if len(sum_sub_adj)>0:    \n",
    "            new_adj=np.stack((list_adj,sum_sub_adj),0)\n",
    "        else :\n",
    "            sum_sub_adj=np.zeros((1, list_adj.shape[0], list_adj.shape[1]))\n",
    "            new_adj=np.concatenate((list_adj.reshape(1, list_adj.shape[0], list_adj.shape[1]),sum_sub_adj),0)\n",
    "\n",
    "        #SUB_ADJ.append(new_adj)\n",
    "        SUB_ADJ=new_adj\n",
    "        #------합치기\n",
    "        data=dataset[i]\n",
    "        check1=torch.sum(data.edge_index[0]-np.where(SUB_ADJ[0]==1)[0])+torch.sum(data.edge_index[1]-np.where(SUB_ADJ[0]==1)[1])\n",
    "        if check1 != 0 :\n",
    "            print('error')\n",
    "\n",
    "        data.cycle_index=torch.stack((torch.LongTensor(np.where(SUB_ADJ[1]!=0)[0]), torch.LongTensor(np.where(SUB_ADJ[1]!=0)[1])),1).T.contiguous()\n",
    "        #data.cycle_attr = torch.FloatTensor(SUB_ADJ[1][np.where(SUB_ADJ[1]!=0)[0],np.where(SUB_ADJ[1]!=0)[1]]) \n",
    "        #FloatTensor 형태여야됨 \n",
    "        NEWDATA.append(data)\n",
    "        \n",
    "    torch.save(NEWDATA,file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "sensitive-sword",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_class=[]\n",
    "for i in range(len(dataset)):\n",
    "    dataset_class.append(dataset[i].y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "numerous-finland",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (968553625.py, line 5)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Input \u001b[0;32mIn [2]\u001b[0;36m\u001b[0m\n\u001b[0;31m    class_name=[Cy2C_GCN_OGB_1','Cy2C_GCN_OGB_3]\u001b[0m\n\u001b[0m                              ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "from nets import Cy2C_GCN_OGB_1,Cy2C_GCN_OGB_3\n",
    "from Trainer_ogb import Trainer\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model_name=['Cy2C_GCN_OGB_1','Cy2C_GCN_OGB_3']\n",
    "class_name=[Cy2C_GCN_OGB_1','Cy2C_GCN_OGB_3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d78aaba8-a834-4e7d-8345-245e0fc88b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "appreciated-motorcycle",
   "metadata": {},
   "outputs": [],
   "source": [
    "split_idx = dataset.get_idx_split() \n",
    "\n",
    "train_dataset=[]\n",
    "test_dataset=[]\n",
    "valid_dataset=[]\n",
    "for i in split_idx[\"train\"]:\n",
    "    train_dataset.append(NEWDATA[i])\n",
    "for i in split_idx[\"test\"]:\n",
    "    test_dataset.append(NEWDATA[i])\n",
    "for i in split_idx[\"valid\"]:\n",
    "    valid_dataset.append(NEWDATA[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "processed-airline",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "broad-gather",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "=====================================\n",
      "===== Cy2C_GCN_OGB_5_80_0.05(0.3)_1e-05 ===== ogbg_molhiv =====\n",
      "=====================================\n",
      "checkpoint\n",
      "result\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [13]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m=====================================\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     13\u001b[0m trainer\u001b[38;5;241m=\u001b[39mTrainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset\u001b[38;5;241m.\u001b[39mnum_node_features,dataset[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39my\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m], train_loader, valid_loader, test_loader,lr\u001b[38;5;241m=\u001b[39mlr,hidden_dim\u001b[38;5;241m=\u001b[39mhidden_dim,opt_fun\u001b[38;5;241m=\u001b[39mopt_fun,n_layer\u001b[38;5;241m=\u001b[39mn_layer,num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m,eval_metric\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39meval_metric,dataset_name2\u001b[38;5;241m=\u001b[39mdataset_name2,dropout\u001b[38;5;241m=\u001b[39mdropout,decay\u001b[38;5;241m=\u001b[39mdecay,n_fold\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m,drop_mid\u001b[38;5;241m=\u001b[39mdrop_mid)\n\u001b[0;32m---> 14\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Bench_re/Trainer_molhiv.py:132\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    130\u001b[0m test_curve\u001b[38;5;241m=\u001b[39m[]\n\u001b[1;32m    131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m50000\u001b[39m):\n\u001b[0;32m--> 132\u001b[0m     train_loss\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    133\u001b[0m     valid_acc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalid_loader)\n\u001b[1;32m    134\u001b[0m     \u001b[38;5;66;03m#test_acc = self.test(self.test_loader)\u001b[39;00m\n",
      "File \u001b[0;32m~/Bench_re/Trainer_molhiv.py:238\u001b[0m, in \u001b[0;36mTrainer.train_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    236\u001b[0m \u001b[38;5;66;03m#print(data.y.shape,data.y[0])\u001b[39;00m\n\u001b[1;32m    237\u001b[0m data\u001b[38;5;241m=\u001b[39mdata\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)            \n\u001b[0;32m--> 238\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcycle_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# Perform a single forward pass.\u001b[39;00m\n\u001b[1;32m    239\u001b[0m is_labeled \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39my \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my\n\u001b[1;32m    240\u001b[0m \u001b[38;5;66;03m#print(out.shape,out[0])\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/Bench_re/nets.py:28\u001b[0m, in \u001b[0;36mCy2C_GCN_OGB.forward\u001b[0;34m(self, x, edge_index, cycle_index, batch)\u001b[0m\n\u001b[1;32m     26\u001b[0m x\u001b[38;5;241m=\u001b[39mx_0\n\u001b[1;32m     27\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_layer):\n\u001b[0;32m---> 28\u001b[0m     xa\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbn_layers[i](\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv_layers\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m))\n\u001b[1;32m     29\u001b[0m     x\u001b[38;5;241m=\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout2(xa)\u001b[38;5;241m+\u001b[39mx)\n\u001b[1;32m     30\u001b[0m     a\u001b[38;5;241m.\u001b[39mappend(x)  \n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch_geometric/nn/conv/gcn_conv.py:172\u001b[0m, in \u001b[0;36mGCNConv.forward\u001b[0;34m(self, x, edge_index, edge_weight)\u001b[0m\n\u001b[1;32m    170\u001b[0m cache \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index\n\u001b[1;32m    171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 172\u001b[0m     edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43mgcn_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# yapf: disable\u001b[39;49;00m\n\u001b[1;32m    173\u001b[0m \u001b[43m        \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    174\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimproved\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_self_loops\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    175\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcached:\n\u001b[1;32m    176\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index \u001b[38;5;241m=\u001b[39m (edge_index, edge_weight)\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch_geometric/nn/conv/gcn_conv.py:58\u001b[0m, in \u001b[0;36mgcn_norm\u001b[0;34m(edge_index, edge_weight, num_nodes, improved, add_self_loops, dtype)\u001b[0m\n\u001b[1;32m     54\u001b[0m     edge_weight \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones((edge_index\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m), ), dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m     55\u001b[0m                              device\u001b[38;5;241m=\u001b[39medge_index\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     57\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m add_self_loops:\n\u001b[0;32m---> 58\u001b[0m     edge_index, tmp_edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43madd_remaining_self_loops\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     59\u001b[0m \u001b[43m        \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     60\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m tmp_edge_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     61\u001b[0m     edge_weight \u001b[38;5;241m=\u001b[39m tmp_edge_weight\n",
      "File \u001b[0;32m~/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch_geometric/utils/loop.py:228\u001b[0m, in \u001b[0;36madd_remaining_self_loops\u001b[0;34m(edge_index, edge_attr, fill_value, num_nodes)\u001b[0m\n\u001b[1;32m    224\u001b[0m     loop_attr[edge_index[\u001b[38;5;241m0\u001b[39m][inv_mask]] \u001b[38;5;241m=\u001b[39m edge_attr[inv_mask]\n\u001b[1;32m    226\u001b[0m     edge_attr \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([edge_attr[mask], loop_attr], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m--> 228\u001b[0m edge_index \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([\u001b[43medge_index\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmask\u001b[49m\u001b[43m]\u001b[49m, loop_index], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m    229\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m edge_index, edge_attr\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i,CLASS_NAME in enumerate(class_name):\n",
    "    for hidden_dim in [80]:\n",
    "        for drop_mid in [0.3]:\n",
    "            for dropout in [0.05]:\n",
    "                for decay in [0.00001]:\n",
    "                    opt_fun=1\n",
    "                    for n_layer in [3,5,7]:\n",
    "                        print(n_layer)\n",
    "                        name=f'{model_name[i]}_{n_layer}_{hidden_dim}_{dropout}({drop_mid})_{decay}'\n",
    "                        print('=====================================')\n",
    "                        print('=====',name,'=====',dataset_name,'=====')\n",
    "                        print('=====================================')\n",
    "                        trainer=Trainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset.num_node_features,dataset[0].y.shape[1], train_loader, valid_loader, test_loader,lr=lr,hidden_dim=hidden_dim,opt_fun=opt_fun,n_layer=n_layer,num_workers=4,eval_metric=dataset.eval_metric,dataset_name2=dataset_name2,dropout=dropout,decay=decay,n_fold=10,drop_mid=drop_mid)\n",
    "                        trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c8ca0c6-b2e5-46a9-8f04-ff3a5cae0abb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_yodecopy22)",
   "language": "python",
   "name": "conda_yodecopy22"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
