{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "driving-thread",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:The OGB package is out of date. Your version is 1.3.3, while the latest version is 1.3.4.\n"
     ]
    }
   ],
   "source": [
    "import os.path\n",
    "import os\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid,TUDataset\n",
    "from torch_geometric.nn import GATConv, GCNConv,GINConv\n",
    "import numpy as np\n",
    "import igraph as ig\n",
    "from torch_geometric.nn import global_mean_pool,global_add_pool\n",
    "\n",
    "from functools import reduce\n",
    "import pickle\n",
    "from sklearn.model_selection import KFold,StratifiedKFold\n",
    "import torch.optim as optim\n",
    "import csv\n",
    "import pandas as pd\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "from ogb.graphproppred import PygGraphPropPredDataset\n",
    "from torch_geometric.data import DataLoader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ab1bb62f-e0ce-4940-ab48-cb7288647a05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "hazardous-liver",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=\"ogbg-moltox21\"\n",
    "dataset_name2=\"ogbg-moltox21\"\n",
    "\n",
    "dataset = PygGraphPropPredDataset(name = dataset_name) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "theoretical-delicious",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dataset: PygGraphPropPredDataset(7831):\n",
      "====================\n",
      "Number of graphs: 7831\n",
      "Number of features: 9\n",
      "Number of classes: 2\n"
     ]
    }
   ],
   "source": [
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('====================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "sexual-islam",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "132"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset=utils.data_load(dataset,normalize=False)\n",
    "max_node=utils.max_node_dataset(dataset)\n",
    "max_node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cloudy-bloom",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=\"ogbg_moltox21\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "broadband-links",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "file on\n"
     ]
    }
   ],
   "source": [
    "file = f'./dataset/{dataset_name}/H1_ver2'\n",
    "\n",
    "if os.path.isfile(file):\n",
    "    NEWDATA = torch.load(file)     \n",
    "    print('file on')\n",
    "\n",
    "else:        \n",
    "    SUB_ADJ=[]\n",
    "    RAW_SUB_ADJ=[]\n",
    "    NEWDATA=[]\n",
    "    for i in range(len(dataset)):                    \n",
    "        \n",
    "        data=dataset[i]\n",
    "        v1=data.edge_index[0,:]\n",
    "        v2=data.edge_index[1,:]\n",
    "        #print(torch.max(v1))\n",
    "        adj = torch.zeros((max_node,max_node))\n",
    "        adj[v1,v2]=1\n",
    "        adj=adj.numpy()\n",
    "        (adj==adj.T).all()\n",
    "        list_feature=(data.x)\n",
    "        list_adj=(adj)       \n",
    "        \n",
    "        #print(dataset[i])\n",
    "        _, _, _, _, sum_sub_adj = utils.make_cycle_adj_speed_nosl(list_adj,data)\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print(i)\n",
    "            \n",
    "        #_sub_adj=np.array(sub_adj)\n",
    "\n",
    "        if len(sum_sub_adj)>0:    \n",
    "            new_adj=np.stack((list_adj,sum_sub_adj),0)\n",
    "        else :\n",
    "            sum_sub_adj=np.zeros((1, list_adj.shape[0], list_adj.shape[1]))\n",
    "            new_adj=np.concatenate((list_adj.reshape(1, list_adj.shape[0], list_adj.shape[1]),sum_sub_adj),0)\n",
    "\n",
    "        #SUB_ADJ.append(new_adj)\n",
    "        SUB_ADJ=new_adj\n",
    "        #------합치기\n",
    "        data=dataset[i]\n",
    "        check1=torch.sum(data.edge_index[0]-np.where(SUB_ADJ[0]==1)[0])+torch.sum(data.edge_index[1]-np.where(SUB_ADJ[0]==1)[1])\n",
    "        if check1 != 0 :\n",
    "            print('error')\n",
    "\n",
    "        data.cycle_index=torch.stack((torch.LongTensor(np.where(SUB_ADJ[1]!=0)[0]), torch.LongTensor(np.where(SUB_ADJ[1]!=0)[1])),1).T.contiguous()\n",
    "        #data.cycle_attr = torch.FloatTensor(SUB_ADJ[1][np.where(SUB_ADJ[1]!=0)[0],np.where(SUB_ADJ[1]!=0)[1]]) \n",
    "        #FloatTensor 형태여야됨 \n",
    "        NEWDATA.append(data)\n",
    "        \n",
    "    torch.save(NEWDATA,file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "consolidated-senator",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_class=[]\n",
    "for i in range(len(dataset)):\n",
    "    dataset_class.append(dataset[i].y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "domestic-grocery",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nets import Cy2C_GCN_OGB_1,Cy2C_GCN_OGB_3\n",
    "from Trainer_ogb import Trainer\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model_name=['Cy2C_GCN_OGB_1','Cy2C_GCN_OGB_3']\n",
    "class_name=[Cy2C_GCN_OGB_1,Cy2C_GCN_OGB_3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "relative-somewhere",
   "metadata": {},
   "outputs": [],
   "source": [
    "split_idx = dataset.get_idx_split() \n",
    "\n",
    "train_dataset=[]\n",
    "test_dataset=[]\n",
    "valid_dataset=[]\n",
    "for i in split_idx[\"train\"]:\n",
    "    train_dataset.append(NEWDATA[i])\n",
    "for i in split_idx[\"test\"]:\n",
    "    test_dataset.append(NEWDATA[i])\n",
    "for i in split_idx[\"valid\"]:\n",
    "    valid_dataset.append(NEWDATA[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "sensitive-novelty",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/YODECOPY22/lib/python3.8/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "195b7c83-2b7e-4a08-ba77-6fc0c0f35dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c1b5d4b-36c0-4b7c-8810-86ff17d61bfb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7\n",
      "=====================================\n",
      "===== Cy2C_GCN_OGB_1_7_300_0.0(0.5)_0.001 ===== ogbg_moltox21 =====\n",
      "=====================================\n",
      "checkpoint\n",
      "result\n"
     ]
    }
   ],
   "source": [
    "for i,CLASS_NAME in enumerate(class_name):\n",
    "    for hidden_dim in [300]:\n",
    "        for drop_mid in [0.5]:\n",
    "            for dropout in [0.0,0.05,0.1]:\n",
    "                for decay in [0.001,0.00001]:\n",
    "                    opt_fun=1                \n",
    "                    for n_layer in [7,6,8]:\n",
    "                        print(n_layer)\n",
    "                        name=f'{model_name[i]}_{n_layer}_{hidden_dim}_{dropout}({drop_mid})_{decay}'\n",
    "                        print('=====================================')\n",
    "                        print('=====',name,'=====',dataset_name,'=====')\n",
    "                        print('=====================================')\n",
    "                        trainer=Trainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset.num_node_features,dataset[0].y.shape[1], train_loader, valid_loader, test_loader,lr=lr,hidden_dim=hidden_dim,opt_fun=opt_fun,n_layer=n_layer,num_workers=1,eval_metric=dataset.eval_metric,dataset_name2=dataset_name2,dropout=dropout,decay=decay,n_fold=10,drop_mid=drop_mid)\n",
    "                        trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_yodecopy22)",
   "language": "python",
   "name": "conda_yodecopy22"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
