{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gnnboundary import *\n",
    "from scripts.experiments import get_model_kwargs, CKPT_PATHS\n",
    "from gnnboundary.utils.random_baseline import PAPER_CLASS_COMBINATIONS\n",
    "from visualization.helpers import get_embeddings_for_class_pair, generate_umap, generate_pca\n",
    "from visualization.plotting import plot_latent_space_2d, plot_latent_space_3d\n",
    "\n",
    "\n",
    "dataset_list = [MotifDataset(seed=12345)] # , CollabDataset(seed=12345), ENZYMESDataset(seed=12345) # IMDBDataset(seed=12345)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in dataset_list:\n",
    "    for cls_pair in PAPER_CLASS_COMBINATIONS[dataset.__class__.__name__]:\n",
    "        embeds = get_embeddings_for_class_pair(cls_pair, dataset)\n",
    "        for method in [\"PCA\", \"UMAP\"]:\n",
    "            for dim in [2, 3]:\n",
    "                if method == \"PCA\":\n",
    "                    viz_emb, labels = generate_pca(embeds, dataset, cls_pair, n_components=dim)\n",
    "                elif method == \"UMAP\":\n",
    "                    viz_emb, labels = generate_umap(embeds, dataset, cls_pair, n_components=dim)\n",
    "                if dim == 2:\n",
    "                    plot_latent_space_2d(viz_emb, labels, dataset.name, method=method)\n",
    "                elif dim == 3:\n",
    "                    plot_latent_space_3d(viz_emb, labels, dataset.name, method=method)\n",
    "                print(f\"Finsihed plotting for {dataset.name}, {cls_pair}, {method}, {dim}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnnboundary",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
