{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "objective-traffic",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/YODECOPY/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os.path\n",
    "import os\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid,TUDataset\n",
    "from torch_geometric.nn import GATConv, GCNConv,GINConv\n",
    "import numpy as np\n",
    "import igraph as ig\n",
    "from torch_geometric.nn import global_mean_pool,global_add_pool\n",
    "import utils\n",
    "from functools import reduce\n",
    "import pickle\n",
    "from sklearn.model_selection import KFold,StratifiedKFold\n",
    "import torch.optim as optim\n",
    "import csv\n",
    "import pandas as pd\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "976ccb9c-77fb-427c-b0c0-47eaef0c2941",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "vanilla-alliance",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://www.chrsmrrs.com/graphkerneldatasets/BZR_MD.zip\n",
      "Extracting dataset/BZR_MD/BZR_MD.zip\n",
      "Processing...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dataset: BZR_MD(306):\n",
      "====================\n",
      "Number of graphs: 306\n",
      "Number of features: 8\n",
      "Number of classes: 2\n",
      "\n",
      "Data(edge_index=[2, 342], x=[19, 8], edge_attr=[342, 6], y=[1])\n",
      "=============================================================\n",
      "Number of nodes: 19\n",
      "Number of edges: 342\n",
      "Average node degree: 18.00\n",
      "Has isolated nodes: False\n",
      "Has self-loops: False\n",
      "Is undirected: True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Done!\n"
     ]
    }
   ],
   "source": [
    "dataset_name = 'BZR_MD'\n",
    "dataset = TUDataset(root='dataset', name=dataset_name,use_edge_attr=True )\n",
    "\n",
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('====================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n",
    "\n",
    "data = dataset[0]  # Get the first graph object.\n",
    "print()\n",
    "print(data)\n",
    "print('=============================================================')\n",
    "\n",
    "# Gather some statistics about the first graph.\n",
    "print(f'Number of nodes: {data.num_nodes}')\n",
    "print(f'Number of edges: {data.num_edges}')\n",
    "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
    "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n",
    "print(f'Has self-loops: {data.has_self_loops()}')\n",
    "print(f'Is undirected: {data.is_undirected()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "accepted-tribune",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "33"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset=utils.data_load(dataset,normalize=False)\n",
    "max_node=utils.max_node_dataset(dataset)\n",
    "max_node"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "federal-israeli",
   "metadata": {},
   "source": [
    "### 데이터 체크"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "quick-connecticut",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "100\n",
      "200\n",
      "300\n"
     ]
    }
   ],
   "source": [
    "file = f'./dataset/{dataset_name}/H1_ver2'\n",
    "\n",
    "if os.path.isfile(file):\n",
    "    NEWDATA = torch.load(file)     \n",
    "    print('file on')\n",
    "\n",
    "else:        \n",
    "    SUB_ADJ=[]\n",
    "    RAW_SUB_ADJ=[]\n",
    "    NEWDATA=[]\n",
    "    for i in range(len(dataset)):                    \n",
    "        \n",
    "        data=dataset[i]\n",
    "        v1=data.edge_index[0,:]\n",
    "        v2=data.edge_index[1,:]\n",
    "        #print(torch.max(v1))\n",
    "        adj = torch.zeros((max_node,max_node))\n",
    "        adj[v1,v2]=1\n",
    "        adj=adj.numpy()\n",
    "        (adj==adj.T).all()\n",
    "        list_feature=(data.x)\n",
    "        list_adj=(adj)       \n",
    "        \n",
    "        #print(dataset[i])\n",
    "        _, _, _, _, sum_sub_adj = utils.make_cycle_adj_speed_nosl(list_adj,data)\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print(i)\n",
    "            \n",
    "        #_sub_adj=np.array(sub_adj)\n",
    "\n",
    "        if len(sum_sub_adj)>0:    \n",
    "            new_adj=np.stack((list_adj,sum_sub_adj),0)\n",
    "        else :\n",
    "            sum_sub_adj=np.zeros((1, list_adj.shape[0], list_adj.shape[1]))\n",
    "            new_adj=np.concatenate((list_adj.reshape(1, list_adj.shape[0], list_adj.shape[1]),sum_sub_adj),0)\n",
    "\n",
    "        #SUB_ADJ.append(new_adj)\n",
    "        SUB_ADJ=new_adj\n",
    "        #------합치기\n",
    "        data=dataset[i]\n",
    "        check1=torch.sum(data.edge_index[0]-np.where(SUB_ADJ[0]==1)[0])+torch.sum(data.edge_index[1]-np.where(SUB_ADJ[0]==1)[1])\n",
    "        if check1 != 0 :\n",
    "            print('error')\n",
    "\n",
    "        data.cycle_index=torch.stack((torch.LongTensor(np.where(SUB_ADJ[1]!=0)[0]), torch.LongTensor(np.where(SUB_ADJ[1]!=0)[1])),1).T.contiguous()\n",
    "        #data.cycle_attr = torch.FloatTensor(SUB_ADJ[1][np.where(SUB_ADJ[1]!=0)[0],np.where(SUB_ADJ[1]!=0)[1]]) \n",
    "        #FloatTensor 형태여야됨 \n",
    "        NEWDATA.append(data)\n",
    "        \n",
    "    torch.save(NEWDATA,file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "behind-absorption",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_class=[]\n",
    "for i in range(len(dataset)):\n",
    "    dataset_class.append(dataset[i].y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c074f9b2-2b5f-45ba-bc2b-d7017e5045d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_class=torch.FloatTensor(dataset_class).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "adjustable-humor",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StratifiedKFold(n_splits=10, random_state=None, shuffle=True)\n"
     ]
    }
   ],
   "source": [
    "folder = f'./dataset/{dataset_name}/kfold_data'\n",
    "if os.path.isdir(folder):\n",
    "    print('folder_on')\n",
    "    for j in range(10):\n",
    "        print(j)\n",
    "        test_index = torch.as_tensor(np.loadtxt(f'./dataset/{dataset_name}/kfold_data/test_idx-{j}.txt',dtype=np.int32), dtype=torch.long)\n",
    "        for k in range(10):\n",
    "            train_index = torch.as_tensor(np.loadtxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/train_idx-{k}.txt',dtype=np.int32), dtype=torch.long)\n",
    "            valid_index = torch.as_tensor(np.loadtxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/valid_idx-{k}.txt',dtype=np.int32), dtype=torch.long)    \n",
    "            all_index = reduce(np.union1d, (train_index, valid_index, test_index))\n",
    "            assert len(dataset) == len(all_index)\n",
    "\n",
    "else :        \n",
    "    os.makedirs(folder)\n",
    "    kkf=StratifiedKFold(n_splits=10, shuffle=True)\n",
    "    #kf = KFold(n_splits=10, shuffle=True)\n",
    "    kkf2=StratifiedKFold(n_splits=10, shuffle=True)\n",
    "    #kf2 = KFold(n_splits=10, shuffle=True)\n",
    "    kkf.get_n_splits(dataset,dataset_class)\n",
    "    print(kkf)\n",
    "    j=0\n",
    "    for train_total_index, test_index in kkf.split(dataset,dataset_class):\n",
    "        #print(train_index, test_index)\n",
    "        np.savetxt(f'./dataset/{dataset_name}/kfold_data/train_total_idx-{j}.txt',(train_total_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "        np.savetxt(f'./dataset/{dataset_name}/kfold_data/test_idx-{j}.txt',(test_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "        assert len(dataset)==len(reduce(np.union1d, (test_index, train_total_index)))\n",
    "        k=0\n",
    "        os.mkdir(f'./dataset/{dataset_name}/kfold_data/train_total_{j}') \n",
    "        \n",
    "        dataset_class_train=[]\n",
    "        dataset_train=[]\n",
    "        for i in train_total_index:\n",
    "            dataset_class_train.append(dataset[i].y)\n",
    "            dataset_train.append(dataset[i])\n",
    "        dataset_class_train=torch.FloatTensor(dataset_class_train).numpy()\n",
    "        dataset_train=(dataset_train)\n",
    "        kkf2.get_n_splits(train_total_index,dataset_class_train)\n",
    "        \n",
    "        for ii, jj in kkf2.split(dataset_train,dataset_class_train):        \n",
    "            valid_index=train_total_index[jj]\n",
    "            train_index=train_total_index[ii]\n",
    "            np.savetxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/valid_idx-{k}.txt',(valid_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "            np.savetxt(f'./dataset/{dataset_name}/kfold_data/train_total_{j}/train_idx-{k}.txt',(train_index.astype(np.int64)), fmt='%i', delimiter='\\t')\n",
    "            k+=1\n",
    "            assert len(train_total_index)==len(reduce(np.union1d, (valid_index, train_index)))\n",
    "        j+=1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "forced-composite",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:The OGB package is out of date. Your version is 1.3.3, while the latest version is 1.3.4.\n"
     ]
    }
   ],
   "source": [
    "from nets_attr import Cy2C_GCN_attr_1_concat,Cy2C_GCN_attr_3_concat\n",
    "from Trainer_CB_attr import Trainer\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "model_name=['Cy2C_GCN_attr_1_concat','Cy2C_GCN_attr_3_concat']\n",
    "class_name=[Cy2C_GCN_attr_1_concat,Cy2C_GCN_attr_3_concat]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "closed-conclusion",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "batch_size=32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aa68215-ebc8-4e90-8706-b29f97db3024",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cy2C_GCN_attr_1_concat_1_32_0.0(0.0)_0.0\n",
      "=====================================\n",
      "===== Cy2C_GCN_attr_1_concat_1_32_0.0(0.0)_0.0 ===== BZR_MD =====\n",
      "=====================================\n",
      "load mainfold, subfold== 0 0\n",
      "Mainfold_index: 0, Subfold_index:0\n",
      "main & sub ===0,0,best acc & loss==,0.6774,0.0191,final acc & loss==0.7419,0.0176,best_epoch==11,final_epoch==112\n",
      "load mainfold, subfold== 1 0\n",
      "Mainfold_index: 1, Subfold_index:0\n",
      "main & sub ===1,0,best acc & loss==,0.6452,0.0212,final acc & loss==0.5806,0.0259,best_epoch==43,final_epoch==144\n",
      "load mainfold, subfold== 2 0\n",
      "Mainfold_index: 2, Subfold_index:0\n",
      "main & sub ===2,0,best acc & loss==,0.8065,0.0174,final acc & loss==0.6774,0.0174,best_epoch==11,final_epoch==112\n",
      "load mainfold, subfold== 3 0\n"
     ]
    }
   ],
   "source": [
    "for i,CLASS_NAME in enumerate(class_name):\n",
    "    for hidden_dim in [32,64]:\n",
    "        for decay in [0.0,0.0001]:\n",
    "            for mid_drop in [0.0,0.2, 0.4]:\n",
    "                for dropout in [0.0, 0.2, 0.4]:\n",
    "                    for n_layer in [1,3,5]:\n",
    "                        name=f'{model_name[i]}_{n_layer}_{hidden_dim}_{dropout}({mid_drop})_{decay}'\n",
    "                        print(name)\n",
    "                        print('=====================================')\n",
    "                        print('=====',name,'=====',dataset_name,'=====')\n",
    "                        print('=====================================')\n",
    "                        trainer=Trainer(name, dataset_name,NEWDATA,device,CLASS_NAME,dataset.num_node_features,dataset.num_classes,batch_size=batch_size,lr=lr,hidden_dim=hidden_dim,n_layer=n_layer,num_workers=1,dropout=dropout,decay=decay,mid_drop=mid_drop)\n",
    "                        trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "quantitative-perspective",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "radio-logistics",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "brutal-young",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_yodecopy)",
   "language": "python",
   "name": "conda_yodecopy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
