{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import random\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Linear\n",
    "\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.nn import GCN2Conv\n",
    "from torch_geometric.nn.conv.gcn_conv import gcn_norm\n",
    "from rae import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import CitationFull\n",
    "from torch_geometric.transforms import RandomNodeSplit\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "dataset_name = \"Cora\"\n",
    "seed_everything(0)\n",
    "transform = T.Compose([T.NormalizeFeatures(), RandomNodeSplit(num_val=0.1, num_test=0)])\n",
    "dataset = Planetoid(PROJECT_ROOT / \"data\" / \"pyg\" / dataset_name, dataset_name, transform=transform)\n",
    "data = dataset[0]\n",
    "_, edge_weight = gcn_norm(\n",
    "    data.edge_index, num_nodes=data.x.size(0), add_self_loops=False\n",
    ")  # Pre-process GCN normalization.\n",
    "data.edge_weight = edge_weight\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.attention import RelativeAttention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.enumerations import Output\n",
    "from rae.modules.attention import AttentionOutput\n",
    "from torch import nn\n",
    "from typing import *\n",
    "import numpy as np\n",
    "from sklearn.model_selection import ParameterSampler, ParameterGrid\n",
    "import logging\n",
    "from tqdm import tqdm\n",
    "from pprint import pprint\n",
    "from rae.utils.utils import to_device\n",
    "import functools\n",
    "import itertools\n",
    "import functools\n",
    "from pytorch_lightning.utilities.seed import log as seed_log\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "\n",
    "import random\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        hidden_proj: nn.Module,\n",
    "        hidden_fn,\n",
    "        relative_proj: RelativeAttention,\n",
    "        class_proj: nn.Module,\n",
    "        convs: nn.ModuleList,\n",
    "        conv_fn,\n",
    "        conv_out: int,\n",
    "        dropout: float,\n",
    "        relative: bool = True,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hidden_proj: nn.Module = hidden_proj\n",
    "        self.class_proj: nn.Module = class_proj\n",
    "\n",
    "        self.hidden_fn = hidden_fn\n",
    "\n",
    "        self.relative_proj = relative_proj\n",
    "\n",
    "        self.convs = convs\n",
    "        self.conv_fn = conv_fn\n",
    "        self.conv_fc = nn.Linear(in_features=conv_out, out_features=conv_out)\n",
    "\n",
    "        self.layer_norm = nn.LayerNorm(conv_out)\n",
    "\n",
    "        self.dropout = dropout\n",
    "        self.relative: bool = relative\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight, anchor_idxs: torch.Tensor, only_absolute: bool = False):\n",
    "        x = F.dropout(x, self.dropout, training=self.training)\n",
    "        x = self.hidden_proj(x)\n",
    "\n",
    "        x = x_0 = self.hidden_fn(x)\n",
    "\n",
    "        for conv in self.convs:\n",
    "            x = F.dropout(x, self.dropout, training=self.training)\n",
    "            params = {\"edge_index\": edge_index}\n",
    "            if type(self.convs[0]).__name__ == \"GCN2Conv\":\n",
    "                params[\"x_0\"] = x_0\n",
    "                params[\"edge_weight\"] = edge_weight\n",
    "            x = conv(x, **params)\n",
    "            x = self.conv_fn(x)\n",
    "\n",
    "        x = self.conv_fc(x)\n",
    "\n",
    "        if only_absolute:\n",
    "            return x\n",
    "        anchors: torch.Tensor = x[anchor_idxs, :]\n",
    "        rel_out = self.relative_proj(x=x, anchors=anchors)\n",
    "\n",
    "        x = self.class_proj(x)\n",
    "        return {Output.LOGITS: x, Output.SIMILARITIES: rel_out[AttentionOutput.SIMILARITIES]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import GATConv, GCN2Conv, GCNConv, GINConv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encoder_factory(encoder_type, num_layers: int, in_channels: int, out_channels: int, **params):\n",
    "    assert num_layers > 0\n",
    "    if encoder_type == \"GCN2Conv\":\n",
    "        convs = []\n",
    "        for layer in range(num_layers):\n",
    "            convs.append(GCN2Conv(layer=layer + 1, channels=out_channels, **params))\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GCNConv\":\n",
    "        convs = []\n",
    "        # current_out_channels = in_channels\n",
    "        #\n",
    "        # for layer in range(num_layers):\n",
    "        #     convs.append(\n",
    "        #         GCNConv(\n",
    "        #             in_channels=current_out_channels,\n",
    "        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),\n",
    "        #             **params,\n",
    "        #         )\n",
    "        #     )\n",
    "        convs = [\n",
    "            GCNConv(\n",
    "                in_channels=in_channels,\n",
    "                out_channels=out_channels,\n",
    "                **params,\n",
    "            )\n",
    "        ]\n",
    "        in_channels = out_channels\n",
    "        for layer in range(num_layers - 1):\n",
    "            convs.append(\n",
    "                GCNConv(\n",
    "                    in_channels=in_channels,\n",
    "                    out_channels=out_channels,\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GATConv\":\n",
    "        convs = []\n",
    "\n",
    "        # for layer in range(num_layers):\n",
    "        #     convs.append(\n",
    "        #         GATConv(\n",
    "        #             in_channels=current_out_channels,\n",
    "        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),\n",
    "        #             **params,\n",
    "        #         )\n",
    "        #     )\n",
    "\n",
    "        convs = [\n",
    "            GATConv(\n",
    "                in_channels=in_channels,\n",
    "                out_channels=out_channels,\n",
    "                **params,\n",
    "            )\n",
    "        ]\n",
    "        in_channels = out_channels\n",
    "        for layer in range(num_layers - 1):\n",
    "            convs.append(\n",
    "                GATConv(\n",
    "                    in_channels=in_channels,\n",
    "                    out_channels=out_channels,\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GINConv\":\n",
    "        convs = []\n",
    "        # current_out_channels = in_channels\n",
    "        #\n",
    "        # for layer in range(num_layers):\n",
    "        #     convs.append(\n",
    "        #         GINConv(\n",
    "        #             nn=nn.Linear(\n",
    "        #                 in_features=current_out_channels,\n",
    "        #                 out_features=(current_out_channels := max(out_channels, current_out_channels // 2)),\n",
    "        #             )\n",
    "        #         )\n",
    "        #     )\n",
    "        current_in_channels = in_channels\n",
    "        for layer in range(num_layers):\n",
    "            convs.append(\n",
    "                GINConv(\n",
    "                    nn=nn.Linear(\n",
    "                        in_features=current_in_channels,\n",
    "                        out_features=out_channels,\n",
    "                    ),\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "            current_in_channels = out_channels\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_proj = RelativeAttention(\n",
    "    n_anchors=None,\n",
    "    n_classes=dataset.num_classes,\n",
    "    similarity_mode=\"inner\",\n",
    "    values_mode=\"similarities\",\n",
    "    normalization_mode=\"l2\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TRAIN BEST MODEL\n",
    "\n",
    "\n",
    "sweep = {\n",
    "    \"seed\": [1],\n",
    "    # \"seed_index\": [0],\n",
    "    \"num_epochs\": [50],\n",
    "    \"in_channels\": [128],\n",
    "    # \"out_channels\": [10, 32, 64],\n",
    "    \"out_channels\": [128],\n",
    "    \"num_layers\": [32],\n",
    "    \"dropout\": [0.5],\n",
    "    # \"hidden_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    # \"conv_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    \"hidden_fn\": [torch.nn.ReLU()],\n",
    "    \"conv_fn\": [torch.nn.ReLU()],\n",
    "    \"optimizer\": [torch.optim.Adam],\n",
    "    \"lr\": [0.02],\n",
    "    \"encoder\": [\n",
    "        (\n",
    "            \"GCN2Conv\",\n",
    "            functools.partial(\n",
    "                encoder_factory,\n",
    "                encoder_type=\"GCN2Conv\",\n",
    "                **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),\n",
    "            ),\n",
    "        ),\n",
    "        # (\"GCNConv\", functools.partial(encoder_factory, encoder_type=\"GCNConv\")),\n",
    "        # (\"GATConv\", functools.partial(encoder_factory, encoder_type=\"GATConv\")),\n",
    "        #         (\"GINConv\", functools.partial(encoder_factory, encoder_type=\"GINConv\")),\n",
    "    ],\n",
    "    \"num_anchors\": [300],\n",
    "}\n",
    "\n",
    "\n",
    "class Lambda(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super().__init__()\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.func(x)\n",
    "\n",
    "\n",
    "def train_step(model, optimizer, data):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)\n",
    "    logits = out[Output.LOGITS]\n",
    "    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, data):\n",
    "    model.eval()\n",
    "    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)\n",
    "    pred = out[Output.LOGITS].argmax(dim=-1)\n",
    "\n",
    "    accs = []\n",
    "    for _, mask in data(\"train_mask\", \"val_mask\"):\n",
    "        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n",
    "    return out, accs\n",
    "\n",
    "\n",
    "assert len(ParameterGrid(sweep)) == 1\n",
    "for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc=\"Experiment\")):\n",
    "    seed: int = experiment[\"seed\"]\n",
    "    temp_log_level = seed_log.getEffectiveLevel()\n",
    "    seed_log.setLevel(logging.ERROR)\n",
    "    seed_everything(seed)\n",
    "    seed_log.setLevel(temp_log_level)\n",
    "\n",
    "    num_anchors: int = experiment[\"num_anchors\"]\n",
    "    data.anchors = torch.as_tensor(random.sample(data.train_mask.nonzero().squeeze().cpu().tolist(), num_anchors))\n",
    "\n",
    "    encoder_name, encoder_build = experiment[\"encoder\"]\n",
    "    if encoder_name == \"GCN2Conv\":\n",
    "        experiment[\"out_channels\"] = num_anchors\n",
    "        experiment[\"in_channels\"] = num_anchors\n",
    "\n",
    "    hidden_proj = nn.Linear(dataset.num_features, experiment[\"in_channels\"])\n",
    "\n",
    "    convs = encoder_build(\n",
    "        num_layers=experiment[\"num_layers\"],\n",
    "        in_channels=experiment[\"in_channels\"],\n",
    "        out_channels=experiment[\"out_channels\"],\n",
    "    )\n",
    "    class_proj = nn.Sequential(\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=experiment[\"out_channels\"]),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=experiment[\"out_channels\"], out_features=64),\n",
    "        nn.Tanh(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=64),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=64, out_features=dataset.num_classes),\n",
    "    )\n",
    "\n",
    "    model = Net(\n",
    "        hidden_proj=hidden_proj,\n",
    "        hidden_fn=experiment[\"hidden_fn\"],\n",
    "        relative_proj=relative_proj,\n",
    "        class_proj=class_proj,\n",
    "        convs=convs,\n",
    "        conv_fn=experiment[\"conv_fn\"],\n",
    "        conv_out=experiment[\"out_channels\"],\n",
    "        dropout=experiment[\"dropout\"],\n",
    "    ).to(DEVICE)\n",
    "    data = data.to(DEVICE)\n",
    "    optimizer = experiment[\"optimizer\"](\n",
    "        model.parameters(),\n",
    "        lr=experiment[\"lr\"],\n",
    "    )\n",
    "\n",
    "    best_val_acc = 0\n",
    "    best_epoch = None\n",
    "    epochs = []\n",
    "    for epoch in range(experiment[\"num_epochs\"]):\n",
    "        loss = train_step(model=model, optimizer=optimizer, data=data)\n",
    "        model_out, (train_acc, val_acc) = test(model=model, data=data)\n",
    "        # epochs.append(epoch_out)\n",
    "\n",
    "        if val_acc > best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            best_epoch = {\n",
    "                \"epoch\": epoch,\n",
    "                \"loss\": loss,\n",
    "                \"train_acc\": train_acc,\n",
    "                \"val_acc\": val_acc,\n",
    "            }\n",
    "\n",
    "        pbar.set_description(\n",
    "            f\"Epoch: {best_epoch['epoch']:04d}, Loss: {best_epoch['loss']:.4f} Train: {best_epoch['train_acc']:.4f}, \"\n",
    "            f\"Val: {best_epoch['val_acc']:.4f}\"\n",
    "        )\n",
    "\n",
    "    model.cpu()\n",
    "\n",
    "best_model = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model.cuda()\n",
    "best_absolute = best_model(\n",
    "    data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=None, only_absolute=True\n",
    ").detach()\n",
    "best_absolute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# General SWEEP\n",
    "sweep = {\n",
    "    \"seed\": list(range(3)),\n",
    "    # \"seed_index\": [0],\n",
    "    \"num_epochs\": [50],\n",
    "    \"in_channels\": [128],\n",
    "    # \"out_channels\": [10, 32, 64],\n",
    "    \"out_channels\": [128],\n",
    "    \"num_layers\": [32],\n",
    "    \"dropout\": [0.5],\n",
    "    # \"hidden_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    # \"conv_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    \"hidden_fn\": [None],\n",
    "    \"conv_fn\": [None],\n",
    "    \"optimizer\": [\n",
    "        torch.optim.Adam,\n",
    "    ],\n",
    "    \"lr\": [0.02],\n",
    "    \"encoder\": [None],\n",
    "    \"num_anchors\": list(range(1, 50, 1)),\n",
    "}\n",
    "\n",
    "\n",
    "# keys, values = zip(*sweep.items())\n",
    "# experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]\n",
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "experiments = ParameterGrid(sweep)\n",
    "f\"Total available experiments={len(experiments)}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "import numpy as np\n",
    "from sklearn.model_selection import ParameterSampler, ParameterGrid\n",
    "import logging\n",
    "from tqdm import tqdm\n",
    "from pprint import pprint\n",
    "from rae.utils.utils import to_device\n",
    "\n",
    "stats = {x: [] for x in (\"experiment\", \"epoch\", \"loss\", \"train_acc\", \"val_acc\", \"num_anchors\")}\n",
    "\n",
    "\n",
    "class ModelHead(nn.Module):\n",
    "    def __init__(self, num_anchors, num_classes):\n",
    "        super().__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            Lambda(lambda x: x.permute(1, 0)),\n",
    "            nn.InstanceNorm1d(num_features=num_anchors),\n",
    "            Lambda(lambda x: x.permute(1, 0)),\n",
    "            #             nn.Linear(in_features=num_anchors, out_features=300),\n",
    "            #             nn.Tanh(),\n",
    "            #             Lambda(lambda x: x.permute(1, 0)),\n",
    "            #             nn.InstanceNorm1d(num_features=300),\n",
    "            #             Lambda(lambda x: x.permute(1, 0)),\n",
    "            nn.Linear(in_features=num_anchors, out_features=num_classes),\n",
    "        )\n",
    "\n",
    "    def forward(self, x, anchors):\n",
    "        rel_out = relative_proj(x=x, anchors=anchors)[AttentionOutput.OUTPUT]\n",
    "        return {Output.LOGITS: self.model(rel_out)}\n",
    "\n",
    "\n",
    "def train_step(model, optimizer, absolute_latents, anchors_idxs):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(absolute_latents, absolute_latents[anchors_idxs])\n",
    "    logits = out[Output.LOGITS]\n",
    "    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, absolute_latents, anchors_idxs):\n",
    "    model.eval()\n",
    "    out = model(absolute_latents, absolute_latents[anchors_idxs])\n",
    "    pred = out[Output.LOGITS].argmax(dim=-1)\n",
    "    accs = []\n",
    "    for _, mask in data(\"train_mask\", \"val_mask\"):\n",
    "        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n",
    "    return out, accs\n",
    "\n",
    "\n",
    "# for i, experiment in enumerate(pbar := tqdm(ParameterSampler(sweep, n_iter=50, random_state=42), desc=\"Experiment\")):\n",
    "for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc=\"Experiment\")):\n",
    "    seed: int = experiment[\"seed\"]\n",
    "    temp_log_level = seed_log.getEffectiveLevel()\n",
    "    seed_log.setLevel(logging.ERROR)\n",
    "    seed_everything(seed)\n",
    "    seed_log.setLevel(temp_log_level)\n",
    "\n",
    "    num_anchors: int = experiment[\"num_anchors\"]\n",
    "    anchors_idxs = torch.as_tensor(random.sample(data.train_mask.nonzero().squeeze().cpu().tolist(), num_anchors))\n",
    "\n",
    "    model = ModelHead(num_anchors=num_anchors, num_classes=dataset.num_classes)\n",
    "    model.to(DEVICE)\n",
    "\n",
    "    data = data.to(DEVICE)\n",
    "    optimizer = experiment[\"optimizer\"](\n",
    "        model.parameters(),\n",
    "        lr=experiment[\"lr\"],\n",
    "    )\n",
    "\n",
    "    best_val_acc = 0\n",
    "    best_epoch = None\n",
    "    epochs = []\n",
    "    for epoch in range(experiment[\"num_epochs\"]):\n",
    "        loss = train_step(model=model, optimizer=optimizer, absolute_latents=best_absolute, anchors_idxs=anchors_idxs)\n",
    "        model_out, (train_acc, val_acc) = test(model=model, absolute_latents=best_absolute, anchors_idxs=anchors_idxs)\n",
    "        # epochs.append(epoch_out)\n",
    "\n",
    "        stats[\"experiment\"].append(i)\n",
    "        stats[\"epoch\"].append(epoch)\n",
    "        stats[\"loss\"].append(loss)\n",
    "        stats[\"train_acc\"].append(train_acc)\n",
    "        stats[\"val_acc\"].append(val_acc)\n",
    "        stats[\"num_anchors\"].append(num_anchors)\n",
    "\n",
    "        if val_acc > best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            best_epoch = {\n",
    "                \"epoch\": epoch,\n",
    "                \"loss\": loss,\n",
    "                \"train_acc\": train_acc,\n",
    "                \"val_acc\": val_acc,\n",
    "            }\n",
    "        # print(\n",
    "        #     f\"Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, \"\n",
    "        #     f\"Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, \"\n",
    "        #     f\"Final Test: {test_acc:.4f}\"\n",
    "        # )\n",
    "\n",
    "    pbar.set_description(\n",
    "        f\"Epoch: {best_epoch['epoch']:04d}, Loss: {best_epoch['loss']:.4f} Train: {best_epoch['train_acc']:.4f}, \"\n",
    "        f\"Val: {best_epoch['val_acc']:.4f}\"\n",
    "    )\n",
    "\n",
    "    model.cpu()\n",
    "\n",
    "stats = pd.DataFrame(stats)\n",
    "stats.to_csv(\n",
    "    PROJECT_ROOT\n",
    "    / \"experiments\"\n",
    "    / \"sec:anchor-analysis\"\n",
    "    / f\"{dataset_name}_data_manifold_stats_anchors_analysis_frozen_encoder.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "best_step = stats.groupby(\"experiment\").agg([np.max]).droplevel(level=1, axis=1)\n",
    "px.scatter(best_step, x=\"num_anchors\", y=\"val_acc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats.to_csv(\n",
    "    PROJECT_ROOT\n",
    "    / \"experiments\"\n",
    "    / \"sec:anchor-analysis\"\n",
    "    / f\"{dataset_name}_data_manifold_stats_anchors_analysis_frozen_encoder.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
