{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193027714,
     "user_tz": -120
    },
    "id": "nIBBcWxa4DB5"
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os.path as osp\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.logging import log\n",
    "from torch_geometric.nn import GCNConv\n",
    "from typing import Callable, Sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from latent_invariances.modules.aggregation_modules import (\n",
    "    Identity,\n",
    "    SumAggregation,\n",
    "    ConcatAggregation,\n",
    "    SelfAttentionLayer,\n",
    "    WeightedAvgAggregation,\n",
    "    NormSelfAttentionLayer,\n",
    "    NonLinearSumAggregation,\n",
    "    NonLinearWeightedAvgAggregation,\n",
    ")\n",
    "from latent_invariances.utils.relreps import SIMPLE_PROJECTION_TYPE\n",
    "\n",
    "\n",
    "# x = torch.randn(10, 5)\n",
    "# anchors = torch.randn(5, 5)\n",
    "# for projection_funcs, aggregation_module in projections_aggregation_to_use:\n",
    "#     block = RelativeBlock(\n",
    "#         projection_funcs,\n",
    "#         aggregation_module(subspace_dim=anchors.shape[0], num_subspaces=len(projection_funcs)),\n",
    "#     )\n",
    "#     print(block)\n",
    "#     print(block(x, anchors).shape)\n",
    "#     break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StitchedModel(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        encoder: torch.nn.Module,\n",
    "        decoder: torch.nn.Module,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder.eval()\n",
    "        self.decoder = decoder.eval()\n",
    "\n",
    "    def forward(self, x: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:\n",
    "        rel_x = self.encoder.encode(x, anchors)\n",
    "        return self.decoder.decode(rel_x)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"StitchedModel({self.encoder}, {self.decoder})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193027714,
     "user_tz": -120
    },
    "id": "Q2dHH67D4DB6"
   },
   "outputs": [],
   "source": [
    "def train(model, data, optimizer):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data.x, data.edge_index)\n",
    "    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, data):\n",
    "    model.eval()\n",
    "    pred = model(data.x, data.edge_index).argmax(dim=-1)\n",
    "\n",
    "    accs = []\n",
    "    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n",
    "        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n",
    "    return accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193027715,
     "user_tz": -120
    },
    "id": "8LPG5zXv4DB6"
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "\n",
    "def plot(h, y, pca=None):\n",
    "    if h.shape[1] > 2:\n",
    "        if pca is not None:\n",
    "            h = pca.transform(h)\n",
    "        else:\n",
    "            pca = PCA(n_components=2)\n",
    "            h = pca.fit_transform(h)\n",
    "\n",
    "    assert h.shape[1] == 2\n",
    "    # Extract x, y, and color information from the points\n",
    "    x_values = [x[0].item() for x in h]\n",
    "    y_values = [x[1].item() for x in h]\n",
    "    colors = [str(y.item()) for y in y]\n",
    "\n",
    "    # Create a DataFrame from the extracted data\n",
    "    d = {\"x\": x_values, \"y\": y_values, \"class\": colors}\n",
    "    df = pd.DataFrame(d)\n",
    "\n",
    "    # Plot the scatterplot\n",
    "    sns.scatterplot(data=df, x=\"x\", y=\"y\", hue=\"class\")\n",
    "\n",
    "    # Show the plot\n",
    "    plt.show()\n",
    "\n",
    "    return pca"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193027715,
     "user_tz": -120
    },
    "id": "Nk4VMlD14DB6"
   },
   "outputs": [],
   "source": [
    "from lightning import seed_everything\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tl37fspl4DB6"
   },
   "source": [
    "## Training models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain, combinations\n",
    "\n",
    "\n",
    "def pair_powerset(iterable, pairing):\n",
    "    \"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)\"\n",
    "    s = list(iterable)\n",
    "    return ((x, pairing) for x in chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) if x)\n",
    "\n",
    "\n",
    "list(pair_powerset([\"Cosine\", \"CenterCosine\", \"Euclidean\", \"NormEuclidean\"], \"t\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Type\n",
    "from torch import nn\n",
    "\n",
    "### Hyperparams\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hidden_channels = 300\n",
    "lr = 1e-4\n",
    "epochs = 150\n",
    "### ---------------------\n",
    "\n",
    "\n",
    "path = osp.join(\"data\", \"Planetoid\")\n",
    "dataset = Planetoid(path, \"Cora\", transform=T.NormalizeFeatures())\n",
    "data = dataset[0]\n",
    "num_nodes = data.x.shape[0]\n",
    "anchor_indices = torch.randperm(data.x.size(0))[:hidden_channels]\n",
    "\n",
    "\n",
    "def train_model_and_plot_embeddings(run_name, data, projection_funcs, aggregation_module, seed):\n",
    "    # TODO: decommenta\n",
    "    # seed_everything(seed)\n",
    "    model = GNN(\n",
    "        dataset.num_features,\n",
    "        hidden_channels,\n",
    "        dataset.num_classes,\n",
    "        relative_block=RelativeBlock(\n",
    "            projection_funcs,\n",
    "            aggregation_module(\n",
    "                subspace_dim=anchor_indices.shape[0],\n",
    "                num_subspaces=len(projection_funcs),\n",
    "            ),\n",
    "        ),\n",
    "        anchor_indices=anchor_indices,\n",
    "    )\n",
    "    model, data = model.to(device), data.to(device)\n",
    "    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)\n",
    "\n",
    "    best_val_acc = final_test_acc = 0\n",
    "    for epoch in (bar := tqdm(range(1, epochs + 1))):\n",
    "        loss = train(model, data, optimizer)\n",
    "        train_acc, val_acc, tmp_test_acc = test(model, data)\n",
    "        if val_acc > best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            test_acc = tmp_test_acc\n",
    "\n",
    "        bar.set_description(\n",
    "            f\"{run_name} [{epoch}: {loss:.4f}] (train: {train_acc:.4f}) (val: {val_acc:.4f}) (test: {test_acc:.4f})\"\n",
    "        )\n",
    "\n",
    "    model.eval()\n",
    "    h = model.encode(data.x, data.edge_index)\n",
    "\n",
    "    rel_h = F.normalize(h, dim=-1, p=2)\n",
    "    rel_h = rel_h @ rel_h[anchor_indices].T\n",
    "\n",
    "    return model.eval().cpu(), h, rel_h\n",
    "\n",
    "\n",
    "class PureIdentity(torch.nn.Identity):\n",
    "    pass\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return 300\n",
    "\n",
    "\n",
    "class RelativeBlock(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        projection_names: Sequence[str],\n",
    "        aggregation_module: torch.nn.Module,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.projection_names = projection_names\n",
    "        self.projection_funcs: Sequence[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = [\n",
    "            SIMPLE_PROJECTION_TYPE[x] for x in projection_names\n",
    "        ]\n",
    "        self.aggregation_module = aggregation_module\n",
    "\n",
    "    def forward(self, x: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:\n",
    "        rel_x = self.encode(x, anchors)\n",
    "        return self.decode(rel_x)\n",
    "\n",
    "    def encode(self, x, anchors) -> torch.Tensor:\n",
    "        return torch.cat([fun(anchors=anchors, points=x) for fun in self.projection_funcs], dim=-1)\n",
    "\n",
    "    def decode(self, rel_x: torch.Tensor) -> torch.Tensor:\n",
    "        return self.aggregation_module(rel_x)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"RelativeBlock({self.projection_names}, {self.aggregation_module})\"\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"RelativeBlock({self.projection_names=}, {self.aggregation_module=})\"\n",
    "\n",
    "\n",
    "class GNN(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels,\n",
    "        hidden_channels,\n",
    "        out_channels,\n",
    "        relative_block: RelativeBlock,\n",
    "        anchor_indices: torch.Tensor,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)\n",
    "        self.conv2 = GCNConv(hidden_channels, hidden_channels, cached=True)\n",
    "\n",
    "        self.register_buffer(\"anchor_indices\", anchor_indices)\n",
    "        self.relative_block = relative_block\n",
    "        self.decoder = torch.nn.Linear(relative_block.aggregation_module.out_dim, out_channels)\n",
    "\n",
    "    def encode(self, x, edge_index):\n",
    "        features = self.conv1(x, edge_index).relu()\n",
    "        features = self.conv2(features, edge_index)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            anchors = features[self.anchor_indices]\n",
    "\n",
    "        x = self.relative_block.encode(x=features, anchors=anchors)\n",
    "        return x\n",
    "\n",
    "    def decode(self, x):\n",
    "        x = self.relative_block.decode(x)\n",
    "        return self.decoder(x)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.encode(x, edge_index)\n",
    "\n",
    "        return self.decode(x)\n",
    "\n",
    "    def get_attention_weights(self, x, edge_index, attention_idx: int = 0):\n",
    "        x = self.encode(x, edge_index)\n",
    "        return self.relative_block.aggregation_module.get_attention_weights(x, attention_idx=attention_idx)\n",
    "\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, k, heads):\n",
    "        super().__init__()\n",
    "        self.attention = nn.MultiheadAttention(embed_dim=k, num_heads=heads, batch_first=True, dropout=0.1)\n",
    "        self.norm1 = nn.LayerNorm(k)\n",
    "        self.norm2 = nn.LayerNorm(k)\n",
    "        self.ff = nn.Sequential(nn.Linear(k, k // 4), nn.ReLU(), nn.Linear(k // 4, k))\n",
    "\n",
    "    def forward(self, x):\n",
    "        attended = self.attention(query=x, key=x, value=x)[0]\n",
    "        x = self.norm1(attended + x)\n",
    "        fedforward = self.ff(x)\n",
    "        return self.norm2(fedforward + x)\n",
    "\n",
    "    def get_attention_weights(self, x):\n",
    "        return self.attention(query=x, key=x, value=x)[1]\n",
    "\n",
    "\n",
    "class Transformer(torch.nn.Module):\n",
    "    def __init__(self, subspace_dim: int, num_subspaces: int, batch_size: int = 16):\n",
    "        super().__init__()\n",
    "        depth = 1\n",
    "        heads = 1\n",
    "        self.subspace_dim = subspace_dim\n",
    "        self.num_subspaces = num_subspaces\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "        self.tblocks = nn.Sequential(*[TransformerBlock(k=self.subspace_dim, heads=heads) for _ in range(depth)])\n",
    "\n",
    "    def get_attention_weights(self, concat_subspaces: Sequence[torch.Tensor], attention_idx: int = 0):\n",
    "        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)\n",
    "        attention_weights = self.tblocks[attention_idx].get_attention_weights(query)\n",
    "        return attention_weights\n",
    "\n",
    "    @property\n",
    "    def out_dim(self):\n",
    "        return self.subspace_dim\n",
    "\n",
    "    def forward(self, concat_subspaces: Sequence[torch.Tensor]):\n",
    "        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)\n",
    "        out = self.tblocks(query)\n",
    "        return torch.mean(out, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 5987,
     "status": "ok",
     "timestamp": 1691193033698,
     "user_tz": -120
    },
    "id": "iivQReZ94DB6",
    "outputId": "02e5d6c7-0fb3-46ba-b71b-4735b5e7373e"
   },
   "outputs": [],
   "source": [
    "SEEDS = list(range(3))\n",
    "\n",
    "projections_aggregation_to_use = [\n",
    "    # ([\"Absolute\"], PureIdentity),\n",
    "    # ([\"Absolute\"], Transformer),\n",
    "    # ([\"Cosine\"], PureIdentity),\n",
    "    # ([\"Euclidean\"], PureIdentity),\n",
    "    # ([\"CenterCosine\"], PureIdentity),\n",
    "    # ([\"NormEuclidean\"], PureIdentity),\n",
    "    # #\n",
    "    # ([\"Absolute\"], Identity),\n",
    "    # ([\"Cosine\"], Identity),\n",
    "    # ([\"Euclidean\"], Identity),\n",
    "    # ([\"CenterCosine\"], Identity),\n",
    "    # ([\"NormEuclidean\"], Identity),\n",
    "] + list(pair_powerset([\"Cosine\", \"CenterCosine\", \"Euclidean\", \"NormEuclidean\"], Transformer))\n",
    "\n",
    "\n",
    "def normalize_name(name) -> str:\n",
    "    if isinstance(name, torch.nn.Module):\n",
    "        return name.__class__.__name__.lower()\n",
    "    if isinstance(name, Type):\n",
    "        return name.__name__.lower()\n",
    "    if isinstance(name, Sequence):\n",
    "        return \",\".join(x.lower() for x in sorted(name))\n",
    "    return name.lower()\n",
    "\n",
    "\n",
    "# Train models\n",
    "import itertools\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "\n",
    "reltype2seed2model = defaultdict(dict)\n",
    "for seed, (projection_names, aggregator) in (\n",
    "    list(itertools.product(SEEDS, projections_aggregation_to_use))\n",
    "    # bar := tqdm(list(itertools.product(SEEDS, projections_aggregation_to_use)))\n",
    "):\n",
    "    run_name = f\"{normalize_name(aggregator)}({normalize_name(projection_names)})\"\n",
    "    # bar.set_description(f\"Running {run_name} with seed {seed}\")\n",
    "    model, abs1, rel1 = train_model_and_plot_embeddings(\n",
    "        run_name=run_name,\n",
    "        data=data,\n",
    "        projection_funcs=projection_names,\n",
    "        aggregation_module=aggregator,\n",
    "        seed=seed,\n",
    "    )\n",
    "    reltype2seed2model[run_name][seed] = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Debug\n",
    "# list(pair_powerset([\"Cosine\", \"CenterCosine\", \"Euclidean\", \"NormEuclidean\"], Transformer))[-1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1691193038479,
     "user_tz": -120
    },
    "id": "qaKHLHuB5Fv1"
   },
   "outputs": [],
   "source": [
    "# Stitching results\n",
    "\n",
    "from collections import namedtuple\n",
    "import itertools\n",
    "\n",
    "Result = namedtuple(\"Result\", [\"seed1\", \"seed2\", \"run_name\", \"train_acc\", \"val_acc\", \"test_acc\"])\n",
    "\n",
    "results = []\n",
    "for seed1, seed2 in tqdm(list(itertools.product(SEEDS, SEEDS))):\n",
    "    for reltype, seed2model in reltype2seed2model.items():\n",
    "        model1 = seed2model[seed1]\n",
    "        model2 = seed2model[seed2]\n",
    "        stitched_model = StitchedModel(encoder=model1, decoder=model2).eval().to(device)\n",
    "        train_acc, val_acc, test_acc = test(stitched_model, data)\n",
    "        stitched_model.cpu()\n",
    "        results.append(\n",
    "            Result(\n",
    "                seed1=seed1,\n",
    "                seed2=seed2,\n",
    "                run_name=reltype,\n",
    "                train_acc=train_acc,\n",
    "                val_acc=val_acc,\n",
    "                test_acc=test_acc,\n",
    "            )\n",
    "        )\n",
    "        # train_acc1, val_acc1, test_acc1 = test(model1, data)\n",
    "        # train_acc2, val_acc2, test_acc2 = test(model2, data)\n",
    "        # results.append(Result(seed1, reltype, train_acc1, val_acc1, test_acc1))\n",
    "        # results.append(Result(seed2, reltype, train_acc2, val_acc2, test_acc2))\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame(results)\n",
    "df = df[df.seed1 != df.seed2].groupby([\"run_name\"]).agg([\"mean\", \"std\"]).drop(columns=[\"seed1\", \"seed2\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(results)\n",
    "df.to_csv(\"graph_results_depth1.csv\", index=False, sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"All possible stitching across 10 seeds\")\n",
    "df = pd.DataFrame(results)\n",
    "df = df[df.seed1 != df.seed2].groupby([\"run_name\"]).agg([\"mean\", \"std\"]).drop(columns=[\"seed1\", \"seed2\"])\n",
    "df.sort_values(by=[\"run_name\"], key=lambda x: x.str.len()).style.highlight_max(color=\"gray\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(results)\n",
    "df = df[df.seed1 == df.seed2].groupby([\"run_name\"]).agg([\"mean\", \"std\"]).drop(columns=[\"seed1\", \"seed2\"])\n",
    "print(\"End to End\")\n",
    "df.sort_values(by=[\"run_name\"], key=lambda x: x.str.len()).style.highlight_max(color=\"gray\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention weights analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "ATTENTION_INDEX = 0\n",
    "\n",
    "aggregator = Transformer\n",
    "projection_names = [\"CenterCosine\", \"Cosine\", \"Euclidean\", \"NormEuclidean\"]\n",
    "# projection_names = [ \"Cosine\", \"NormEuclidean\"]\n",
    "run_name = f\"{normalize_name(aggregator)}({normalize_name(projection_names)})\"\n",
    "\n",
    "print(run_name, f\"{ATTENTION_INDEX=}\")\n",
    "\n",
    "w_across_seeds = []\n",
    "for seed in SEEDS:\n",
    "    print(f\"{seed=}\")\n",
    "    model = reltype2seed2model[run_name][seed].eval().cuda()\n",
    "    w = model.get_attention_weights(data.x, data.edge_index, attention_idx=ATTENTION_INDEX)\n",
    "    w = w.mean(0)  # average over all data\n",
    "    w_across_seeds.append(w.detach().cpu())\n",
    "w = torch.stack(w_across_seeds).mean(0)\n",
    "\n",
    "_ = sns.heatmap(\n",
    "    w.detach().cpu().numpy(),\n",
    "    cmap=\"Reds\",\n",
    "    annot=True,\n",
    "    fmt=\".2f\",\n",
    "    xticklabels=model.relative_block.projection_names,\n",
    "    yticklabels=model.relative_block.projection_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
