{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import random\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Linear\n",
    "\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.nn import GCN2Conv\n",
    "from torch_geometric.nn.conv.gcn_conv import gcn_norm\n",
    "from rae import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import CitationFull\n",
    "from torch_geometric.transforms import RandomNodeSplit\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "dataset_name = \"Cora\"\n",
    "seed_everything(0)\n",
    "transform = T.Compose([T.NormalizeFeatures(), RandomNodeSplit(num_val=0.1, num_test=0)])\n",
    "dataset = Planetoid(PROJECT_ROOT / \"data\" / \"pyg\" / dataset_name, dataset_name, transform=transform)\n",
    "data = dataset[0]\n",
    "_, edge_weight = gcn_norm(\n",
    "    data.edge_index, num_nodes=data.x.size(0), add_self_loops=False\n",
    ")  # Pre-process GCN normalization.\n",
    "data.edge_weight = edge_weight\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.attention import RelativeAttention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.enumerations import Output\n",
    "from rae.modules.attention import AttentionOutput\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        relative: bool,\n",
    "        hidden_proj: nn.Module,\n",
    "        hidden_fn,\n",
    "        relative_proj: RelativeAttention,\n",
    "        class_proj: nn.Module,\n",
    "        convs: nn.ModuleList,\n",
    "        conv_fn,\n",
    "        conv_out: int,\n",
    "        dropout: float,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hidden_proj: nn.Module = hidden_proj\n",
    "        self.class_proj: nn.Module = class_proj\n",
    "\n",
    "        self.hidden_fn = hidden_fn\n",
    "\n",
    "        self.relative_proj = relative_proj\n",
    "\n",
    "        self.convs = convs\n",
    "        self.conv_fn = conv_fn\n",
    "        self.conv_fc = nn.Linear(in_features=conv_out, out_features=conv_out)\n",
    "\n",
    "        self.layer_norm = nn.LayerNorm(conv_out)\n",
    "\n",
    "        self.dropout = dropout\n",
    "        self.relative: bool = relative\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight, anchor_idxs: torch.Tensor):\n",
    "        x = F.dropout(x, self.dropout, training=self.training)\n",
    "        x = self.hidden_proj(x)\n",
    "\n",
    "        x = x_0 = self.hidden_fn(x)\n",
    "\n",
    "        for conv in self.convs:\n",
    "            x = F.dropout(x, self.dropout, training=self.training)\n",
    "            params = {\"edge_index\": edge_index}\n",
    "            if type(self.convs[0]).__name__ == \"GCN2Conv\":\n",
    "                params[\"x_0\"] = x_0\n",
    "                params[\"edge_weight\"] = edge_weight\n",
    "            x = conv(x, **params)\n",
    "            x = self.conv_fn(x)\n",
    "\n",
    "        x = self.conv_fc(x)\n",
    "        anchors: torch.Tensor = x[anchor_idxs, :]\n",
    "\n",
    "        rel_out = self.relative_proj(x=x, anchors=anchors)\n",
    "        assert not self.relative\n",
    "        if self.relative:\n",
    "            x = rel_out[AttentionOutput.OUTPUT]\n",
    "\n",
    "        x = F.normalize(x, p=2, dim=-1)\n",
    "        x = self.class_proj(x)\n",
    "        return {Output.LOGITS: x, Output.SIMILARITIES: rel_out[AttentionOutput.SIMILARITIES]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_anchors: int = 300\n",
    "data.anchors = torch.as_tensor(random.sample(data.train_mask.nonzero().squeeze().cpu().tolist(), num_anchors))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import GATConv, GCN2Conv, GCNConv, GINConv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encoder_factory(encoder_type, num_layers: int, in_channels: int, out_channels: int, **params):\n",
    "    assert num_layers > 0\n",
    "    if encoder_type == \"GCN2Conv\":\n",
    "        convs = []\n",
    "        for layer in range(num_layers):\n",
    "            convs.append(GCN2Conv(layer=layer + 1, channels=out_channels, **params))\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GCNConv\":\n",
    "        convs = []\n",
    "        # current_out_channels = in_channels\n",
    "        #\n",
    "        # for layer in range(num_layers):\n",
    "        #     convs.append(\n",
    "        #         GCNConv(\n",
    "        #             in_channels=current_out_channels,\n",
    "        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),\n",
    "        #             **params,\n",
    "        #         )\n",
    "        #     )\n",
    "        convs = [\n",
    "            GCNConv(\n",
    "                in_channels=in_channels,\n",
    "                out_channels=out_channels,\n",
    "                **params,\n",
    "            )\n",
    "        ]\n",
    "        in_channels = out_channels\n",
    "        for layer in range(num_layers - 1):\n",
    "            convs.append(\n",
    "                GCNConv(\n",
    "                    in_channels=in_channels,\n",
    "                    out_channels=out_channels,\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GATConv\":\n",
    "        convs = []\n",
    "\n",
    "        # for layer in range(num_layers):\n",
    "        #     convs.append(\n",
    "        #         GATConv(\n",
    "        #             in_channels=current_out_channels,\n",
    "        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),\n",
    "        #             **params,\n",
    "        #         )\n",
    "        #     )\n",
    "\n",
    "        convs = [\n",
    "            GATConv(\n",
    "                in_channels=in_channels,\n",
    "                out_channels=out_channels,\n",
    "                **params,\n",
    "            )\n",
    "        ]\n",
    "        in_channels = out_channels\n",
    "        for layer in range(num_layers - 1):\n",
    "            convs.append(\n",
    "                GATConv(\n",
    "                    in_channels=in_channels,\n",
    "                    out_channels=out_channels,\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GINConv\":\n",
    "        convs = []\n",
    "        # current_out_channels = in_channels\n",
    "        #\n",
    "        # for layer in range(num_layers):\n",
    "        #     convs.append(\n",
    "        #         GINConv(\n",
    "        #             nn=nn.Linear(\n",
    "        #                 in_features=current_out_channels,\n",
    "        #                 out_features=(current_out_channels := max(out_channels, current_out_channels // 2)),\n",
    "        #             )\n",
    "        #         )\n",
    "        #     )\n",
    "        current_in_channels = in_channels\n",
    "        for layer in range(num_layers):\n",
    "            convs.append(\n",
    "                GINConv(\n",
    "                    nn=nn.Linear(\n",
    "                        in_features=current_in_channels,\n",
    "                        out_features=out_channels,\n",
    "                    ),\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "            current_in_channels = out_channels\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import functools\n",
    "from pytorch_lightning.utilities.seed import log as seed_log\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "\n",
    "import random\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# General SWEEP\n",
    "sweep = {\n",
    "    \"seed\": list(range(5)),\n",
    "    # \"seed_index\": [0],\n",
    "    \"num_epochs\": [10, 30, 50],\n",
    "    \"in_channels\": [num_anchors],\n",
    "    # \"out_channels\": [10, 32, 64],\n",
    "    \"out_channels\": [num_anchors],\n",
    "    \"num_layers\": [64, 32],\n",
    "    \"dropout\": [0.1, 0.5],\n",
    "    # \"hidden_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    # \"conv_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    \"hidden_fn\": [torch.nn.ReLU(), torch.nn.Tanh()],\n",
    "    \"conv_fn\": [torch.nn.ReLU(), torch.nn.Tanh()],\n",
    "    \"optimizer\": [torch.optim.Adam, torch.optim.SGD],\n",
    "    \"lr\": [0.01, 0.02],\n",
    "    \"encoder\": [\n",
    "        (\n",
    "            \"GCN2Conv\",\n",
    "            functools.partial(\n",
    "                encoder_factory,\n",
    "                encoder_type=\"GCN2Conv\",\n",
    "                **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),\n",
    "            ),\n",
    "        ),\n",
    "        # (\"GCNConv\", functools.partial(encoder_factory, encoder_type=\"GCNConv\")),\n",
    "        # (\"GATConv\", functools.partial(encoder_factory, encoder_type=\"GATConv\")),\n",
    "        (\"GINConv\", functools.partial(encoder_factory, encoder_type=\"GINConv\")),\n",
    "    ],\n",
    "}\n",
    "\n",
    "# Best model config\n",
    "# sweep = {\n",
    "#     \"seed\": [1],\n",
    "#     #\"seed_index\": [0],\n",
    "#     \"num_epochs\": [500],\n",
    "#     \"in_channels\": [num_anchors],\n",
    "#     # \"out_channels\": [10, 32, 64],\n",
    "#     \"out_channels\": [num_anchors],\n",
    "#     \"num_layers\": [32],\n",
    "#     \"dropout\": [0.5],\n",
    "#     # \"hidden_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "#     # \"conv_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "#     \"hidden_fn\": [torch.nn.ReLU()],\n",
    "#     \"conv_fn\": [torch.nn.ReLU()],\n",
    "#     \"optimizer\": [torch.optim.Adam],\n",
    "#     \"lr\": [0.02],\n",
    "#     \"encoder\": [\n",
    "#         (\n",
    "#             \"GCN2Conv\",\n",
    "#             functools.partial(\n",
    "#                 encoder_factory,\n",
    "#                 encoder_type=\"GCN2Conv\",\n",
    "#                 **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),\n",
    "#             ),\n",
    "#         ),\n",
    "#         # (\"GCNConv\", functools.partial(encoder_factory, encoder_type=\"GCNConv\")),\n",
    "#         # (\"GATConv\", functools.partial(encoder_factory, encoder_type=\"GATConv\")),\n",
    "# #         (\"GINConv\", functools.partial(encoder_factory, encoder_type=\"GINConv\")),\n",
    "#     ],\n",
    "# }\n",
    "\n",
    "\n",
    "# relative_proj = RelativeAttention(\n",
    "#     n_anchors=num_anchors,\n",
    "#     n_classes=dataset.num_classes,\n",
    "#     similarity_mode=\"inner\",\n",
    "#     values_mode=\"similarities\",\n",
    "#     normalization_mode=\"l2\",\n",
    "# )\n",
    "# keys, values = zip(*sweep.items())\n",
    "# experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]\n",
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "experiments = ParameterGrid(sweep)\n",
    "f\"Total available experiments={len(experiments)}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(model, optimizer, data):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)\n",
    "    logits = out[Output.LOGITS]\n",
    "    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return float(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0].train_mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def test(model, data):\n",
    "    model.eval()\n",
    "    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)\n",
    "    pred = out[Output.LOGITS].argmax(dim=-1)\n",
    "\n",
    "    accs = []\n",
    "    for _, mask in data(\"train_mask\", \"val_mask\"):\n",
    "        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n",
    "    return out, accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_best = torch.load(PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{'Cora'}_best_run.pt\")\n",
    "reference_latents = [old_best[\"best_epoch\"][\"rel_x\"]]\n",
    "reference_latents[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def get_distance(latents1: torch.Tensor, latents_ref: Sequence[torch.Tensor]):\n",
    "    assert not isinstance(latents_ref, (np.ndarray, torch.Tensor))\n",
    "    dists = [\n",
    "        F.cosine_similarity(\n",
    "            latents1,\n",
    "            latent_ref,\n",
    "        )\n",
    "        .mean()\n",
    "        .item()\n",
    "        for latent_ref in latents_ref\n",
    "    ]\n",
    "    return np.mean(dists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relative: bool = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterSampler, ParameterGrid\n",
    "import logging\n",
    "from tqdm import tqdm\n",
    "from pprint import pprint\n",
    "from rae.utils.utils import to_device\n",
    "\n",
    "experiments = []\n",
    "stats = {x: [] for x in (\"experiment\", \"epoch\", \"loss\", \"train_acc\", \"val_acc\", \"reference_distance\")}\n",
    "\n",
    "# for i, experiment in enumerate(pbar := tqdm(ParameterSampler(sweep, n_iter=100, random_state=42), desc=\"Experiment\")):\n",
    "for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc=\"Experiment\")):\n",
    "    encoder_name, encoder_build = experiment[\"encoder\"]\n",
    "    if encoder_name == \"GCN2Conv\":\n",
    "        experiment[\"out_channels\"] = num_anchors\n",
    "        experiment[\"in_channels\"] = num_anchors\n",
    "\n",
    "    seed: int = experiment[\"seed\"]\n",
    "    temp_log_level = seed_log.getEffectiveLevel()\n",
    "    seed_log.setLevel(logging.ERROR)\n",
    "    seed_everything(seed)\n",
    "    seed_log.setLevel(temp_log_level)\n",
    "\n",
    "    hidden_proj = nn.Linear(dataset.num_features, experiment[\"in_channels\"])\n",
    "\n",
    "    convs = encoder_build(\n",
    "        num_layers=experiment[\"num_layers\"],\n",
    "        in_channels=experiment[\"in_channels\"],\n",
    "        out_channels=experiment[\"out_channels\"],\n",
    "    )\n",
    "    class_proj = nn.Linear(experiment[\"out_channels\"], dataset.num_classes)\n",
    "\n",
    "    model = Net(\n",
    "        relative=relative,\n",
    "        hidden_proj=hidden_proj,\n",
    "        hidden_fn=experiment[\"hidden_fn\"],\n",
    "        relative_proj=relative_proj,\n",
    "        class_proj=class_proj,\n",
    "        convs=convs,\n",
    "        conv_fn=experiment[\"conv_fn\"],\n",
    "        conv_out=experiment[\"out_channels\"],\n",
    "        dropout=experiment[\"dropout\"],\n",
    "    ).to(DEVICE)\n",
    "    data = data.to(DEVICE)\n",
    "    optimizer = experiment[\"optimizer\"](\n",
    "        model.parameters(),\n",
    "        lr=experiment[\"lr\"],\n",
    "    )\n",
    "\n",
    "    best_val_acc = 0\n",
    "    best_epoch = None\n",
    "    epochs = []\n",
    "    for epoch in range(experiment[\"num_epochs\"]):\n",
    "        loss = train_step(model=model, optimizer=optimizer, data=data)\n",
    "        model_out, (train_acc, val_acc) = test(model=model, data=data)\n",
    "        # epochs.append(epoch_out)\n",
    "\n",
    "        stats[\"experiment\"].append(i)\n",
    "        stats[\"epoch\"].append(epoch)\n",
    "        stats[\"loss\"].append(loss)\n",
    "        stats[\"train_acc\"].append(train_acc)\n",
    "        stats[\"val_acc\"].append(val_acc)\n",
    "        stats[\"reference_distance\"].append(\n",
    "            get_distance(\n",
    "                latents1=model_out[Output.SIMILARITIES].to(\"cpu\", non_blocking=True), latents_ref=reference_latents\n",
    "            )\n",
    "            if reference_latents is not None\n",
    "            else None\n",
    "        )\n",
    "\n",
    "        if val_acc > best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            best_epoch = {\n",
    "                \"rel_x\": model_out[Output.SIMILARITIES].to(\"cpu\", non_blocking=True),\n",
    "                \"epoch\": epoch,\n",
    "                \"loss\": loss,\n",
    "                \"train_acc\": train_acc,\n",
    "                \"val_acc\": val_acc,\n",
    "            }\n",
    "        # print(\n",
    "        #     f\"Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, \"\n",
    "        #     f\"Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, \"\n",
    "        #     f\"Final Test: {test_acc:.4f}\"\n",
    "        # )\n",
    "    experiment[\"best_epoch\"] = best_epoch\n",
    "    # experiment[\"epochs\"] = epochs\n",
    "    # best_epoch = epochs[best_epoch]\n",
    "    pbar.set_description(\n",
    "        f\"Epoch: {best_epoch['epoch']:04d}, Loss: {best_epoch['loss']:.4f} Train: {best_epoch['train_acc']:.4f}, \"\n",
    "        f\"Val: {best_epoch['val_acc']:.4f} Dist: {stats['reference_distance'][-1]}\"\n",
    "    )\n",
    "\n",
    "    experiments.append(experiment)\n",
    "    model.cpu()\n",
    "\n",
    "stats = pd.DataFrame(stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(\n",
    "    experiments, PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{dataset_name}_data_manifold_experiments.pt\"\n",
    ")\n",
    "stats.to_csv(PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{dataset_name}_data_manifold_stats.tsv\", sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(experiments[0], PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{dataset_name}_best_run.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
