{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceebbbe8-de9c-4a7f-86d2-df0e85f9ee39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['TORCH'] = torch.__version__\n",
    "print(torch.__version__)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37657006-4cd7-4978-9532-5c8551eb16c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.nn import Linear\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.transforms import Compose\n",
    "from torch_geometric.datasets import Amazon\n",
    "from torch_geometric.transforms.random_node_split import RandomNodeSplit\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.transforms import NormalizeFeatures\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn import GATConv\n",
    "from torch_geometric.loader import NeighborLoader\n",
    "from torch_geometric.nn import SAGEConv\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from torch_geometric.utils import negative_sampling\n",
    "from torch_geometric.utils import train_test_split_edges\n",
    "\n",
    "from copy import deepcopy\n",
    "import torch.nn as nn\n",
    "from IPython.display import Javascript  # Restrict height of output cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20d8e6d6-328c-488d-9da4-6af29c8358d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#dataset_name='Flickr'\n",
    "dataset_name='Amazon'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e44663f-d2b3-4bc2-98e0-81ec142f6fc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid, Flickr, Amazon\n",
    "from torch_geometric.transforms import NormalizeFeatures\n",
    "\n",
    "\n",
    "if dataset_name=='Flickr':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 2000, num_test = 10000)\n",
    "    ])\n",
    "    dataset = Flickr(root='data/Flickr', \\\n",
    "                     transform =transform)\n",
    "elif dataset_name=='Amazon':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 1000, num_test = 3000)\n",
    "    ])\n",
    "    dataset = Amazon(root='data/Amazon', name='Computers', \\\n",
    "                     transform =transform)\n",
    "\n",
    "\n",
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('======================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n",
    "data = dataset[0]  # Get the first graph object.\n",
    "\n",
    "print()\n",
    "print(data)\n",
    "print('===========================================================================================================')\n",
    "\n",
    "# Gather some statistics about the graph.\n",
    "print(f'Number of nodes: {data.num_nodes}')\n",
    "print(f'Number of edges: {data.num_edges}')\n",
    "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
    "print(f'Number of training nodes: {data.train_mask.sum()}')\n",
    "print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')\n",
    "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n",
    "print(f'Has self-loops: {data.has_self_loops()}')\n",
    "print(f'Is undirected: {data.is_undirected()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9104853-b271-4b21-95f6-60b35b4a7698",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import GATConv\n",
    "\n",
    "\n",
    "class GAT(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, out_channels, num_layers, heads=8, dropout=0.3):\n",
    "        super().__init__()\n",
    "        torch.manual_seed(np.random.choice(np.arange(1,100)))\n",
    "        self.feature_vals = {}\n",
    "        self.layers = nn.ModuleList()\n",
    "        self.num_layers = num_layers\n",
    "        self.hidden_channels = hidden_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.dropout = dropout\n",
    "        self.layers.append(GATConv(dataset.num_features, hidden_channels, heads=heads))\n",
    "        for i in range(num_layers-2):\n",
    "            self.layers.append(GATConv(hidden_channels*heads, hidden_channels, heads=heads))\n",
    "        self.layers.append(GATConv(hidden_channels*heads, out_channels, heads=1))\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        for i,layer in enumerate(self.layers):\n",
    "            x = layer(x, edge_index)\n",
    "            if i!= (len(self.layers)-1):\n",
    "                x = F.elu(x)\n",
    "                x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "            self.feature_vals['conv'+str(i)] = deepcopy(x.detach().cpu().numpy())\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "224a1b2a-30f4-4409-a208-6f6d1116979b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "      model.train()\n",
    "      optimizer.zero_grad()  # Clear gradients.\n",
    "      out = model(data.x, data.edge_index)  # Perform a single forward pass.\n",
    "      loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.\n",
    "      loss.backward()  # Derive gradients.\n",
    "      optimizer.step()  # Update parameters based on gradients.\n",
    "      return loss\n",
    "\n",
    "def test():\n",
    "      model.eval()\n",
    "      out = model(data.x, data.edge_index)\n",
    "      pred = out.argmax(dim=1)  # Use the class with highest probability.\n",
    "      test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.\n",
    "      test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.\n",
    "      return test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889a1dda-8931-4581-a01d-fa1fe5450719",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_ids=[1,2]\n",
    "for run_id in run_ids:\n",
    "    config = {\n",
    "        \"model_name\":\"GAT\",\n",
    "        \"task\":\"NC\",\n",
    "        \"run_id\":run_id,\n",
    "        \"dataset\":dataset_name\n",
    "    }\n",
    "    path = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"\n",
    "    !mkdir -p $path\n",
    "    path2 = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"+config['task']+\"_\"+str(config['run_id'])+\"*\"\n",
    "    !rm $path2\n",
    "    model = GAT(hidden_channels=16, out_channels=dataset.num_classes,num_layers=4,heads=8,dropout=0.2)\n",
    "    model, data = model.to(device), data.to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    loss_list = []\n",
    "    test_acc_list = []\n",
    "    for epoch in range(1, 101):\n",
    "        loss = train()\n",
    "        loss_list.append(loss)\n",
    "        test_acc = test()\n",
    "        test_acc_list.append(test_acc)\n",
    "        feature_vals = deepcopy(model.feature_vals)\n",
    "        feature_path =  path+config['task']+\"_\"+str(config['run_id'])+'_'+str(epoch)+'.npz'\n",
    "        np.savez(feature_path, **feature_vals)\n",
    "        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\n",
    "    #from matplotlib.pyplot import plt\n",
    "    plt.figure(figsize=(20,8))\n",
    "    plt.plot(test_acc_list)\n",
    "    plt.title(\"Accuracy Over Epochs\", fontsize=20)\n",
    "    plt.xlabel(\"Epochs\", fontsize=15)\n",
    "    plt.ylabel(\"Test Accuracy\", fontsize=15)\n",
    "    output_filename = f'GAT_NC_test_accuracy_{run_id}.png'\n",
    "    plt.savefig(output_filename, bbox_inches='tight')\n",
    "    plt.close()  # Close the plot to release resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd37a9e-2204-403a-9741-dee5865b0a53",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_acc = test()\n",
    "print(f'Test Accuracy: {test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6791cdd-b3b8-4a1d-ac20-212f994b70bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in model.feature_vals.keys():\n",
    "    print(model.feature_vals[k].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc282b17-d9c1-425e-944e-ca3b6b1e28ba",
   "metadata": {},
   "source": [
    "## Link Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f774331-cd22-4e0d-a00d-b6fd858fbb02",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid, Flickr, Amazon\n",
    "from torch_geometric.transforms import NormalizeFeatures\n",
    "\n",
    "\n",
    "if dataset_name=='Flickr':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 2000, num_test = 10000)\n",
    "    ])\n",
    "    dataset = Flickr(root='data/Flickr', \\\n",
    "                     transform =transform)\n",
    "elif dataset_name=='Amazon':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 1000, num_test = 3000)\n",
    "    ])\n",
    "    dataset = Amazon(root='data/Amazon', name='Computers', \\\n",
    "                     transform =transform)\n",
    "\n",
    "\n",
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('======================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n",
    "data = dataset[0]  # Get the first graph object.\n",
    "\n",
    "print()\n",
    "print(data)\n",
    "print('===========================================================================================================')\n",
    "\n",
    "# Gather some statistics about the graph.\n",
    "print(f'Number of nodes: {data.num_nodes}')\n",
    "print(f'Number of edges: {data.num_edges}')\n",
    "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
    "print(f'Number of training nodes: {data.train_mask.sum()}')\n",
    "print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')\n",
    "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n",
    "print(f'Has self-loops: {data.has_self_loops()}')\n",
    "print(f'Is undirected: {data.is_undirected()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "710cf45d-df5c-4824-8aed-2cb063bd78ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
    "data = train_test_split_edges(data)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "084693ce-2975-4f6d-9554-255a334f4555",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, out_channels, num_layers, heads=8, dropout=0.3):\n",
    "        super().__init__()\n",
    "        torch.manual_seed(np.random.choice(np.arange(1,100)))\n",
    "        self.feature_vals = {}\n",
    "        self.layers = nn.ModuleList()\n",
    "        self.num_layers = num_layers\n",
    "        self.hidden_channels = hidden_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.dropout = dropout\n",
    "        self.layers.append(GATConv(dataset.num_features, hidden_channels, heads=heads))\n",
    "        for i in range(num_layers-2):\n",
    "            self.layers.append(GATConv(hidden_channels*heads, hidden_channels, heads=heads))\n",
    "        self.layers.append(GATConv(hidden_channels*heads, out_channels, heads=heads))\n",
    "\n",
    "    def encode(self, x, edge_index):\n",
    "        for i,layer in enumerate(self.layers):\n",
    "            x = layer(x, edge_index)\n",
    "            if i!= (len(self.layers)-1):\n",
    "                x = F.elu(x, alpha=1)\n",
    "                x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "            self.feature_vals['conv'+str(i)] = deepcopy(x.detach().cpu().numpy())\n",
    "        return x\n",
    "\n",
    "\n",
    "    def decode(self, z, pos_edge_index, neg_edge_index): # only pos and neg edges\n",
    "        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges\n",
    "        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product \n",
    "        return logits\n",
    "\n",
    "    def decode_all(self, z): \n",
    "        prob_adj = z @ z.t() # get adj NxN\n",
    "        return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95e4cad-122b-4a57-9dd2-b9efbe4a5bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_link_labels(pos_edge_index, neg_edge_index):\n",
    "    # returns a tensor:\n",
    "    # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the length of pos_edge_index\n",
    "    # and the number of zeros is equal to the length of neg_edge_index\n",
    "    E = pos_edge_index.size(1) + neg_edge_index.size(1)\n",
    "    link_labels = torch.zeros(E, dtype=torch.float, device=device)\n",
    "    link_labels[:pos_edge_index.size(1)] = 1.\n",
    "    return link_labels\n",
    "\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "\n",
    "    neg_edge_index = negative_sampling(\n",
    "        edge_index=data.train_pos_edge_index, #positive edges\n",
    "        num_nodes=data.num_nodes, # number of nodes\n",
    "        num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    z = model.encode(data.x, data.train_pos_edge_index) #encode\n",
    "    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) # decode\n",
    "    \n",
    "    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)\n",
    "    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test():\n",
    "    model.eval()\n",
    "    perfs = []\n",
    "    for prefix in [\"val\", \"test\"]:\n",
    "        pos_edge_index = data[f'{prefix}_pos_edge_index']\n",
    "        neg_edge_index = data[f'{prefix}_neg_edge_index']\n",
    "\n",
    "        z = model.encode(data.x, data.train_pos_edge_index) # encode train\n",
    "        link_logits = model.decode(z, pos_edge_index, neg_edge_index) # decode test or val\n",
    "        link_probs = link_logits.sigmoid() # apply sigmoid\n",
    "        \n",
    "        link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link\n",
    "        \n",
    "        perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score\n",
    "    return perfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6aae1ed-1c1d-4e1b-81ae-d31b72e2cb4b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "run_ids=[1,2]\n",
    "for run_id in run_ids:\n",
    "    config = {\n",
    "        \"model_name\":\"GAT\",\n",
    "        \"task\":\"LP\",\n",
    "        \"run_id\":run_id,\n",
    "        \"dataset\":dataset_name\n",
    "    }\n",
    "    path = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"\n",
    "    !mkdir -p $path\n",
    "    path2 = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"+config['task']+\"_\"+str(config['run_id'])+\"*\"\n",
    "    !rm $path2\n",
    "    model, data = Net(hidden_channels=8,out_channels=8,num_layers=4, heads=8, dropout=0).to(device), data.to(device)\n",
    "    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)\n",
    "    best_val_perf = test_perf = 0\n",
    "    test_acc_list=[]\n",
    "    for epoch in range(1, 201):\n",
    "        train_loss = train()\n",
    "        val_perf, tmp_test_perf = test()\n",
    "        if val_perf > best_val_perf:\n",
    "            best_val_perf = val_perf\n",
    "            test_perf = tmp_test_perf\n",
    "        log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'\n",
    "        feature_vals = deepcopy(model.feature_vals)\n",
    "        feature_path =  path+config['task']+\"_\"+str(config['run_id'])+'_'+str(epoch)+'.npz'\n",
    "        np.savez(feature_path, **feature_vals)\n",
    "        test_acc_list.append(tmp_test_perf)\n",
    "        if epoch % 10 == 0:\n",
    "            print(log.format(epoch, train_loss, best_val_perf, test_perf))\n",
    "    #from matplotlib.pyplot import plt\n",
    "    plt.figure(figsize=(20,8))\n",
    "    plt.plot(test_acc_list)\n",
    "    plt.title(\"Accuracy Over Epochs\", fontsize=20)\n",
    "    plt.xlabel(\"Epochs\", fontsize=15)\n",
    "    plt.ylabel(\"Test Accuracy\", fontsize=15)\n",
    "    output_filename = f'GAT_LP_test_accuracy_{run_id}.png'\n",
    "    plt.savefig(output_filename, bbox_inches='tight')\n",
    "    plt.close()  # Close the plot to release resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0a0aee-5b24-4905-8f41-a309cc3ff216",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = model.encode(data.x, data.train_pos_edge_index)\n",
    "final_edge_index = model.decode_all(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d58991f4-95bb-4aca-baa5-14a04a558085",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in model.feature_vals.keys():\n",
    "    print(model.feature_vals[k].shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
