{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceebbbe8-de9c-4a7f-86d2-df0e85f9ee39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['TORCH'] = torch.__version__\n",
    "print(torch.__version__)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25dfc940-0467-4a47-8b46-66301404bd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.13.0+cu117.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37657006-4cd7-4978-9532-5c8551eb16c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.nn import Linear\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.transforms import Compose\n",
    "from torch_geometric.datasets import Amazon\n",
    "from torch_geometric.transforms.random_node_split import RandomNodeSplit\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.transforms import NormalizeFeatures\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn import GATConv\n",
    "from torch_geometric.loader import NeighborLoader\n",
    "from torch_geometric.nn import SAGEConv\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from torch_geometric.utils import negative_sampling\n",
    "from torch_geometric.utils import train_test_split_edges\n",
    "\n",
    "from copy import deepcopy\n",
    "import torch.nn as nn\n",
    "from IPython.display import Javascript  # Restrict height of output cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e093eb-2f85-49d1-ba41-c0e784889f14",
   "metadata": {},
   "outputs": [],
   "source": [
    "#dataset_name='Flickr'\n",
    "dataset_name='Amazon'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e44663f-d2b3-4bc2-98e0-81ec142f6fc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid, Flickr, Amazon\n",
    "from torch_geometric.transforms import NormalizeFeatures\n",
    "\n",
    "\n",
    "if dataset_name=='Flickr':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 2000, num_test = 10000)\n",
    "    ])\n",
    "    dataset = Flickr(root='data/Flickr', \\\n",
    "                     transform =transform)\n",
    "elif dataset_name=='Amazon':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 1000, num_test = 3000)\n",
    "    ])\n",
    "    dataset = Amazon(root='data/Amazon', name='Computers', \\\n",
    "                     transform =transform)\n",
    "\n",
    "\n",
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('======================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n",
    "data = dataset[0]  # Get the first graph object.\n",
    "\n",
    "print()\n",
    "print(data)\n",
    "print('===========================================================================================================')\n",
    "\n",
    "# Gather some statistics about the graph.\n",
    "print(f'Number of nodes: {data.num_nodes}')\n",
    "print(f'Number of edges: {data.num_edges}')\n",
    "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
    "print(f'Number of training nodes: {data.train_mask.sum()}')\n",
    "print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')\n",
    "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n",
    "print(f'Has self-loops: {data.has_self_loops()}')\n",
    "print(f'Is undirected: {data.is_undirected()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e541c9c-e13c-4348-9151-b3f7d8d62491",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.loader import ClusterData, ClusterLoader\n",
    "seed=12345\n",
    "torch.manual_seed(seed)\n",
    "cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.\n",
    "train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)  # 2. Stochastic partioning scheme.\n",
    "\n",
    "print()\n",
    "total_num_nodes = 0\n",
    "labels=[]\n",
    "for step, sub_data in enumerate(train_loader):\n",
    "    print(f'Step {step + 1}:')\n",
    "    print('=======')\n",
    "    print(f'Number of nodes in the current batch: {sub_data.num_nodes}')\n",
    "    print(sub_data)\n",
    "    print()\n",
    "    total_num_nodes += sub_data.num_nodes\n",
    "    labels.append(sub_data.y)\n",
    "\n",
    "print(f'Iterated over {total_num_nodes} of {data.num_nodes} nodes!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9104853-b271-4b21-95f6-60b35b4a7698",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CGCN(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, out_channels, num_layers, dropout=0.3):\n",
    "        super().__init__()\n",
    "        torch.manual_seed(np.random.choice(np.arange(1,100)))\n",
    "        self.feature_vals = {}\n",
    "        self.layers = nn.ModuleList()\n",
    "        self.num_layers = num_layers\n",
    "        self.hidden_channels = hidden_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.dropout = dropout\n",
    "        self.layers.append(GCNConv(dataset.num_features, hidden_channels))\n",
    "        for i in range(num_layers-2):\n",
    "            self.layers.append(GCNConv(hidden_channels, hidden_channels))\n",
    "        self.layers.append(GCNConv(hidden_channels, out_channels))\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        for i,layer in enumerate(self.layers):\n",
    "            x = layer(x, edge_index)\n",
    "            if i!= (len(self.layers)-1):\n",
    "                x = x.relu()\n",
    "                x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "        return x\n",
    "\n",
    "    def inference(self, x, edge_index):\n",
    "        for i,layer in enumerate(self.layers):\n",
    "            x = layer(x, edge_index)\n",
    "            if i!= (len(self.layers)-1):\n",
    "                x = x.relu()\n",
    "                x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "            self.feature_vals['conv'+str(i)] = deepcopy(x.detach().cpu().numpy())\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3640955e-9147-417b-b6da-6323b030de13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    final_loss = 0\n",
    "    for sub_data in train_loader:\n",
    "        sub_data = sub_data.to(device)\n",
    "        out = model(sub_data.x, sub_data.edge_index)\n",
    "        loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        final_loss+=loss\n",
    "        \n",
    "    return final_loss\n",
    "\n",
    "def test():\n",
    "    model.eval()\n",
    "    out = model.inference(data.x, data.edge_index)\n",
    "    pred = out.argmax(dim=1)  # Use the class with highest probability.\n",
    "    accs = []\n",
    "    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n",
    "      correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.\n",
    "      accs.append(int(correct.sum()) / int(mask.sum()))  # Derive ratio of correct predictions.\n",
    "    return accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889a1dda-8931-4581-a01d-fa1fe5450719",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_ids=[1,2]\n",
    "for run_id in run_ids:\n",
    "    config = {\n",
    "        \"model_name\":\"CGCN\",\n",
    "        \"task\":\"NC\",\n",
    "        \"run_id\":run_id,\n",
    "        \"dataset\":dataset_name\n",
    "    }\n",
    "    path = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"\n",
    "    !mkdir -p $path\n",
    "    path2 = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"+config['task']+\"_\"+str(config['run_id'])+\"*\"\n",
    "    !rm $path2\n",
    "    model = CGCN(hidden_channels=128, out_channels=dataset.num_classes,num_layers=4,dropout=0.2)\n",
    "    model, data = model.to(device), data.to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    loss_list = []\n",
    "    test_acc_list = []\n",
    "    for epoch in range(1, 101):\n",
    "        loss = train()\n",
    "        loss_list.append(loss)\n",
    "        train_acc, val_acc, test_acc = test()\n",
    "        test_acc_list.append(test_acc)\n",
    "        feature_vals = deepcopy(model.feature_vals)\n",
    "        feature_path =  path+config['task']+\"_\"+str(config['run_id'])+'_'+str(epoch)+'.npz'\n",
    "        np.savez(feature_path, **feature_vals)\n",
    "    \n",
    "        print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')\n",
    "        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\n",
    "    #from matplotlib.pyplot import plt\n",
    "    plt.figure(figsize=(20,8))\n",
    "    plt.plot(test_acc_list)\n",
    "    plt.title(\"Accuracy Over Epochs\", fontsize=20)\n",
    "    plt.xlabel(\"Epochs\", fontsize=15)\n",
    "    plt.ylabel(\"Test Accuracy\", fontsize=15)\n",
    "    output_filename = f'CGCN_NC_test_accuracy_{run_id}.png'\n",
    "    plt.savefig(output_filename, bbox_inches='tight')\n",
    "    plt.close()  # Close the plot to release resources\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd37a9e-2204-403a-9741-dee5865b0a53",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc, val_acc, test_acc = test()\n",
    "print(f'Test Accuracy: {test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6791cdd-b3b8-4a1d-ac20-212f994b70bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in model.feature_vals.keys():\n",
    "    print(model.feature_vals[k].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc282b17-d9c1-425e-944e-ca3b6b1e28ba",
   "metadata": {},
   "source": [
    "## Link Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fef6a49c-192f-4ace-b629-981c625e4169",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "import warnings\n",
    "\n",
    "# Use the warnings filter to ignore specific warning categories or all warnings\n",
    "# To ignore all warnings (not recommended for production code):\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f774331-cd22-4e0d-a00d-b6fd858fbb02",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid, Flickr, Amazon\n",
    "from torch_geometric.transforms import NormalizeFeatures\n",
    "\n",
    "\n",
    "if dataset_name=='Flickr':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 2000, num_test = 10000)\n",
    "    ])\n",
    "    dataset = Flickr(root='data/Flickr', \\\n",
    "                     transform =transform)\n",
    "elif dataset_name=='Amazon':\n",
    "    transform = Compose([\n",
    "        #NormalizeFeatures(),\n",
    "        RandomNodeSplit('train_rest',num_val = 1000, num_test = 3000)\n",
    "    ])\n",
    "    dataset = Amazon(root='data/Amazon', name='Computers', \\\n",
    "                     transform =transform)\n",
    "\n",
    "\n",
    "print()\n",
    "print(f'Dataset: {dataset}:')\n",
    "print('======================')\n",
    "print(f'Number of graphs: {len(dataset)}')\n",
    "print(f'Number of features: {dataset.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')\n",
    "\n",
    "data = dataset[0]  # Get the first graph object.\n",
    "\n",
    "print()\n",
    "print(data)\n",
    "print('===========================================================================================================')\n",
    "\n",
    "# Gather some statistics about the graph.\n",
    "print(f'Number of nodes: {data.num_nodes}')\n",
    "print(f'Number of edges: {data.num_edges}')\n",
    "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
    "print(f'Number of training nodes: {data.train_mask.sum()}')\n",
    "print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')\n",
    "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n",
    "print(f'Has self-loops: {data.has_self_loops()}')\n",
    "print(f'Is undirected: {data.is_undirected()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "262d029e-4c5b-4dfc-bf46-13d1811a8c4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.loader import ClusterData, ClusterLoader\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.\n",
    "train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)  # 2. Stochastic partioning scheme.\n",
    "\n",
    "print()\n",
    "total_num_nodes = 0\n",
    "labels=[]\n",
    "for step, sub_data in enumerate(train_loader):\n",
    "    print(f'Step {step + 1}:')\n",
    "    print('=======')\n",
    "    print(f'Number of nodes in the current batch: {sub_data.num_nodes}')\n",
    "    print(sub_data)\n",
    "    print()\n",
    "    total_num_nodes += sub_data.num_nodes\n",
    "    labels.append(sub_data.y)\n",
    "\n",
    "print(f'Iterated over {total_num_nodes} of {data.num_nodes} nodes!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "710cf45d-df5c-4824-8aed-2cb063bd78ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
    "data = train_test_split_edges(data)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "084693ce-2975-4f6d-9554-255a334f4555",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, out_channels, num_layers, dropout=0.3):\n",
    "        super(Net, self).__init__()\n",
    "        super().__init__()\n",
    "        torch.manual_seed(np.random.choice(np.arange(1,100)))\n",
    "        self.feature_vals = {}\n",
    "        self.layers = nn.ModuleList()\n",
    "        self.num_layers = num_layers\n",
    "        self.hidden_channels = hidden_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.dropout = dropout\n",
    "        self.layers.append(GCNConv(dataset.num_features, hidden_channels))\n",
    "        for i in range(num_layers-2):\n",
    "            self.layers.append(GCNConv(hidden_channels, hidden_channels))\n",
    "        self.layers.append(GCNConv(hidden_channels, out_channels))\n",
    "\n",
    "    def encode(self, x, edge_index):\n",
    "        for i,layer in enumerate(self.layers):\n",
    "            x = layer(x, edge_index)\n",
    "            if i!= (len(self.layers)-1):\n",
    "                x = x.relu()\n",
    "                x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "            self.feature_vals['conv'+str(i)] = deepcopy(x.detach().cpu().numpy())\n",
    "        return x\n",
    "\n",
    "\n",
    "    def encode_infer(self, x, edge_index):\n",
    "        for i,layer in enumerate(self.layers):\n",
    "            x = layer(x, edge_index)\n",
    "            if i!= (len(self.layers)-1):\n",
    "                x = x.relu()\n",
    "                x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "            self.feature_vals['conv'+str(i)] = deepcopy(x.detach().cpu().numpy())\n",
    "        return x\n",
    "\n",
    "    def decode(self, z, pos_edge_index, neg_edge_index): # only pos and neg edges\n",
    "        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges\n",
    "        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product \n",
    "        return logits\n",
    "\n",
    "    def decode_all(self, z): \n",
    "        prob_adj = z @ z.t() # get adj NxN\n",
    "        return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95e4cad-122b-4a57-9dd2-b9efbe4a5bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_link_labels(pos_edge_index, neg_edge_index):\n",
    "    # returns a tensor:\n",
    "    # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the length of pos_edge_index\n",
    "    # and the number of zeros is equal to the length of neg_edge_index\n",
    "    E = pos_edge_index.size(1) + neg_edge_index.size(1)\n",
    "    link_labels = torch.zeros(E, dtype=torch.float, device=device)\n",
    "    link_labels[:pos_edge_index.size(1)] = 1.\n",
    "    return link_labels\n",
    "\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "    final_loss=0\n",
    "    for sub_data in train_loader:  # Iterate over each mini-batch.\n",
    "        sub_data = sub_data.to(device)\n",
    "        sub_data.train_mask = sub_data.val_mask = sub_data.test_mask = sub_data.y = None\n",
    "        sub_data = train_test_split_edges(sub_data)\n",
    "        \n",
    "        neg_edge_index = negative_sampling(\n",
    "            edge_index=sub_data.train_pos_edge_index, #positive edges\n",
    "            num_nodes=sub_data.num_nodes, # number of nodes\n",
    "            num_neg_samples=sub_data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "    \n",
    "        z = model.encode(sub_data.x, sub_data.train_pos_edge_index) #encode\n",
    "        link_logits = model.decode(z, sub_data.train_pos_edge_index, neg_edge_index) # decode\n",
    "    \n",
    "        link_labels = get_link_labels(sub_data.train_pos_edge_index, neg_edge_index)\n",
    "        loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        final_loss+=loss\n",
    "    return final_loss\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test():\n",
    "    model.eval()\n",
    "    perfs = []\n",
    "    z = model.encode_infer(data.x, data.train_pos_edge_index) # encode train\n",
    "    for prefix in [\"val\", \"test\"]:\n",
    "        pos_edge_index = data[f'{prefix}_pos_edge_index']\n",
    "        neg_edge_index = data[f'{prefix}_neg_edge_index']\n",
    "        link_logits = model.decode(z, pos_edge_index, neg_edge_index) # decode test or val\n",
    "        link_probs = link_logits.sigmoid() # apply sigmoid\n",
    "        \n",
    "        link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link\n",
    "        \n",
    "        perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score\n",
    "    return perfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6aae1ed-1c1d-4e1b-81ae-d31b72e2cb4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_ids=[1,2]\n",
    "for run_id in run_ids:\n",
    "    config = {\n",
    "        \"model_name\":\"CGCN\",\n",
    "        \"task\":\"LP\",\n",
    "        \"run_id\":run_id,\n",
    "        \"dataset\":dataset_name\n",
    "    }\n",
    "    path = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"\n",
    "    !mkdir -p $path\n",
    "    path2 = 'model_data/'+config['dataset']+\"/\"+config['model_name']+\"/\"+config['task']+\"_\"+str(config['run_id'])+\"*\"\n",
    "    !rm $path2\n",
    "    model, data = Net(hidden_channels=128,out_channels=128,num_layers=4,dropout=0).to(device), data.to(device)\n",
    "    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001, weight_decay=5e-4)\n",
    "    best_val_perf = test_perf = 0\n",
    "    test_acc_list=[]\n",
    "    for epoch in range(1, 201):\n",
    "        train_loss = train()\n",
    "        val_perf, tmp_test_perf = test()\n",
    "        if val_perf > best_val_perf:\n",
    "            best_val_perf = val_perf\n",
    "            test_perf = tmp_test_perf\n",
    "        log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'\n",
    "        feature_vals = deepcopy(model.feature_vals)\n",
    "        feature_path =  path+config['task']+\"_\"+str(config['run_id'])+'_'+str(epoch)+'.npz'\n",
    "        np.savez(feature_path, **feature_vals)\n",
    "        test_acc_list.append(tmp_test_perf)\n",
    "        if epoch % 10 == 0:\n",
    "            print(log.format(epoch, train_loss, best_val_perf, test_perf))\n",
    "    #from matplotlib.pyplot import plt\n",
    "    plt.figure(figsize=(20,8))\n",
    "    plt.plot(test_acc_list)\n",
    "    plt.title(\"Accuracy Over Epochs\", fontsize=20)\n",
    "    plt.xlabel(\"Epochs\", fontsize=15)\n",
    "    plt.ylabel(\"Test Accuracy\", fontsize=15)\n",
    "    output_filename = f'CGCN_LP_test_accuracy_{run_id}.png'\n",
    "    plt.savefig(output_filename, bbox_inches='tight')\n",
    "    plt.close()  # Close the plot to release resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d58991f4-95bb-4aca-baa5-14a04a558085",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in model.feature_vals.keys():\n",
    "    print(model.feature_vals[k].shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
