{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from os.path import dirname\n",
    "sys.path.append(dirname(__name__)+'..')\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "# %matplotlib widget\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "from torch_geometric.loader import DataLoader\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# from models.layer import MaxPooling, MaxPoolingX\n",
    "import torch_geometric\n",
    "import torch_geometric.transforms as T\n",
    "import tqdm\n",
    "\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import to_networkx\n",
    "\n",
    "\n",
    "from torch_geometric.utils import to_dense_batch, to_dense_adj, degree, add_remaining_self_loops, remove_self_loops\n",
    "from torch_geometric.nn.conv import SplineConv, SAGEConv\n",
    "from torch_geometric.nn.norm import BatchNorm\n",
    "from torch_geometric.transforms import Cartesian\n",
    "from torch_geometric.nn import GATConv\n",
    "\n",
    "import logging\n",
    "import time\n",
    "\n",
    "from UNET_attns.UNET_attn import Unet\n",
    "\n",
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "training size 800, testing set 200, Number of graphs per batch: 16\n"
     ]
    }
   ],
   "source": [
    "dataset = torch_geometric.datasets.TUDataset('../data/', 'IMDB-BINARY')\n",
    "\n",
    "idx = torch.randperm(len(dataset))\n",
    "cut_off = int(len(dataset) * 0.8)\n",
    "train_dataset = dataset[idx[:cut_off]]\n",
    "test_dataset = dataset[idx[cut_off:]]\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)\n",
    "# dataLoader\n",
    "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
    "print(f'training size {len(train_dataset)}, testing set {len(test_dataset)}, Number of graphs per batch: {next(iter(train_loader)).y.shape[0]}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "feature_dim: 135.0\n"
     ]
    }
   ],
   "source": [
    "feature_dim = 0\n",
    "for g in dataset:\n",
    "    deg = degree(g.edge_index[0], g.num_nodes)\n",
    "    feature_dim = max(feature_dim, deg.max().item())\n",
    "print(f'feature_dim: {feature_dim}')\n",
    "FEAT_DIM = int(feature_dim)+1 if feature_dim < 400 else 400\n",
    "\n",
    "def gen_node_feature(data):\n",
    "    deg = degree(data.edge_index[0], data.num_nodes).long()\n",
    "    deg[deg >= FEAT_DIM] = FEAT_DIM - 1\n",
    "    feat = F.one_hot(deg, num_classes=FEAT_DIM).float()\n",
    "    return feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_metrics(true_labels, predicted_labels):\n",
    "    ra = roc_auc_score(true_labels, predicted_labels)\n",
    "    return ra\n",
    "\n",
    "def evaluate(model, train_loader, device):\n",
    "    model.eval()\n",
    "\n",
    "    trueLabel = []\n",
    "    predLabel = []\n",
    "    with torch.no_grad():\n",
    "        # for data in tqdm.tqdm(train_loader):\n",
    "        for data in train_loader:\n",
    "            data = data.to(device)\n",
    "            data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "            hs, mask = to_dense_batch(data.x, data.batch)\n",
    "            hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "            gs = to_dense_adj(data.edge_index, data.batch)\n",
    "            gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "            _, o_gs = model(gs, hs)\n",
    "\n",
    "            for og, g in zip(o_gs, gs):\n",
    "                # og = (torch.sign(og-0.5)+1)/2\n",
    "                trueLabel += g.int().cpu().numpy().flatten().tolist()\n",
    "                predLabel += og.cpu().numpy().flatten().tolist()\n",
    "    trueLabel = np.array(trueLabel)\n",
    "    predLabel = np.array(predLabel)\n",
    "    return calculate_metrics(trueLabel, predLabel)\n",
    "\n",
    "class ARG():\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# variational"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GraphUNET.ops import GCN, Pool, norm_g, Initializer, Unpool\n",
    "from UNET_attns.UNET_attn import NewEncoder, Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Unet(nn.Module):\n",
    "\t'''\n",
    "\ttwo-way network\n",
    "\t'''\n",
    "\tdef __init__(self, in_dim=None, args=None, s_gcn_state=None, encoder_state=None, s_ln_state=None) -> None:\n",
    "\t\tsuper(Unet, self).__init__()\n",
    "\t\tself.act = getattr(nn, args.act)()\n",
    "\t\tself.mask_ratio = args.mask_ratio\n",
    "\n",
    "\t\tself.s_gcn = GCN(in_dim, args.dim, self.act, args.drop_p)\n",
    "\t\tself.s_ln = nn.LayerNorm(args.dim)\n",
    "\t\tif s_gcn_state:\n",
    "\t\t\tself.s_gcn.load_state_dict(s_gcn_state)\n",
    "\t\t\tfor param in self.s_gcn.parameters(): # freeze the grad of source gcn\n",
    "\t\t\t\tparam.requires_grad = False\n",
    "\t\tif s_ln_state:\n",
    "\t\t\tself.s_ln.load_state_dict(s_ln_state)\n",
    "\t\t\tfor param in self.s_ln.parameters(): # freeze the grad of source gcn\n",
    "\t\t\t\tparam.requires_grad = False\n",
    "\n",
    "\t\tself.g_enc = NewEncoder(args.ks, args.dim, self.act, args.drop_p)\n",
    "\t\tif encoder_state:\n",
    "\t\t\tself.g_enc.load_state_dict(encoder_state)\n",
    "\t\t\tfor param in self.g_enc.parameters(): # freeze the grad of encoder\n",
    "\t\t\t\tparam.requires_grad = False\n",
    "\n",
    "\t\tself.bot_gcn = GCN(args.dim, args.dim, self.act, args.drop_p)\n",
    "\t\tself.bot_ln = nn.LayerNorm(args.dim)\n",
    "\t\tself.g_dec_mean1 = Decoder(args.ks, args.dim, self.act, args.drop_p)\n",
    "\t\tself.g_dec_logstd1 = Decoder(args.ks, args.dim, self.act, args.drop_p)\n",
    "\t\tself.g_dec_mean2 = Decoder(args.ks, args.dim, self.act, args.drop_p)\n",
    "\t\tself.g_dec_logstd2 = Decoder(args.ks, args.dim, self.act, args.drop_p)\n",
    "\t\n",
    "\tdef forward(self, gs, hs):\n",
    "\t\to_gs = self.embed(gs, hs)\n",
    "\t\treturn self.customBCE(o_gs, gs), o_gs\n",
    "\t\n",
    "\tdef embed(self, gs, hs):\n",
    "\t\to_gs = []\n",
    "\t\tfor g, h in zip(gs, hs):\n",
    "\t\t\tog = self.embed_one(g, h)\n",
    "\t\t\to_gs.append(og)\n",
    "\t\treturn o_gs\n",
    "\n",
    "\tdef embed_one(self, g, h):\n",
    "\t\tg = norm_g(g)\n",
    "\t\th = self.s_gcn(g, h)\n",
    "\t\th = self.s_ln(h)\n",
    "\t\tori_h = h\n",
    "\t\tg, h, adj_ms, down_outs, indices_list = self.g_enc(g, h)\n",
    "\n",
    "\t\tg = norm_g(g)\n",
    "\t\th = self.bot_gcn(g, h)\n",
    "\t\th = self.bot_ln(h)\n",
    "\t\tmean1 = self.g_dec_mean1(h, ori_h, down_outs, adj_ms, indices_list)\n",
    "\t\tstd1 = self.g_dec_logstd1(h, ori_h, down_outs, adj_ms, indices_list)\n",
    "\t\tmean2 = self.g_dec_mean2(h, ori_h, down_outs, adj_ms, indices_list)\n",
    "\t\tstd2 = self.g_dec_logstd2(h, ori_h, down_outs, adj_ms, indices_list)\n",
    "\n",
    "\t\th1 = torch.randn(mean1.size()).to(device)\n",
    "\t\th1 = h1*torch.exp(std1) + mean1\n",
    "\t\th2 = torch.randn(mean2.size()).to(device)\n",
    "\t\th2 = h2*torch.exp(std2) + mean2\n",
    "\n",
    "\t\th = (h1 @ h2.T)\n",
    "\t\treturn torch.sigmoid((h+h.T)/2)\n",
    "\n",
    "\tdef customBCE(self, o_gs, gs):\n",
    "\t\tloss = 0\n",
    "\t\tcnts = 0\n",
    "\t\tfor og, g in zip(o_gs, gs):\n",
    "\t\t\ttn = g.numel()\n",
    "\t\t\tzeros = tn - g.sum()\n",
    "\t\t\tones = g.sum()\n",
    "\t\t\tone_weight = tn / 2 / ones\n",
    "\t\t\tzero_weight = tn / 2 / zeros\n",
    "\t\t\tweights = torch.where(g == 0, zero_weight, one_weight)\n",
    "\t\t\tloss += F.binary_cross_entropy(og, g, weight=weights)\n",
    "\t\t\tcnts += 1\n",
    "\t\tloss /= cnts\n",
    "\t\treturn loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader: DataLoader, optimizer, scheduler, device, epoch):\n",
    "    global writer\n",
    "    model.train()\n",
    "\n",
    "    total_loss = 0\n",
    "    for data in train_loader:\n",
    "        data.x = gen_node_feature(data)\n",
    "        data = data.to(device)\n",
    "        data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "        optimizer.zero_grad()\n",
    "        label = data.y\n",
    "        hs, mask = to_dense_batch(data.x, data.batch)\n",
    "        hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "        gs = to_dense_adj(data.edge_index, data.batch)\n",
    "        gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "        loss, _ = model(gs, hs)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "    scheduler.step()\n",
    "\n",
    "def calculate_metrics(true_labels, predicted_labels):\n",
    "    ra = roc_auc_score(true_labels, predicted_labels)\n",
    "    return ra\n",
    "\n",
    "def evaluate(model, train_loader, device):\n",
    "    model.eval()\n",
    "\n",
    "    trueLabel = []\n",
    "    predLabel = []\n",
    "    with torch.no_grad():\n",
    "        for data in train_loader:\n",
    "            data.x = gen_node_feature(data)\n",
    "            data = data.to(device)\n",
    "            data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "            hs, mask = to_dense_batch(data.x, data.batch)\n",
    "            hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "            gs = to_dense_adj(data.edge_index, data.batch)\n",
    "            gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "            _, o_gs = model(gs, hs)\n",
    "\n",
    "            for og, g in zip(o_gs, gs):\n",
    "                trueLabel += g.int().cpu().numpy().flatten().tolist()\n",
    "                predLabel += og.cpu().numpy().flatten().tolist()\n",
    "    trueLabel = np.array(trueLabel)\n",
    "    predLabel = np.array(predLabel)\n",
    "    return (trueLabel==predLabel).sum()/trueLabel.size, calculate_metrics(trueLabel, predLabel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ARG():\n",
    "    pass\n",
    "\n",
    "args = ARG()\n",
    "args.input_dim = FEAT_DIM\n",
    "args.act = 'GELU'\n",
    "args.dim = 128\n",
    "args.drop_p = 0\n",
    "args.mask_ratio = 0.5\n",
    "args.ks = [0.9, 0.8, 0.7, 0.6, 0.5] + [0]\n",
    "num_epoch = 200\n",
    "\n",
    "model = Unet(in_dim= args.input_dim, args=args).to(device)\n",
    "optimizer = torch.optim.AdamW(model.parameters())\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(filename='resLog/result_variational.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
    "\n",
    "train_acc = []\n",
    "test_acc = []\n",
    "train_acc.append(evaluate(model, train_loader, device))\n",
    "test_acc.append(evaluate(model, test_loader, device))\n",
    "logging.info(f'Epoch [0/{num_epoch}], acc_train: {train_acc[-1][0]:.6f}, acc_test: {test_acc[-1][0]:.6f}, roc_train: {train_acc[-1][-1]:.6f}, roc_test: {test_acc[-1][-1]:.6f}')\n",
    "for epoch in range(num_epoch):\n",
    "    train(model, train_loader, optimizer, scheduler, device, epoch)\n",
    "    train_acc.append(evaluate(model, train_loader, device))\n",
    "    test_acc.append(evaluate(model, test_loader, device))\n",
    "    logging.info(f'Epoch [{epoch+1}/{num_epoch}], acc_train: {train_acc[-1][0]:.6f}, acc_test: {test_acc[-1][0]:.6f}, roc_train: {train_acc[-1][-1]:.6f}, roc_test: {test_acc[-1][-1]:.6f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# edge masking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GraphUNET.ops import GCN, Pool, norm_g, Initializer, Unpool\n",
    "from UNET_attns.UNET_attn import Unet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class newUnet(Unet):\n",
    "    def forward(self, gs, hs, ori_gs):\n",
    "        o_gs = self.embed(gs, hs)\n",
    "        return self.customBCE(o_gs, ori_gs), o_gs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_undirected_mask(gs, mask_ratio):\n",
    "    new_gs = []\n",
    "    for g in gs:\n",
    "        mask = (torch.rand_like(g) > mask_ratio).float()\n",
    "        mask = torch.triu(mask, diagonal=1).to(device)\n",
    "        mask = mask + mask.T + torch.eye(mask.size(0)).to(device)\n",
    "        g = g * mask\n",
    "        new_gs.append(g)\n",
    "    return new_gs\n",
    "\n",
    "def train(model, train_loader: DataLoader, optimizer, scheduler, device, mask_ratio, epoch):\n",
    "    global writer\n",
    "    model.train()\n",
    "\n",
    "    total_loss = 0\n",
    "    for data in train_loader:\n",
    "        data.x = gen_node_feature(data)\n",
    "        data = data.to(device)\n",
    "        data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "        optimizer.zero_grad()\n",
    "        label = data.y\n",
    "        hs, mask = to_dense_batch(data.x, data.batch)\n",
    "        hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "        gs = to_dense_adj(data.edge_index, data.batch)\n",
    "        gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "        new_gs = apply_undirected_mask(gs, mask_ratio)\n",
    "        loss, _ = model(new_gs, hs, new_gs)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "    scheduler.step()\n",
    "\n",
    "def calculate_metrics(true_labels, predicted_labels):\n",
    "    ra = roc_auc_score(true_labels, predicted_labels)\n",
    "    return ra\n",
    "\n",
    "def evaluate(model, train_loader, device):\n",
    "    model.eval()\n",
    "\n",
    "    trueLabel = []\n",
    "    predLabel = []\n",
    "    with torch.no_grad():\n",
    "        for data in train_loader:\n",
    "            data.x = gen_node_feature(data)\n",
    "            data = data.to(device)\n",
    "            data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "            hs, mask = to_dense_batch(data.x, data.batch)\n",
    "            hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "            gs = to_dense_adj(data.edge_index, data.batch)\n",
    "            gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "            _, o_gs = model(gs, hs, gs)\n",
    "\n",
    "            for og, g in zip(o_gs, gs):\n",
    "                trueLabel += g.int().cpu().numpy().flatten().tolist()\n",
    "                predLabel += og.cpu().numpy().flatten().tolist()\n",
    "    trueLabel = np.array(trueLabel)\n",
    "    predLabel = np.array(predLabel)\n",
    "    return (trueLabel==predLabel).sum()/trueLabel.size, calculate_metrics(trueLabel, predLabel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ARG():\n",
    "    pass\n",
    "\n",
    "args = ARG()\n",
    "args.input_dim = FEAT_DIM\n",
    "args.act = 'GELU'\n",
    "args.dim = 128\n",
    "args.drop_p = 0\n",
    "args.mask_ratio = 0.5\n",
    "args.ks = [0.9, 0.8, 0.7, 0.6, 0.5] + [0]\n",
    "\n",
    "\n",
    "num_epoch = 200\n",
    "\n",
    "model = newUnet(in_dim=args.input_dim, args=args).to(device)\n",
    "optimizer = torch.optim.AdamW(model.parameters())\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(filename='resLog/result_masking.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
    "\n",
    "train_acc = []\n",
    "test_acc = []\n",
    "train_acc.append(evaluate(model, train_loader, device))\n",
    "test_acc.append(evaluate(model, test_loader, device))\n",
    "logging.info(f'Epoch [0/{num_epoch}], acc_train: {train_acc[-1][0]:.6f}, acc_test: {test_acc[-1][0]:.6f}, roc_train: {train_acc[-1][-1]:.6f}, roc_test: {test_acc[-1][-1]:.6f}')\n",
    "for epoch in range(num_epoch):\n",
    "    train(model, train_loader, optimizer, scheduler, device, 0.5, epoch)\n",
    "    train_acc.append(evaluate(model, train_loader, device))\n",
    "    test_acc.append(evaluate(model, test_loader, device))\n",
    "    logging.info(f'Epoch [{epoch+1}/{num_epoch}], acc_train: {train_acc[-1][0]:.6f}, acc_test: {test_acc[-1][0]:.6f}, roc_train: {train_acc[-1][-1]:.6f}, roc_test: {test_acc[-1][-1]:.6f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L2-norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GraphUNET.ops import GCN, Pool, norm_g, Initializer, Unpool\n",
    "from UNET_attns.UNET_attn import Unet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class newUnet(Unet):\n",
    "    def embed_one(self, g, h):\n",
    "        g = norm_g(g)\n",
    "        h = self.s_gcn(g, h)\n",
    "        h = self.s_ln(h)\n",
    "        ori_h = h\n",
    "        g, h, adj_ms, down_outs, indices_list = self.g_enc(g, h)\n",
    "\n",
    "        g = norm_g(g)\n",
    "        h = self.bot_gcn(g, h)\n",
    "        h = self.bot_ln(h)\n",
    "        h1 = self.g_dec1(h, ori_h, down_outs, adj_ms, indices_list)\n",
    "        h2 = self.g_dec2(h, ori_h, down_outs, adj_ms, indices_list)\n",
    "        h = torch.cdist(h1, h2, p=2)\n",
    "        h = 10 * (1-h)\n",
    "\n",
    "        return torch.sigmoid(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader: DataLoader, optimizer, scheduler, device, mask_ratio, epoch):\n",
    "    global writer\n",
    "    model.train()\n",
    "\n",
    "    total_loss = 0\n",
    "    for data in train_loader:\n",
    "        data.x = gen_node_feature(data)\n",
    "        data = data.to(device)\n",
    "        data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "        optimizer.zero_grad()\n",
    "        label = data.y\n",
    "        # data.x = degree(data.edge_index[1], num_nodes=data.num_nodes).unsqueeze_(-1)\n",
    "        hs, mask = to_dense_batch(data.x, data.batch)\n",
    "        hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "        gs = to_dense_adj(data.edge_index, data.batch)\n",
    "        gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "        loss, _ = model(gs, hs)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "    scheduler.step()\n",
    "\n",
    "def calculate_metrics(true_labels, predicted_labels):\n",
    "    ra = roc_auc_score(true_labels, predicted_labels)\n",
    "    return ra\n",
    "\n",
    "def evaluate(model, train_loader, device):\n",
    "    model.eval()\n",
    "\n",
    "    trueLabel = []\n",
    "    predLabel = []\n",
    "    with torch.no_grad():\n",
    "        # for data in tqdm.tqdm(train_loader):\n",
    "        for data in train_loader:\n",
    "            data.x = gen_node_feature(data)\n",
    "            data = data.to(device)\n",
    "            data.edge_index = add_remaining_self_loops(data.edge_index)[0]\n",
    "            hs, mask = to_dense_batch(data.x, data.batch)\n",
    "            hs = [hs[i][mask[i]] for i in range(len(hs))]\n",
    "            gs = to_dense_adj(data.edge_index, data.batch)\n",
    "            gs = [gs[i][:len(hs[i]), :len(hs[i])] for i in range(len(gs))]\n",
    "            _, o_gs = model(gs, hs)\n",
    "\n",
    "            for og, g in zip(o_gs, gs):\n",
    "                # og = (torch.sign(og-0.5)+1)/2\n",
    "                trueLabel += g.int().cpu().numpy().flatten().tolist()\n",
    "                predLabel += og.cpu().numpy().flatten().tolist()\n",
    "    trueLabel = np.array(trueLabel)\n",
    "    predLabel = np.array(predLabel)\n",
    "    return (trueLabel==predLabel).sum()/trueLabel.size, calculate_metrics(trueLabel, predLabel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ARG():\n",
    "    pass\n",
    "\n",
    "args = ARG()\n",
    "args.input_dim = FEAT_DIM\n",
    "args.act = 'GELU'\n",
    "args.dim = 128\n",
    "args.drop_p = 0\n",
    "args.mask_ratio = 0.5\n",
    "args.ks = [0.9, 0.8, 0.7, 0.6, 0.5] + [0]\n",
    "\n",
    "\n",
    "num_epoch = 200\n",
    "\n",
    "model = newUnet(in_dim=args.input_dim, args=args).to(device)\n",
    "optimizer = torch.optim.AdamW(model.parameters())\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(filename='resLog/result_l2norm.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
    "\n",
    "train_acc = []\n",
    "test_acc = []\n",
    "train_acc.append(evaluate(model, train_loader, device))\n",
    "test_acc.append(evaluate(model, test_loader, device))\n",
    "logging.info(f'Epoch [0/{num_epoch}], acc_train: {train_acc[-1][0]:.6f}, acc_test: {test_acc[-1][0]:.6f}, roc_train: {train_acc[-1][-1]:.6f}, roc_test: {test_acc[-1][-1]:.6f}')\n",
    "for epoch in range(num_epoch):\n",
    "    train(model, train_loader, optimizer, scheduler, device, 0.5, epoch)\n",
    "    train_acc.append(evaluate(model, train_loader, device))\n",
    "    test_acc.append(evaluate(model, test_loader, device))\n",
    "    logging.info(f'Epoch [{epoch+1}/{num_epoch}], acc_train: {train_acc[-1][0]:.6f}, acc_test: {test_acc[-1][0]:.6f}, roc_train: {train_acc[-1][-1]:.6f}, roc_test: {test_acc[-1][-1]:.6f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EventGNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
