{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193027714,
     "user_tz": -120
    },
    "id": "nIBBcWxa4DB5"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.logging import log\n",
    "from torch_geometric.nn import GCNConv\n",
    "from typing import Callable, Sequence\n",
    "from lightning import seed_everything\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ignore this section\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregators\n",
    "\n",
    "It defines the projections that are in the `aggregation_modules.py`, copying them here only to make this notebook self contained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from typing import Sequence\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "log = logging.getLogger(__name__)\n",
    "\n",
    "\n",
    "class SumAggregation(nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"SumAggregation: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        self.norm_layers = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]) -> torch.Tensor:\n",
    "        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)\n",
    "\n",
    "        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]\n",
    "\n",
    "        return torch.stack(out, dim=1).sum(dim=1)\n",
    "\n",
    "\n",
    "class NonLinearSumAggregation(nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"SumAggregation: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        self.norm_layers = nn.ModuleList(\n",
    "            [\n",
    "                nn.Sequential(\n",
    "                    nn.LayerNorm(subspace_dim),\n",
    "                    nn.Linear(subspace_dim, subspace_dim),\n",
    "                    nn.Tanh(),\n",
    "                )\n",
    "                for _ in range(num_subspaces)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]) -> torch.Tensor:\n",
    "        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)\n",
    "\n",
    "        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]\n",
    "\n",
    "        return torch.stack(out, dim=1).sum(dim=1)\n",
    "\n",
    "\n",
    "class ConcatAggregation(nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"ConcatAggregation: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        self.norm_layers = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim * self.num_subspaces\n",
    "\n",
    "    def forward(self, concat_subspaces: torch.Tensor) -> torch.Tensor:\n",
    "        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)\n",
    "\n",
    "        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]\n",
    "\n",
    "        return torch.cat(out, dim=1)\n",
    "\n",
    "\n",
    "class WeightedAvgAggregation(nn.Module):  # TODO: fix this\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"WeightedAvgAggregation: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        self.weight = nn.Parameter(torch.ones(num_subspaces))\n",
    "\n",
    "        self.norm_layers = nn.ModuleList(\n",
    "            [\n",
    "                nn.Sequential(\n",
    "                    nn.LayerNorm(subspace_dim),\n",
    "                    # nn.Linear(subspace_dim, subspace_dim),\n",
    "                    # nn.Tanh(),\n",
    "                )\n",
    "                for _ in range(num_subspaces)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]) -> torch.Tensor:\n",
    "        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)\n",
    "        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]\n",
    "\n",
    "        softmax_weights = F.softmax(self.weight, dim=0)  # [num_subspaces]\n",
    "        concat_subspaces = torch.stack(concat_subspaces, dim=1)  # [batch_size, num_subspaces, subspace_dim]\n",
    "        out = torch.einsum(\"bns,n -> bs\", concat_subspaces, softmax_weights)  # [batch_size, subspace_dim]\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class Identity(nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"Identity: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        assert self.num_subspaces == 1\n",
    "\n",
    "        self.norm_layers = nn.LayerNorm(subspace_dim)\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, relative_space: torch.Tensor) -> torch.Tensor:\n",
    "        return self.norm_layers(relative_space)\n",
    "\n",
    "\n",
    "class NonLinearIdentity(nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"Identity: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        assert self.num_subspaces == 1\n",
    "\n",
    "        self.norm_layers = nn.LayerNorm(subspace_dim)\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, relative_space: torch.Tensor) -> torch.Tensor:\n",
    "        return self.norm_layers(relative_space)\n",
    "\n",
    "\n",
    "# check also https://github.com/sooftware/attentions/blob/master/attentions.py\n",
    "class SelfAttentionLayer(torch.nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        log.info(f\"SelfAttentionLayer: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "        log.info(f\"SelfAttentionLayer: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}\")\n",
    "\n",
    "        self.attention = nn.MultiheadAttention(embed_dim=self.subspace_dim, num_heads=1, batch_first=True)\n",
    "        self.norm_layers = nn.ModuleList([nn.LayerNorm(self.subspace_dim) for _ in range(self.num_subspaces)])\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]):\n",
    "        query = concat_subspaces.split(self.subspace_dim, dim=1)\n",
    "        query = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, query)]\n",
    "        query = torch.stack(query, dim=1)\n",
    "\n",
    "        out, _ = self.attention(query=query, key=query, value=query)\n",
    "\n",
    "        return torch.sum(out, dim=1)\n",
    "\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, k, heads):\n",
    "        super().__init__()\n",
    "        self.attention = nn.MultiheadAttention(embed_dim=k, num_heads=heads, batch_first=True, dropout=0.1)\n",
    "        self.norm1 = nn.LayerNorm(k)\n",
    "        self.norm2 = nn.LayerNorm(k)\n",
    "        self.ff = nn.Sequential(nn.Linear(k, k // 4), nn.ReLU(), nn.Linear(k // 4, k))\n",
    "\n",
    "    def forward(self, x):\n",
    "        attended = self.attention(query=x, key=x, value=x)[0]\n",
    "        x = self.norm1(attended)\n",
    "        fedforward = self.ff(x)\n",
    "        return self.norm2(fedforward)\n",
    "\n",
    "    def get_attention_weights(self, x):\n",
    "        return self.attention(query=x, key=x, value=x)[1]\n",
    "\n",
    "\n",
    "class Transformer(torch.nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        super().__init__()\n",
    "        depth = 1\n",
    "        heads = 1\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        self.tblocks = nn.Sequential(*[TransformerBlock(k=self.subspace_dim, heads=heads) for _ in range(depth)])\n",
    "\n",
    "    def get_attention_weights(self, concat_subspaces: Sequence[torch.Tensor], attention_idx: int = 0):\n",
    "        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)\n",
    "        attention_weights = self.tblocks[attention_idx].get_attention_weights(query)\n",
    "        return attention_weights\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]):\n",
    "        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)\n",
    "        out = self.tblocks(query)\n",
    "        return torch.mean(out, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Projection functions\n",
    " It defines the projections that are in the `relreps.py`, copying them here only to make this notebook self contained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "from typing import Callable, Optional, Tuple\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "def abs_to_rel(\n",
    "    anchors: torch.Tensor,\n",
    "    points: torch.Tensor,\n",
    "    normalizing_func: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]],\n",
    "    dist_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    ") -> torch.Tensor:\n",
    "    relative_points = []\n",
    "\n",
    "    if normalizing_func is not None:\n",
    "        anchors, points = normalizing_func(anchors=anchors, points=points)\n",
    "\n",
    "    for point in points:\n",
    "        current_rel_point = []\n",
    "        for anchor in anchors:\n",
    "            current_rel_point.append(dist_func(point=point, anchor=anchor))\n",
    "        relative_points.append(current_rel_point)\n",
    "    return torch.as_tensor(relative_points, dtype=anchors.dtype)\n",
    "\n",
    "\n",
    "def basis_change_lstsq(anchors: torch.Tensor, points: torch.Tensor) -> torch.Tensor:\n",
    "    # Center the data\n",
    "    anchor_mean = torch.mean(anchors, dim=0)\n",
    "    centered_points = points - anchor_mean\n",
    "    centered_anchors = anchors - anchor_mean\n",
    "\n",
    "    # Normalize the centered points and anchors\n",
    "    normalized_points = centered_points / torch.norm(centered_points, dim=1, keepdim=True)\n",
    "    normalized_anchors = centered_anchors / torch.norm(centered_anchors, dim=1, keepdim=True)\n",
    "\n",
    "    try:\n",
    "        return torch.linalg.lstsq(normalized_anchors.T, normalized_points.T)[0].T\n",
    "    except RuntimeError:\n",
    "        return torch.zeros((points.shape[0], anchors.shape[0]))\n",
    "\n",
    "\n",
    "def abs_to_rel_cosine(\n",
    "    anchors: torch.Tensor,\n",
    "    points: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    norm_anchors = torch.nn.functional.normalize(anchors, dim=-1)\n",
    "    norm_points = torch.nn.functional.normalize(points, dim=-1)\n",
    "\n",
    "    return norm_points @ norm_anchors.T\n",
    "\n",
    "\n",
    "def abs_to_rel_center_cosine(\n",
    "    anchors: torch.Tensor,\n",
    "    points: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    anchors = anchors - points.mean(dim=0)\n",
    "    points = points - points.mean(dim=0)\n",
    "\n",
    "    norm_anchors = torch.nn.functional.normalize(anchors, dim=-1)\n",
    "    norm_points = torch.nn.functional.normalize(points, dim=-1)\n",
    "\n",
    "    return norm_points @ norm_anchors.T\n",
    "\n",
    "\n",
    "def abs_to_rel_lp(\n",
    "    anchors: torch.Tensor,\n",
    "    points: torch.Tensor,\n",
    "    p: int,\n",
    ") -> torch.Tensor:\n",
    "    return torch.cdist(points, anchors, p=p)\n",
    "\n",
    "\n",
    "def abs_to_rel_normalized_euclidean(\n",
    "    anchors: torch.Tensor,\n",
    "    points: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    anchors = anchors - points.mean(dim=0)\n",
    "    points = points - points.mean(dim=0)\n",
    "\n",
    "    norm_anchors = torch.nn.functional.normalize(anchors, dim=-1)\n",
    "    norm_points = torch.nn.functional.normalize(points, dim=-1)\n",
    "\n",
    "    return torch.cdist(norm_points, norm_anchors, p=2)\n",
    "\n",
    "\n",
    "def abs_to_rel_std_euclidean(\n",
    "    anchors: torch.Tensor,\n",
    "    points: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    anchors = anchors - points.mean(dim=0)\n",
    "    points = points - points.mean(dim=0)\n",
    "\n",
    "    norm_anchors = anchors / points.std(dim=0)\n",
    "    norm_points = points / points.std(dim=0)\n",
    "\n",
    "    return torch.cdist(norm_points, norm_anchors, p=2)\n",
    "\n",
    "\n",
    "SIMPLE_PROJECTION_TYPE = {\n",
    "    \"CoB\": basis_change_lstsq,\n",
    "    \"Cosine\": abs_to_rel_cosine,\n",
    "    \"CenterCosine\": abs_to_rel_center_cosine,\n",
    "    \"Euclidean\": functools.partial(abs_to_rel_lp, p=2),\n",
    "    \"NormEuclidean\": abs_to_rel_normalized_euclidean,\n",
    "    \"L1\": functools.partial(abs_to_rel_lp, p=1),\n",
    "    \"Linf\": functools.partial(abs_to_rel_lp, p=\"inf\"),\n",
    "    \"LPinf\": functools.partial(abs_to_rel_lp, p=42),\n",
    "    \"Absolute\": lambda points, **kwargs: points,\n",
    "    # # #\n",
    "    # Archive\n",
    "    #\n",
    "    # \"Wasserstein\": abs_to_rel_wasserstein,\n",
    "    # # \"Standardized Euclidean\": abs_to_rel_std_euclidean,\n",
    "    # # \"L1\": functools.partial(abs_to_rel_lp, p=1),\n",
    "    # # \"L3\": functools.partial(abs_to_rel_lp, p=3),\n",
    "    # \"CoB Lstsq\": basis_change_lstsq,\n",
    "    # \"Normalized Absolute\": lambda points, **kwargs: torch.nn.functional.normalize(points, dim=-1),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain, combinations\n",
    "\n",
    "\n",
    "def pair_powerset(iterable, pairing):\n",
    "    \"\"\"Returns the powerset of the iterable, pairing each element with the given pairing.\"\"\"\n",
    "    s = list(iterable)\n",
    "    return ((x, pairing) for x in chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) if x)\n",
    "\n",
    "\n",
    "# Example usage\n",
    "list(pair_powerset([\"Cosine\", \"CenterCosine\", \"Euclidean\", \"NormEuclidean\"], \"t\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, data, optimizer):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data.x, data.edge_index)\n",
    "    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, data):\n",
    "    model.eval()\n",
    "    pred = model(data.x, data.edge_index).argmax(dim=-1)\n",
    "\n",
    "    accs = []\n",
    "    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n",
    "        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n",
    "    return accs\n",
    "\n",
    "\n",
    "def train_model(run_name, data, projection_funcs, aggregation_module, seed):\n",
    "    seed_everything(seed)\n",
    "    model = GNN(\n",
    "        dataset.num_features,\n",
    "        hidden_channels,\n",
    "        dataset.num_classes,\n",
    "        relative_block=RelativeBlock(\n",
    "            projection_funcs,\n",
    "            aggregation_module(\n",
    "                subspace_dim=anchor_indices.shape[0],\n",
    "                num_subspaces=len(projection_funcs),\n",
    "            ),\n",
    "        ),\n",
    "        anchor_indices=anchor_indices,\n",
    "    )\n",
    "    model, data = model.to(device), data.to(device)\n",
    "    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)\n",
    "\n",
    "    best_val_acc = final_test_acc = 0\n",
    "    for epoch in (bar := tqdm(range(1, epochs + 1))):\n",
    "        loss = train(model, data, optimizer)\n",
    "        train_acc, val_acc, tmp_test_acc = test(model, data)\n",
    "        if val_acc > best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            test_acc = tmp_test_acc\n",
    "\n",
    "        bar.set_description(\n",
    "            f\"{run_name} [{epoch}: {loss:.4f}] (train: {train_acc:.4f}) (val: {val_acc:.4f}) (test: {test_acc:.4f})\"\n",
    "        )\n",
    "\n",
    "    model.eval()\n",
    "    h = model.encode(data.x, data.edge_index)\n",
    "\n",
    "    rel_h = F.normalize(h, dim=-1, p=2)\n",
    "    rel_h = rel_h @ rel_h[anchor_indices].T\n",
    "\n",
    "    return model.eval().cpu(), h, rel_h"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model wrappers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StitchedModel(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        encoder: torch.nn.Module,\n",
    "        decoder: torch.nn.Module,\n",
    "    ):\n",
    "        \"\"\"Utility class to stitch together an encoder and decoder model already pre-trained.\"\"\"\n",
    "        super().__init__()\n",
    "        self.encoder = encoder.eval()\n",
    "        self.decoder = decoder.eval()\n",
    "\n",
    "    def forward(self, x: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:\n",
    "        rel_x = self.encoder.encode(x, anchors)\n",
    "        return self.decoder.decode(rel_x)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"StitchedModel({self.encoder}, {self.decoder})\"\n",
    "\n",
    "\n",
    "class RelativeBlock(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        projection_names: Sequence[str],\n",
    "        aggregation_module: torch.nn.Module,\n",
    "    ):\n",
    "        \"\"\"A block that takes in a set of projection functions and an aggregation module\n",
    "        and exposes the encode and decode methods in a generic way.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.projection_names = projection_names\n",
    "        self.projection_funcs: Sequence[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = [\n",
    "            SIMPLE_PROJECTION_TYPE[x] for x in projection_names\n",
    "        ]\n",
    "        self.aggregation_module = aggregation_module\n",
    "\n",
    "    def forward(self, x: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:\n",
    "        rel_x = self.encode(x, anchors)\n",
    "        return self.decode(rel_x)\n",
    "\n",
    "    def encode(self, x, anchors) -> torch.Tensor:\n",
    "        return torch.cat([fun(anchors=anchors, points=x) for fun in self.projection_funcs], dim=-1)\n",
    "\n",
    "    def decode(self, rel_x: torch.Tensor) -> torch.Tensor:\n",
    "        return self.aggregation_module(rel_x)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"RelativeBlock({self.projection_names}, {self.aggregation_module})\"\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"RelativeBlock({self.projection_names=}, {self.aggregation_module=})\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tl37fspl4DB6"
   },
   "source": [
    "## Training models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Type\n",
    "from torch import nn\n",
    "import itertools\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "\n",
    "### Hyperparams\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hidden_channels = 300\n",
    "lr = 1e-4\n",
    "epochs = 25\n",
    "### ---------------------\n",
    "\n",
    "\n",
    "path = Path(\"data\") / \"Planetoid\"\n",
    "dataset = Planetoid(path, \"Cora\", transform=T.NormalizeFeatures())\n",
    "data = dataset[0]\n",
    "num_nodes = data.x.shape[0]\n",
    "anchor_indices = torch.randperm(data.x.size(0))[:hidden_channels]\n",
    "\n",
    "\n",
    "# The modules forward is splitted in encode/decode to ease the stitching.\n",
    "# Their organization is the following:\n",
    "# - GNN:\n",
    "#   - RelativeBlock: instantiated with a list of projections names and the aggregation module\n",
    "#       - encode: Projection fuctions and their names\n",
    "#       - decode: aggregation module (e.g. attention, transformer, etc.)\n",
    "\n",
    "\n",
    "class GNN(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels,\n",
    "        hidden_channels,\n",
    "        out_channels,\n",
    "        relative_block: RelativeBlock,\n",
    "        anchor_indices: torch.Tensor,\n",
    "    ):\n",
    "        \"\"\"Main model\"\"\"\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)\n",
    "        self.conv2 = GCNConv(hidden_channels, hidden_channels, cached=True)\n",
    "\n",
    "        self.register_buffer(\"anchor_indices\", anchor_indices)\n",
    "        self.relative_block = relative_block\n",
    "        self.decoder = torch.nn.Linear(relative_block.aggregation_module.out_dim, out_channels)\n",
    "\n",
    "    def encode(self, x, edge_index):\n",
    "        features = self.conv1(x, edge_index).relu()\n",
    "        features = self.conv2(features, edge_index)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            anchors = features[self.anchor_indices]\n",
    "\n",
    "        x = self.relative_block.encode(x=features, anchors=anchors)\n",
    "        return x\n",
    "\n",
    "    def decode(self, x):\n",
    "        x = self.relative_block.decode(x)\n",
    "        return self.decoder(x)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.encode(x, edge_index)\n",
    "\n",
    "        return self.decode(x)\n",
    "\n",
    "    def get_attention_weights(self, x, edge_index, attention_idx: int = 0):\n",
    "        x = self.encode(x, edge_index)\n",
    "        return self.relative_block.aggregation_module.get_attention_weights(x, attention_idx=attention_idx)\n",
    "\n",
    "\n",
    "## Aggregations\n",
    "\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, k, heads):\n",
    "        \"\"\"Only used in the Transformer aggregator\"\"\"\n",
    "        super().__init__()\n",
    "        self.attention = nn.MultiheadAttention(embed_dim=k, num_heads=heads, batch_first=True, dropout=0.1)\n",
    "        self.norm1 = nn.LayerNorm(k)\n",
    "        self.norm2 = nn.LayerNorm(k)\n",
    "        self.ff = nn.Sequential(nn.Linear(k, k // 4), nn.ReLU(), nn.Linear(k // 4, k))\n",
    "\n",
    "    def forward(self, x):\n",
    "        # TODO: norme dei diversi sottospazi\n",
    "        attended = self.attention(query=x, key=x, value=x)[0]\n",
    "        # x = self.norm1(attended + x)\n",
    "        x = self.norm1(attended)\n",
    "        fedforward = self.ff(x)\n",
    "        return self.norm2(fedforward)\n",
    "\n",
    "    def get_attention_weights(self, x):\n",
    "        return self.attention(query=x, key=x, value=x)[1]\n",
    "\n",
    "\n",
    "class Transformer(torch.nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int, batch_size: int = 16):\n",
    "        \"\"\"Aggregator using the Transformer architecture\"\"\"\n",
    "        super().__init__()\n",
    "        depth = 1\n",
    "        heads = 1\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "        self.tblocks = nn.Sequential(*[TransformerBlock(k=self.subspace_dim, heads=heads) for _ in range(depth)])\n",
    "        self.input_norms = nn.ModuleList([nn.LayerNorm(self.subspace_dim) for _ in range(self.num_subspaces)])\n",
    "\n",
    "    def get_attention_weights(self, concat_subspaces: Sequence[torch.Tensor], attention_idx: int = 0):\n",
    "        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)\n",
    "        query = torch.stack([norm(x) for x, norm in zip(query.unbind(dim=1), self.input_norms)], dim=1)\n",
    "\n",
    "        attention_weights = self.tblocks[attention_idx].get_attention_weights(query)\n",
    "        return attention_weights\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]):\n",
    "        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)\n",
    "        query = torch.stack([norm(x) for x, norm in zip(query.unbind(dim=1), self.input_norms)], dim=1)\n",
    "        out = self.tblocks(query)\n",
    "        return torch.mean(out, dim=1)\n",
    "\n",
    "\n",
    "class NewSelfAttentionLayer(torch.nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int):\n",
    "        \"\"\"Aggregator using the Attention only\"\"\"\n",
    "        super().__init__()\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "\n",
    "        self.attention = nn.MultiheadAttention(embed_dim=self.subspace_dim, num_heads=1, batch_first=True, dropout=0.25)\n",
    "        self.norm_layers_in = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])\n",
    "        self.norm_layers_out = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim * self.num_subspaces\n",
    "\n",
    "    def pre_attention(self, concat_subspaces: Sequence[torch.Tensor]):\n",
    "        x = concat_subspaces\n",
    "\n",
    "        x = x.split(self.subspace_dim, dim=1)\n",
    "        x = [norm_layer(a) for norm_layer, a in zip(self.norm_layers_in, x)]\n",
    "        x = torch.stack(x, dim=1)\n",
    "        return x\n",
    "\n",
    "    def get_attention_weights(self, concat_subspaces: Sequence[torch.Tensor], attention_idx: int = 0):\n",
    "        assert attention_idx == 0\n",
    "        x = self.pre_attention(concat_subspaces)\n",
    "        return self.attention(query=x, key=x, value=x)[1]\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]):\n",
    "        x = self.pre_attention(concat_subspaces)\n",
    "        x, _ = self.attention(query=x, key=x, value=x)\n",
    "        return x.reshape((x.shape[0], -1))\n",
    "\n",
    "\n",
    "class PureIdentity(torch.nn.Identity):\n",
    "    \"\"\"Aggregator that does absolutely nothing\"\"\"\n",
    "\n",
    "    pass\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return 300\n",
    "\n",
    "\n",
    "# Decide what to run here\n",
    "projections_aggregation_to_use = [\n",
    "    ([\"Absolute\"], PureIdentity),\n",
    "    ([\"Absolute\"], Transformer),\n",
    "    ([\"Cosine\"], PureIdentity),\n",
    "    ([\"Euclidean\"], PureIdentity),\n",
    "    ([\"CenterCosine\"], PureIdentity),\n",
    "    ([\"NormEuclidean\"], PureIdentity),\n",
    "    #\n",
    "    ([\"Absolute\"], Identity),\n",
    "    ([\"Cosine\"], Identity),\n",
    "    ([\"Euclidean\"], Identity),\n",
    "    ([\"CenterCosine\"], Identity),\n",
    "    ([\"NormEuclidean\"], Identity),\n",
    "] + list(pair_powerset([\"Cosine\", \"CenterCosine\", \"Euclidean\", \"NormEuclidean\"], NewSelfAttentionLayer))\n",
    "\n",
    "SEEDS = list(range(3))\n",
    "\n",
    "\n",
    "def normalize_name(name) -> str:\n",
    "    \"\"\"Utility function to normalize the name of the model.\n",
    "\n",
    "    Useful to determine the key to use in the dictionary without\n",
    "    having to worry about the projection ordering when you specify them.\n",
    "\n",
    "    Do not use for labeling the attention weights, as it will not be\n",
    "    clear which projection is which due to the sorted.\n",
    "    \"\"\"\n",
    "    if isinstance(name, torch.nn.Module):\n",
    "        return name.__class__.__name__.lower()\n",
    "    if isinstance(name, Type):\n",
    "        return name.__name__.lower()\n",
    "    if isinstance(name, Sequence):\n",
    "        return \",\".join(x.lower() for x in sorted(name))\n",
    "    return name.lower()\n",
    "\n",
    "\n",
    "# Train models absolute models\n",
    "reltype2seed2model = defaultdict(dict)\n",
    "for seed, (projection_names, aggregator) in (\n",
    "    list(itertools.product(SEEDS, projections_aggregation_to_use))\n",
    "    # bar := tqdm(list(itertools.product(SEEDS, projections_aggregation_to_use)))\n",
    "):\n",
    "    run_name = f\"{normalize_name(aggregator)}({normalize_name(projection_names)})\"\n",
    "    # bar.set_description(f\"Running {run_name} with seed {seed}\")\n",
    "    model, abs1, rel1 = train_model(\n",
    "        run_name=run_name,\n",
    "        data=data,\n",
    "        projection_funcs=projection_names,\n",
    "        aggregation_module=aggregator,\n",
    "        seed=seed,\n",
    "    )\n",
    "    reltype2seed2model[run_name][seed] = model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193038479,
     "user_tz": -120
    },
    "id": "qaKHLHuB5Fv1"
   },
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "import itertools\n",
    "import pandas as pd\n",
    "\n",
    "Result = namedtuple(\"Result\", [\"seed1\", \"seed2\", \"run_name\", \"train_acc\", \"val_acc\", \"test_acc\"])\n",
    "\n",
    "results = []\n",
    "for seed1, seed2 in tqdm(list(itertools.product(SEEDS, SEEDS))):\n",
    "    for reltype, seed2model in reltype2seed2model.items():\n",
    "        model1 = seed2model[seed1]\n",
    "        model2 = seed2model[seed2]\n",
    "        stitched_model = StitchedModel(encoder=model1, decoder=model2).eval().to(device)\n",
    "        train_acc, val_acc, test_acc = test(stitched_model, data)\n",
    "        stitched_model.cpu()\n",
    "        results.append(\n",
    "            Result(\n",
    "                seed1=seed1,\n",
    "                seed2=seed2,\n",
    "                run_name=reltype,\n",
    "                train_acc=train_acc,\n",
    "                val_acc=val_acc,\n",
    "                test_acc=test_acc,\n",
    "            )\n",
    "        )\n",
    "\n",
    "df = pd.DataFrame(results)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Store results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.DataFrame(results)\n",
    "# df.to_csv(\"results.csv\", index=False, sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stitching performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"All possible stitching across {len(SEEDS)} seeds\")\n",
    "df = pd.DataFrame(results)\n",
    "df = df[df.seed1 != df.seed2].groupby([\"run_name\"]).agg([\"mean\", \"std\"]).drop(columns=[\"seed1\", \"seed2\"])\n",
    "df.sort_values(by=[\"run_name\"], key=lambda x: x.str.len()).style.highlight_max(color=\"gray\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# End2End performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(results)\n",
    "df = df[df.seed1 == df.seed2].groupby([\"run_name\"]).agg([\"mean\", \"std\"]).drop(columns=[\"seed1\", \"seed2\"])\n",
    "print(\"End to End\")\n",
    "df.sort_values(by=[\"run_name\"], key=lambda x: x.str.len()).style.highlight_max(color=\"gray\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention weights analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "\n",
    "# Select the model to use by: projection_names, aggregation_module, seed\n",
    "# Usable only with aggregator that expose the method: get_attention_weights\n",
    "SEEDS = [0, 1, 2]\n",
    "aggregator = NewSelfAttentionLayer\n",
    "projection_names = [\"CenterCosine\", \"Cosine\", \"Euclidean\", \"NormEuclidean\"]\n",
    "\n",
    "# Selection which attention to look at\n",
    "ATTENTION_INDEX = 0\n",
    "\n",
    "run_name = f\"{normalize_name(aggregator)}({normalize_name(projection_names)})\"\n",
    "\n",
    "print(run_name, f\"{ATTENTION_INDEX=} {SEEDS=}\")\n",
    "data = data.cuda()\n",
    "\n",
    "w_across_seeds = []\n",
    "for seed in SEEDS:\n",
    "    model = reltype2seed2model[run_name][seed].eval().cuda()\n",
    "    w = model.get_attention_weights(data.x, data.edge_index, attention_idx=ATTENTION_INDEX)[\n",
    "        data.test_mask\n",
    "    ]  # TODO: usa solo test set\n",
    "    w = w.mean(0)  # average over all data\n",
    "    w_across_seeds.append(w.detach().cpu())\n",
    "w = torch.stack(w_across_seeds).mean(0)  # average over seeds\n",
    "\n",
    "\n",
    "_ = sns.heatmap(\n",
    "    w.detach().cpu().numpy(),\n",
    "    cmap=\"Reds\",\n",
    "    annot=True,\n",
    "    fmt=\".2f\",\n",
    "    xticklabels=model.relative_block.projection_names,\n",
    "    yticklabels=model.relative_block.projection_names,\n",
    ")\n",
    "\n",
    "df = pd.DataFrame(results)\n",
    "df = (\n",
    "    df[df.seed1 == df.seed2][df.seed2.isin(SEEDS)]\n",
    "    .groupby([\"run_name\"])\n",
    "    .agg([\"mean\", \"std\"])\n",
    "    .drop(columns=[\"seed1\", \"seed2\"])\n",
    ")\n",
    "print(\"End to End performance over that seeds\")\n",
    "df.sort_values(by=[\"run_name\"], key=lambda x: x.str.len()).style.highlight_max(color=\"gray\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RelRep norm analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "projection_names_list = [\n",
    "    [\"CenterCosine\", \"Cosine\", \"Euclidean\", \"NormEuclidean\"],\n",
    "    # [\"Cosine\", \"Euclidean\"],\n",
    "    # [\"CenterCosine\", \"Euclidean\"],\n",
    "]\n",
    "\n",
    "result = []\n",
    "\n",
    "for projection_names in projection_names_list:\n",
    "    # projection_names = [ \"Cosine\", \"NormEuclidean\"]\n",
    "    run_name = f\"{normalize_name(aggregator)}({normalize_name(projection_names)})\"\n",
    "\n",
    "    model = reltype2seed2model[run_name][2].eval().cuda()\n",
    "    relrep = model.encode(data.x, data.edge_index)[data.test_mask]\n",
    "\n",
    "    relrep = relrep.reshape(\n",
    "        -1,\n",
    "        model.relative_block.aggregation_module.num_subspaces,\n",
    "        model.relative_block.aggregation_module.subspace_dim,\n",
    "    )\n",
    "\n",
    "    fig = px.violin(\n",
    "        relrep.norm(dim=-1).cpu().detach(),\n",
    "        title=\",\".join(model.relative_block.projection_names),\n",
    "    ).show()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
