{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import copy\n",
    "from itertools import product\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from torch.functional import F\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Dataset\n",
    "from torch_geometric.nn import GATConv, GCNConv\n",
    "from scipy import sparse\n",
    "\n",
    "# Specifically for the flight data processing\n",
    "import pycountry\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Experiment parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "dataset_name = \"Flight\"\n",
    "\n",
    "\n",
    "# Training/validation/calibration/test dataset split sizes\n",
    "props = np.array([0.2, 0.1, 0.35, 0.35])\n",
    "\n",
    "# Target 1-coverage for conformal prediction\n",
    "alpha = 0.1\n",
    "\n",
    "# Number of experiments\n",
    "num_train_trans = 10\n",
    "num_permute_trans = 100\n",
    "num_train_semi_ind = 50\n",
    "\n",
    "# GNN model parameters\n",
    "num_epochs = 30\n",
    "num_channels_GCN = 32\n",
    "num_channels_GAT = 32\n",
    "\n",
    "\n",
    "# Save results\n",
    "results_file = f\"results/Conformal_GNN_{dataset_name}_Results.pkl\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing the flight data...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2646/2646 [01:22<00:00, 32.26it/s]\n"
     ]
    }
   ],
   "source": [
    "# %%\n",
    "print(\"Processing the flight data...\")\n",
    "\n",
    "datapath = \"datasets/flight_data/\"\n",
    "\n",
    "# Load adjacency matrices\n",
    "As = []\n",
    "T = 36\n",
    "for t in range(T):\n",
    "    As.append(sparse.load_npz(datapath + \"As_\" + str(t) + \".npz\"))\n",
    "\n",
    "As = np.array(As)\n",
    "n = As[0].shape[0]\n",
    "\n",
    "# Load labels\n",
    "Z = np.load(datapath + \"Z.npy\")\n",
    "nodes = np.load(datapath + \"nodes.npy\", allow_pickle=True).item()\n",
    "\n",
    "\n",
    "months = [\n",
    "    \"January\",\n",
    "    \"February\",\n",
    "    \"March\",\n",
    "    \"April\",\n",
    "    \"May\",\n",
    "    \"June\",\n",
    "    \"July\",\n",
    "    \"August\",\n",
    "    \"September\",\n",
    "    \"October\",\n",
    "    \"November\",\n",
    "    \"December\",\n",
    "]\n",
    "labels = np.core.defchararray.add(\n",
    "    np.array(months * 3), np.repeat([\"2019\", \"2020\", \"2021\"], 12)\n",
    ")\n",
    "\n",
    "euro_nodes_idx = np.where(np.array(Z) == \"EU\")[0]\n",
    "\n",
    "airports = pd.read_csv(datapath + \"airports.csv\")\n",
    "country_codes_in_data = []\n",
    "airports_not_found = []\n",
    "flights_for_each_country_month = []\n",
    "flights_for_each_country_number = []\n",
    "flights_for_each_country_airport = []\n",
    "flights_for_each_country_code = []\n",
    "flights_for_each_country_country = []\n",
    "for code in tqdm(np.array(list(nodes.keys()))[euro_nodes_idx]):\n",
    "    try:\n",
    "        country_of_airport = airports[airports[\"ident\"] == code][\"iso_country\"].values[\n",
    "            0\n",
    "        ]\n",
    "        country_codes_in_data.append(country_of_airport)\n",
    "\n",
    "        for t in range(T):\n",
    "            flights_for_each_country_month.append(labels[t])\n",
    "            flights_for_each_country_number.append(np.sum(As[t][:, nodes[code]]))\n",
    "            flights_for_each_country_airport.append(code)\n",
    "            flights_for_each_country_code.append(country_of_airport)\n",
    "            try:\n",
    "                country_name = pycountry.countries.get(alpha_2=country_of_airport).name\n",
    "                flights_for_each_country_country.append(country_name)\n",
    "            except:\n",
    "                if country_of_airport == \"XK\":\n",
    "                    flights_for_each_country_country.append(\"Kosovo\")\n",
    "                else:\n",
    "                    print(\"can't find country for code: \", country_of_airport)\n",
    "                    flights_for_each_country_country.append(\"Unknown\")\n",
    "    except:\n",
    "        airports_not_found.append(code)\n",
    "\n",
    "len(airports_not_found)\n",
    "\n",
    "flights_for_each_country = pd.DataFrame(\n",
    "    {\n",
    "        \"month\": flights_for_each_country_month,\n",
    "        \"airport\": flights_for_each_country_airport,\n",
    "        \"number_of_flights\": flights_for_each_country_number,\n",
    "        \"country_code\": flights_for_each_country_code,\n",
    "        \"country\": flights_for_each_country_country,\n",
    "    }\n",
    ")\n",
    "\n",
    "flights_for_each_country.loc[\n",
    "    flights_for_each_country[\"country\"] == \"Unknown\", \"country\"\n",
    "] = \"Kosovo\"\n",
    "flights_for_each_country.loc[\n",
    "    flights_for_each_country[\"country\"] == \"Czechia\", \"country\"\n",
    "] = \"Czech Republic\"\n",
    "flights_for_each_country.loc[\n",
    "    flights_for_each_country[\"country\"] == \"Slovakia\", \"country\"\n",
    "] = \"Slovak Republic\"\n",
    "flights_for_each_country.loc[\n",
    "    flights_for_each_country[\"country\"] == \"Russian Federation\", \"country\"\n",
    "] = \"Russia\"\n",
    "flights_for_each_country.loc[\n",
    "    flights_for_each_country[\"country\"] == \"Moldova, Republic of\", \"country\"\n",
    "] = \"Moldova\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "euro_airports = np.array(list(nodes.keys()))[euro_nodes_idx]\n",
    "airport_to_country = dict(\n",
    "    zip(flights_for_each_country[\"airport\"], flights_for_each_country[\"country\"])\n",
    ")\n",
    "country_labels = np.array([airport_to_country[code] for code in euro_airports])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "\n",
    "# Set the target variable to be the continent (Z) and encode them\n",
    "node_labels = Z\n",
    "node_labels_enc = pd.Categorical(node_labels)\n",
    "node_labels = np.tile(np.array(node_labels_enc.codes), T)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "\n",
    "As_euro = []\n",
    "for A in As:\n",
    "    A_euro = A[euro_nodes_idx, :][:, euro_nodes_idx]\n",
    "    As_euro.append(A_euro)\n",
    "\n",
    "n = As_euro[0].shape[0]\n",
    "node_labels_enc = pd.Categorical(country_labels)\n",
    "node_labels = np.tile(node_labels_enc.codes, T)\n",
    "As = As_euro\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "# (For the dataset plotting script - not required for the experiments to run)\n",
    "\n",
    "from scipy import sparse\n",
    "\n",
    "As_sparse = [sparse.csr_matrix(A) for A in As]\n",
    "for i, A_sparse in enumerate(As_sparse):\n",
    "    sparse.save_npz(f\"datasets/flight_data/flight_As_{i}.npz\", A_sparse)\n",
    "\n",
    "np.save(f\"datasets/flight_data/flight_node_labels.npy\", node_labels)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "class Dynamic_Network(Dataset):\n",
    "    \"\"\"\n",
    "    A pytorch geometric dataset for a dynamic network.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, As, labels):\n",
    "        self.As = As\n",
    "        self.T = len(As)\n",
    "        self.n = As[0].shape[0]\n",
    "        self.classes = labels\n",
    "\n",
    "        assert len(labels) == self.n * self.T\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.As)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x = torch.sparse.spdiags(\n",
    "            torch.ones(self.n),\n",
    "            offsets=torch.tensor([0]),\n",
    "            shape=(self.n, self.n),\n",
    "        )\n",
    "        edge_index = torch.tensor(\n",
    "            np.array([self.As[idx].nonzero()]), dtype=torch.long\n",
    "        ).reshape(2, -1)\n",
    "        edge_weight = torch.tensor(np.array(self.As[idx].data), dtype=torch.float)\n",
    "        y = torch.tensor(\n",
    "            self.classes[self.n * idx : self.n * (idx + 1)], dtype=torch.long\n",
    "        )\n",
    "\n",
    "        # Create a PyTorch Geometric data object\n",
    "        data = torch_geometric.data.Data(\n",
    "            x=x, edge_index=edge_index, edge_weight=edge_weight, y=y\n",
    "        )\n",
    "        data.num_nodes = self.n\n",
    "\n",
    "        return data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "class Block_Diagonal_Network(Dataset):\n",
    "    \"\"\"\n",
    "    A pytorch geometric dataset for the block diagonal version of a Dynamic Network object.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, dataset):\n",
    "        self.A = sparse.block_diag(dataset.As)\n",
    "        self.T = dataset.T\n",
    "        self.n = dataset.n\n",
    "        self.classes = dataset.classes\n",
    "\n",
    "        assert len(self.classes) == self.n * self.T\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x = torch.sparse.spdiags(\n",
    "            torch.ones(self.n * (self.T)),\n",
    "            offsets=torch.tensor([0]),\n",
    "            shape=(self.n * (self.T), self.n * (self.T)),\n",
    "        )\n",
    "        edge_index = torch.tensor(\n",
    "            np.array([self.A.nonzero()]), dtype=torch.long\n",
    "        ).reshape(2, -1)\n",
    "        edge_weight = torch.tensor(np.array(self.A.data), dtype=torch.float)\n",
    "        y = torch.tensor(self.classes, dtype=torch.long)\n",
    "\n",
    "        # Create a PyTorch Geometric data object\n",
    "        data = torch_geometric.data.Data(\n",
    "            x=x, edge_index=edge_index, edge_weight=edge_weight, y=y\n",
    "        )\n",
    "        data.num_nodes = self.n * self.T\n",
    "\n",
    "        return data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def general_unfolded_matrix(As, sparse_matrix=False):\n",
    "    \"\"\"Forms the general unfolded matrix from an adjacency series\"\"\"\n",
    "    T = len(As)\n",
    "    n = As[0].shape[0]\n",
    "\n",
    "    # Construct the rectangular unfolded adjacency\n",
    "    if sparse_matrix:\n",
    "        A = As[0]\n",
    "        for t in range(1, T):\n",
    "            A = sparse.hstack((A, As[t]))\n",
    "\n",
    "        # Construct the dilated unfolded adjacency matrix\n",
    "        DA = sparse.bmat([[None, A], [A.T, None]])\n",
    "        DA = sparse.csr_matrix(DA)\n",
    "    else:\n",
    "        A = As[0]\n",
    "        for t in range(1, T):\n",
    "            A = np.block([A, As[t]])\n",
    "\n",
    "        DA = np.zeros((n + n * T, n + n * T))\n",
    "        DA[0:n, n:] = A\n",
    "        DA[n:, 0:n] = A.T\n",
    "\n",
    "    return DA\n",
    "\n",
    "\n",
    "class Unfolded_Network(Dataset):\n",
    "    \"\"\"\n",
    "    A pytorch geometric dataset for the dilated unfolding of a Dynamic Network object.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, dataset):\n",
    "        self.A = general_unfolded_matrix(dataset.As, sparse_matrix=True)\n",
    "        self.T = dataset.T\n",
    "        self.n = dataset.n\n",
    "        self.classes = dataset.classes\n",
    "\n",
    "        assert len(self.classes) == self.n * self.T\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x = torch.sparse.spdiags(\n",
    "            torch.ones(self.n * (self.T + 1)),\n",
    "            offsets=torch.tensor([0]),\n",
    "            shape=(self.n * (self.T + 1), self.n * (self.T + 1)),\n",
    "        )\n",
    "        edge_index = torch.tensor(\n",
    "            np.array([self.A.nonzero()]), dtype=torch.long\n",
    "        ).reshape(2, -1)\n",
    "        edge_weight = torch.tensor(np.array(self.A.data), dtype=torch.float)\n",
    "\n",
    "        # Add n zeros to the start of y for the anchors\n",
    "        y = torch.tensor(\n",
    "            np.concatenate((np.zeros(self.n), self.classes)), dtype=torch.long\n",
    "        )\n",
    "\n",
    "        # Create a PyTorch Geometric data object\n",
    "        data = torch_geometric.data.Data(\n",
    "            x=x, edge_index=edge_index, edge_weight=edge_weight, y=y\n",
    "        )\n",
    "        data.num_nodes = self.n * (self.T + 1)\n",
    "\n",
    "        return data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "dataset = Dynamic_Network(As, node_labels)\n",
    "dataset_BD = Block_Diagonal_Network(dataset)[0]\n",
    "dataset_UA = Unfolded_Network(dataset)[0]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Define general GCN and GAT neural networks to be applied to both the block diagonal and dilated unfolded networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self, num_nodes, num_channels, num_classes, seed):\n",
    "        super().__init__()\n",
    "        torch.manual_seed(seed)\n",
    "        self.conv1 = GCNConv(num_nodes, num_channels)\n",
    "        self.conv2 = GCNConv(num_channels, num_classes)\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight):\n",
    "        x = self.conv1(x, edge_index, edge_weight)\n",
    "        x = x.relu()\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.conv2(x, edge_index, edge_weight)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class GAT(torch.nn.Module):\n",
    "    def __init__(self, num_nodes, num_channels, num_classes, seed):\n",
    "        super().__init__()\n",
    "        torch.manual_seed(seed)\n",
    "        self.conv1 = GATConv(num_nodes, num_channels)\n",
    "        self.conv2 = GATConv(num_channels, num_classes)\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight):\n",
    "        x = self.conv1(x, edge_index, edge_weight)\n",
    "        x = x.relu()\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.conv2(x, edge_index, edge_weight)\n",
    "\n",
    "        return x\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def train(model, data, train_mask):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    out = model(data.x, data.edge_index, data.edge_weight)\n",
    "    loss = criterion(out[train_mask], data.y[train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def valid(model, data, valid_mask):\n",
    "    model.eval()\n",
    "\n",
    "    out = model(data.x, data.edge_index, data.edge_weight)\n",
    "    pred = out.argmax(dim=1)\n",
    "    correct = pred[valid_mask] == data.y[valid_mask]\n",
    "    acc = int(correct.sum()) / int(valid_mask.sum())\n",
    "\n",
    "    return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Data split functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Create a data mask that consists only of useful nodes for training and testing. We only consider nodes in a particular time window if it has degree greater than zero at that time window. Also, we ignore nodes with label `Teacher` meaning that the number of classes is now 10."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Percentage of usable node/time pairs: 100.0%\n"
     ]
    }
   ],
   "source": [
    "# %%\n",
    "data_mask = np.array([[True] * T for _ in range(n)])\n",
    "\n",
    "for t in range(T):\n",
    "    data_mask[np.where(np.sum(As[t], axis=0) == 0)[0], t] = False\n",
    "    # data_mask[np.where(node_labels == label_dict['Teachers'])[0], t] = False\n",
    "\n",
    "num_classes = np.unique(node_labels).shape[0]\n",
    "print(\n",
    "    f\"Percentage of usable node/time pairs: {100 * np.sum(data_mask) / (n * T) :02.1f}%\"\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def mask_split(mask, split_props, seed=0, mode=\"transductive\"):\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    n, T = mask.shape\n",
    "\n",
    "    if mode == \"transductive\":\n",
    "        # Flatten mask array into one dimension in blocks of nodes per time\n",
    "        flat_mask = mask.T.reshape(-1)\n",
    "        n_masks = np.sum(flat_mask)\n",
    "\n",
    "        # Split shuffled flatten mask array indices into correct proportions\n",
    "        flat_mask_idx = np.where(flat_mask)[0]\n",
    "        np.random.shuffle(flat_mask_idx)\n",
    "        split_ns = np.cumsum([round(n_masks * prop) for prop in split_props[:-1]])\n",
    "        split_idx = np.split(flat_mask_idx, split_ns)\n",
    "\n",
    "    if mode == \"semi-inductive\":\n",
    "        # Find time such that final proportion of masks happen after that time\n",
    "        T_trunc = np.where(\n",
    "            np.cumsum(np.sum(mask, axis=0) / np.sum(mask)) >= 1 - split_props[-1]\n",
    "        )[0][0]\n",
    "\n",
    "        # Flatten mask arrays into one dimension in blocks of nodes per time\n",
    "        flat_mask_start = mask[:, :T_trunc].T.reshape(-1)\n",
    "        flat_mask_end = mask[:, T_trunc:].T.reshape(-1)\n",
    "        n_masks_start = np.sum(flat_mask_start)\n",
    "\n",
    "        # Split starting shuffled flatten mask array into correct proportions\n",
    "        flat_mask_start_idx = np.where(flat_mask_start)[0]\n",
    "        np.random.shuffle(flat_mask_start_idx)\n",
    "        split_props_start = split_props[:-1] / np.sum(split_props[:-1])\n",
    "        split_ns = np.cumsum(\n",
    "            [round(n_masks_start * prop) for prop in split_props_start[:-1]]\n",
    "        )\n",
    "        split_idx = np.split(flat_mask_start_idx, split_ns)\n",
    "\n",
    "        # Place finishing flatten mask array at the end\n",
    "        split_idx.append(n * T_trunc + np.where(flat_mask_end)[0])\n",
    "\n",
    "    split_masks = np.array([[False] * n * T for _ in range(len(split_props))])\n",
    "    for i in range(len(split_props)):\n",
    "        split_masks[i, split_idx[i]] = True\n",
    "\n",
    "    return split_masks\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def mask_mix(mask_1, mask_2, seed=0):\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    n = len(mask_1)\n",
    "    n1 = np.sum(mask_1)\n",
    "    n2 = np.sum(mask_2)\n",
    "\n",
    "    mask_idx = np.where(mask_1 + mask_2)[0]\n",
    "    np.random.shuffle(mask_idx)\n",
    "    split_idx = np.split(mask_idx, [n1])\n",
    "\n",
    "    split_masks = np.array([[False] * n for _ in range(2)])\n",
    "    for i in range(2):\n",
    "        split_masks[i, split_idx[i]] = True\n",
    "\n",
    "    return split_masks\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Testing the training/validation/calibration/test data split functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Percentage train: 19.7%\n",
      "Percentage valid: 9.8%\n",
      "Percentage calib: 34.4%\n",
      "Percentage test:  36.1%\n"
     ]
    }
   ],
   "source": [
    "# %%\n",
    "props = np.array([0.2, 0.1, 0.35, 0.35])\n",
    "train_mask, valid_mask, calib_mask, test_mask = mask_split(\n",
    "    data_mask, props, mode=\"semi-inductive\"\n",
    ")\n",
    "\n",
    "print(f\"Percentage train: {100 * np.sum(train_mask) / np.sum(data_mask) :02.1f}%\")\n",
    "print(f\"Percentage valid: {100 * np.sum(valid_mask) / np.sum(data_mask) :02.1f}%\")\n",
    "print(f\"Percentage calib: {100 * np.sum(calib_mask) / np.sum(data_mask) :02.1f}%\")\n",
    "print(f\"Percentage test:  {100 * np.sum(test_mask)  / np.sum(data_mask) :02.1f}%\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Conformal prediction functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def get_prediction_sets(output, data, calib_mask, test_mask, alpha=0.1):\n",
    "    n_calib = calib_mask.sum()\n",
    "    # output = model(data.x, data.edge_index, data.edge_weight)\n",
    "\n",
    "    # Compute softmax probabilities\n",
    "    smx = torch.nn.Softmax(dim=1)\n",
    "    calib_heuristic = smx(output[calib_mask]).detach().numpy()\n",
    "    test_heuristic = smx(output[test_mask]).detach().numpy()\n",
    "\n",
    "    # APS\n",
    "    calib_pi = calib_heuristic.argsort(1)[:, ::-1]\n",
    "    calib_srt = np.take_along_axis(calib_heuristic, calib_pi, axis=1).cumsum(axis=1)\n",
    "    calib_scores = np.take_along_axis(calib_srt, calib_pi.argsort(axis=1), axis=1)[\n",
    "        range(n_calib), data.y[calib_mask]\n",
    "    ]\n",
    "\n",
    "    # Get the score quantile\n",
    "    qhat = np.quantile(\n",
    "        calib_scores, np.ceil((n_calib + 1) * (1 - alpha)) / n_calib, method=\"higher\"\n",
    "    )\n",
    "\n",
    "    test_pi = test_heuristic.argsort(1)[:, ::-1]\n",
    "    test_srt = np.take_along_axis(test_heuristic, test_pi, axis=1).cumsum(axis=1)\n",
    "    pred_sets = np.take_along_axis(test_srt <= qhat, test_pi.argsort(axis=1), axis=1)\n",
    "\n",
    "    return pred_sets\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def accuracy(output, data, test_mask):\n",
    "    # output = model(data.x, data.edge_index, data.edge_weight)\n",
    "    pred = output.argmax(dim=1)\n",
    "    correct = pred[test_mask] == data.y[test_mask]\n",
    "    acc = int(correct.sum()) / int(test_mask.sum())\n",
    "\n",
    "    return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def avg_set_size(pred_sets, test_mask):\n",
    "    return np.mean(np.sum(pred_sets, axis=1))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def coverage(pred_sets, data, test_mask):\n",
    "    in_set = np.array(\n",
    "        [pred_set[label] for pred_set, label in zip(pred_sets, data.y[test_mask])]\n",
    "    )\n",
    "\n",
    "    return np.mean(in_set)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## GNN training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Initiate nested results data structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "results = {}\n",
    "\n",
    "methods = [\"BD\", \"UA\"]\n",
    "GNN_models = [\"GCN\", \"GAT\"]\n",
    "regimes = [\"Trans\", \"Semi-Ind\"]\n",
    "outputs = [\"Accuracy\", \"Avg Size\", \"Coverage\"]\n",
    "times = [\"All\"] + list(range(T))\n",
    "\n",
    "for method in methods:\n",
    "    results[method] = {}\n",
    "\n",
    "    for GNN_model in GNN_models:\n",
    "        results[method][GNN_model] = {}\n",
    "\n",
    "        for regime in regimes:\n",
    "            results[method][GNN_model][regime] = {}\n",
    "\n",
    "            for output in outputs:\n",
    "                results[method][GNN_model][regime][output] = {}\n",
    "\n",
    "                for time in times:\n",
    "                    results[method][GNN_model][regime][output][time] = []\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### Transductive experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Block Diagonal GCN Number 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30/30 [00:34<00:00,  1.14s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best valid: 0.3997\n",
      "Evaluating Block Diagonal GCN Number 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 40/100 [00:35<00:53,  1.12it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[27], line 81\u001b[0m\n\u001b[1;32m     77\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m     79\u001b[0m \u001b[38;5;66;03m# Get prediction sets corresponding to time t\u001b[39;00m\n\u001b[1;32m     80\u001b[0m index_mapping \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m---> 81\u001b[0m     index: i \u001b[38;5;28;01mfor\u001b[39;00m i, index \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_mask\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m     82\u001b[0m }\n\u001b[1;32m     83\u001b[0m indices \u001b[38;5;241m=\u001b[39m [index_mapping[index] \u001b[38;5;28;01mfor\u001b[39;00m index \u001b[38;5;129;01min\u001b[39;00m np\u001b[38;5;241m.\u001b[39mwhere(test_mask_t)[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m     84\u001b[0m pred_sets_t \u001b[38;5;241m=\u001b[39m pred_sets[np\u001b[38;5;241m.\u001b[39marray(indices)]\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# %%time\n",
    "\n",
    "for method, GNN_model in product(methods, GNN_models):\n",
    "\n",
    "    for i in range(num_train_trans):\n",
    "        # Split data into training/validation/calibration/test\n",
    "        train_mask, valid_mask, calib_mask, test_mask = mask_split(\n",
    "            data_mask, props, seed=i, mode=\"transductive\"\n",
    "        )\n",
    "\n",
    "        if method == \"BD\":\n",
    "            method_str = \"Block Diagonal\"\n",
    "            data = dataset_BD\n",
    "        if method == \"UA\":\n",
    "            method_str = \"Unfolded\"\n",
    "            data = dataset_UA\n",
    "            # Pad masks to include anchor nodes\n",
    "            train_mask = np.concatenate((np.array([False] * n), train_mask))\n",
    "            valid_mask = np.concatenate((np.array([False] * n), valid_mask))\n",
    "            calib_mask = np.concatenate((np.array([False] * n), calib_mask))\n",
    "            test_mask = np.concatenate((np.array([False] * n), test_mask))\n",
    "\n",
    "        if GNN_model == \"GCN\":\n",
    "            model = GCN(data.num_nodes, num_channels_GCN, num_classes, seed=i)\n",
    "        if GNN_model == \"GAT\":\n",
    "            model = GAT(data.num_nodes, num_channels_GAT, num_classes, seed=i)\n",
    "\n",
    "        optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "        print(f\"Training {method_str} {GNN_model} Number {i}\")\n",
    "        max_valid_acc = 0\n",
    "\n",
    "        for epoch in tqdm(range(num_epochs)):\n",
    "            train_loss = train(model, data, train_mask)\n",
    "            valid_acc = valid(model, data, valid_mask)\n",
    "\n",
    "            if valid_acc > max_valid_acc:\n",
    "                max_valid_acc = valid_acc\n",
    "                best_model = copy.deepcopy(model)\n",
    "\n",
    "        best_output = best_model(data.x, data.edge_index, data.edge_weight)\n",
    "        print(f\"Best valid: {max_valid_acc:.4f}\")\n",
    "        print(f\"Evaluating {method_str} {GNN_model} Number {i}\")\n",
    "\n",
    "        coverage_list = []\n",
    "        for j in tqdm(range(num_permute_trans)):\n",
    "            # Permute the calibration and test datasets\n",
    "            calib_mask, test_mask = mask_mix(calib_mask, test_mask, seed=j)\n",
    "\n",
    "            pred_sets = get_prediction_sets(\n",
    "                best_output, data, calib_mask, test_mask, alpha\n",
    "            )\n",
    "\n",
    "            cov = coverage(pred_sets, data, test_mask)\n",
    "            coverage_list.append(cov)\n",
    "            results[method][GNN_model][\"Trans\"][\"Accuracy\"][\"All\"].append(\n",
    "                accuracy(best_output, data, test_mask)\n",
    "            )\n",
    "            results[method][GNN_model][\"Trans\"][\"Avg Size\"][\"All\"].append(\n",
    "                avg_set_size(pred_sets, test_mask)\n",
    "            )\n",
    "            results[method][GNN_model][\"Trans\"][\"Coverage\"][\"All\"].append(cov)\n",
    "\n",
    "            for t in range(T):\n",
    "                # Consider test nodes only at time t\n",
    "                if method == \"BD\":\n",
    "                    time_mask = np.array([[False] * n for _ in range(T)])\n",
    "                    time_mask[t] = True\n",
    "                    time_mask = time_mask.reshape(-1)\n",
    "                if method == \"UA\":\n",
    "                    time_mask = np.array([[False] * n for _ in range(T + 1)])\n",
    "                    time_mask[t + 1] = True\n",
    "                    time_mask = time_mask.reshape(-1)\n",
    "\n",
    "                test_mask_t = time_mask * test_mask\n",
    "                if np.sum(test_mask_t) == 0:\n",
    "                    continue\n",
    "\n",
    "                # Get prediction sets corresponding to time t\n",
    "                index_mapping = {\n",
    "                    index: i for i, index in enumerate(np.where(test_mask)[0])\n",
    "                }\n",
    "                indices = [index_mapping[index] for index in np.where(test_mask_t)[0]]\n",
    "                pred_sets_t = pred_sets[np.array(indices)]\n",
    "\n",
    "                results[method][GNN_model][\"Trans\"][\"Accuracy\"][t].append(\n",
    "                    accuracy(best_output, data, test_mask_t)\n",
    "                )\n",
    "                results[method][GNN_model][\"Trans\"][\"Avg Size\"][t].append(\n",
    "                    avg_set_size(pred_sets_t, test_mask_t)\n",
    "                )\n",
    "                results[method][GNN_model][\"Trans\"][\"Coverage\"][t].append(\n",
    "                    coverage(pred_sets_t, data, test_mask_t)\n",
    "                )\n",
    "\n",
    "        print(\"Coverage: \", np.mean(coverage_list))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### Semi-inductive experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%time\n",
    "\n",
    "for method, GNN_model in product(methods, GNN_models):\n",
    "\n",
    "    for i in range(num_train_semi_ind):\n",
    "        # Split data into training/validation/calibration/test\n",
    "        train_mask, valid_mask, calib_mask, test_mask = mask_split(\n",
    "            data_mask, props, seed=i, mode=\"semi-inductive\"\n",
    "        )\n",
    "\n",
    "        if method == \"BD\":\n",
    "            method_str = \"Block Diagonal\"\n",
    "            data = dataset_BD\n",
    "        if method == \"UA\":\n",
    "            method_str = \"Unfolded\"\n",
    "            data = dataset_UA\n",
    "            # Pad masks to include anchor nodes\n",
    "            train_mask = np.concatenate((np.array([False] * n), train_mask))\n",
    "            valid_mask = np.concatenate((np.array([False] * n), valid_mask))\n",
    "            calib_mask = np.concatenate((np.array([False] * n), calib_mask))\n",
    "            test_mask = np.concatenate((np.array([False] * n), test_mask))\n",
    "\n",
    "        if GNN_model == \"GCN\":\n",
    "            model = GCN(data.num_nodes, num_channels_GCN, num_classes, seed=i)\n",
    "        if GNN_model == \"GAT\":\n",
    "            model = GAT(data.num_nodes, num_channels_GAT, num_classes, seed=i)\n",
    "\n",
    "        optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "        print(f\"Training {method_str} {GNN_model} Number {i}\")\n",
    "        max_valid_acc = 0\n",
    "\n",
    "        for epoch in tqdm(range(num_epochs)):\n",
    "            train_loss = train(model, data, train_mask)\n",
    "            valid_acc = valid(model, data, valid_mask)\n",
    "\n",
    "            if valid_acc > max_valid_acc:\n",
    "                max_valid_acc = valid_acc\n",
    "                best_model = copy.deepcopy(model)\n",
    "\n",
    "        best_output = best_model(data.x, data.edge_index, data.edge_weight)\n",
    "        print(f\"Best valid: {max_valid_acc:.4f}\")\n",
    "        print(f\"Evaluating {method_str} {GNN_model} Number {i}\")\n",
    "\n",
    "        # Cannot permute the calibration and test datasets in semi-inductive experiments\n",
    "\n",
    "        pred_sets = get_prediction_sets(best_output, data, calib_mask, test_mask, alpha)\n",
    "\n",
    "        cov = coverage(pred_sets, data, test_mask)\n",
    "        print(f\"Coverage: {cov:.4f}\")\n",
    "        results[method][GNN_model][\"Semi-Ind\"][\"Accuracy\"][\"All\"].append(\n",
    "            accuracy(best_output, data, test_mask)\n",
    "        )\n",
    "        results[method][GNN_model][\"Semi-Ind\"][\"Avg Size\"][\"All\"].append(\n",
    "            avg_set_size(pred_sets, test_mask)\n",
    "        )\n",
    "        results[method][GNN_model][\"Semi-Ind\"][\"Coverage\"][\"All\"].append(cov)\n",
    "\n",
    "        for t in range(T):\n",
    "            # Consider test nodes only at time t\n",
    "            if method == \"BD\":\n",
    "                time_mask = np.array([[False] * n for _ in range(T)])\n",
    "                time_mask[t] = True\n",
    "                time_mask = time_mask.reshape(-1)\n",
    "            if method == \"UA\":\n",
    "                time_mask = np.array([[False] * n for _ in range(T + 1)])\n",
    "                time_mask[t + 1] = True\n",
    "                time_mask = time_mask.reshape(-1)\n",
    "\n",
    "            test_mask_t = time_mask * test_mask\n",
    "            if np.sum(test_mask_t) == 0:\n",
    "                continue\n",
    "\n",
    "            # Get prediction sets corresponding to time t\n",
    "            pred_sets_t = pred_sets[\n",
    "                np.array(\n",
    "                    [\n",
    "                        np.where(np.where(test_mask)[0] == np.where(test_mask_t)[0][i])[\n",
    "                            0\n",
    "                        ][0]\n",
    "                        for i in range(sum(test_mask_t))\n",
    "                    ]\n",
    "                )\n",
    "            ]\n",
    "\n",
    "            results[method][GNN_model][\"Semi-Ind\"][\"Accuracy\"][t].append(\n",
    "                accuracy(best_output, data, test_mask_t)\n",
    "            )\n",
    "            results[method][GNN_model][\"Semi-Ind\"][\"Avg Size\"][t].append(\n",
    "                avg_set_size(pred_sets_t, test_mask_t)\n",
    "            )\n",
    "            results[method][GNN_model][\"Semi-Ind\"][\"Coverage\"][t].append(\n",
    "                coverage(pred_sets_t, data, test_mask_t)\n",
    "            )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Save results to pickle file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "with open(results_file, \"wb\") as file:\n",
    "    pickle.dump(results, file)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "TGB_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
