{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from pytorch_lightning import seed_everything\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from collections import defaultdict\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModel,\n",
    "    AutoImageProcessor,\n",
    ")\n",
    "\n",
    "import functools\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from layskip.utils import similarities\n",
    "from layskip.modules.module import HFwrapper\n",
    "from layskip.utils.plots import plot_similarity_matrix, plot_similarity_embeddings\n",
    "from layskip.utils.utils import (\n",
    "    image_encode,\n",
    "    extract_all_layers,\n",
    ")\n",
    "\n",
    "from layskip.utils.dictionaries import (\n",
    "    DATASET2IMAGE_COLUMN,\n",
    "    DATASET2LABEL_COLUMN,\n",
    "    DATASET2NUM_CLASSES,\n",
    ")\n",
    "\n",
    "from latentis.space import LatentSpace\n",
    "from latentis.sample import Uniform\n",
    "from latentis.transform import projection\n",
    "from latentis.transform.projection import RelativeProjection\n",
    "from latentis.transform.base import Centering\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "seed_everything(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"facebook/dinov2-small\"\n",
    "\n",
    "# \"WinKawaks/vit-small-patch16-224\"\n",
    "# \"facebook/dinov2-small\"\n",
    "# \"facebook/deit-small-patch16-224\"\n",
    "# \"microsoft/swinv2-tiny-patch4-window8-256\"\n",
    "# \"microsoft/beit-base-patch16-224\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"cifar100-fine\"\n",
    "# mnist cifar10 cifar100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAINED_CLASSIFIER = PROJECT_ROOT / \"models\" / model_name.split(\"/\")[1] / (dataset_name + \"_classifier.ckpt\")\n",
    "PLOTS_DIR = PROJECT_ROOT / \"results\" / \"plots\" / \"CLS\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name in [\"cifar100-fine\", \"cifar100-coarse\"]:\n",
    "    dataset = load_dataset(\"cifar100\")\n",
    "else:\n",
    "    dataset = load_dataset(dataset_name)\n",
    "\n",
    "train_dataset = dataset[\"train\"]\n",
    "test_dataset = dataset[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_name = DATASET2IMAGE_COLUMN[dataset_name]\n",
    "label_name = DATASET2LABEL_COLUMN[dataset_name]\n",
    "num_classes = DATASET2NUM_CLASSES[dataset_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_classifier = torch.load(TRAINED_CLASSIFIER)\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)\n",
    "processor = AutoImageProcessor.from_pretrained(model_name)\n",
    "encoder = AutoModel.from_pretrained(model_name, config=config)\n",
    "\n",
    "model = HFwrapper(encoder=encoder, classifier=trained_classifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    num_workers=8,\n",
    "    pin_memory=True,\n",
    "    collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),\n",
    ")\n",
    "\n",
    "# test_dataloader = DataLoader(\n",
    "#     test_dataset,\n",
    "#     batch_size=256,\n",
    "#     shuffle=False,\n",
    "#     num_workers=8,\n",
    "#     pin_memory=True,\n",
    "#     collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Analyze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_samples = 1000\n",
    "layer_embeddings = extract_all_layers(encoder, max_samples, train_dataloader, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(layer_embeddings), layer_embeddings[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_layer_embeddings = defaultdict(list)\n",
    "last_token_layer_embeddings = defaultdict(list)\n",
    "cls_layer_embeddings = defaultdict(list)\n",
    "\n",
    "# Extract and assemble CLS and last token embeddings\n",
    "for layer_idx, layer_output in layer_embeddings.items():\n",
    "    cls_layer_embeddings[layer_idx] = layer_output[:, 0, :]\n",
    "    last_token_layer_embeddings[layer_idx] = layer_output[:, -1, :]\n",
    "\n",
    "assert len(layer_embeddings) == len(cls_layer_embeddings) == len(last_token_layer_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_layer_embeddings = [layer_output.mean(axis=1) for _, layer_output in layer_embeddings.items()]\n",
    "cls_layer_embeddings = [layer_output for _, layer_output in cls_layer_embeddings.items()]\n",
    "last_token_layer_embeddings = [layer_output for _, layer_output in last_token_layer_embeddings.items()]\n",
    "\n",
    "assert mean_layer_embeddings[0].shape == cls_layer_embeddings[0].shape == last_token_layer_embeddings[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_spaces = []\n",
    "cls_spaces = []\n",
    "last_token_spaces = []\n",
    "\n",
    "for layer_embedding in mean_layer_embeddings:\n",
    "    mean_spaces.append(LatentSpace(vector_source=layer_embedding))\n",
    "\n",
    "for layer_embedding in cls_layer_embeddings:\n",
    "    cls_spaces.append(LatentSpace(vector_source=layer_embedding))\n",
    "\n",
    "for layer_embedding in last_token_layer_embeddings:\n",
    "    last_token_spaces.append(LatentSpace(vector_source=layer_embedding))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "anchors_mean = mean_spaces[0].sample(sampler=Uniform(), n=384).vectors\n",
    "anchors_cls = cls_spaces[0].sample(sampler=Uniform(), n=384).vectors\n",
    "anchors_last_token = last_token_spaces[0].sample(sampler=Uniform(), n=384).vectors\n",
    "\n",
    "assert anchors_cls.shape == anchors_last_token.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projector = RelativeProjection(projection_fn=projection.cosine_proj, abs_transform=Centering())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rr_mean_spaces = []\n",
    "rr_cls_spaces = []\n",
    "rr_last_token_spaces = []\n",
    "\n",
    "for mean_ls in mean_spaces:\n",
    "    rr_mean_spaces.append(projector.fit(x=anchors_mean)(x=mean_ls.vectors)[\"x\"])\n",
    "\n",
    "for cls_ls in cls_spaces:\n",
    "    rr_cls_spaces.append(projector.fit(x=anchors_cls)(x=cls_ls.vectors)[\"x\"])\n",
    "\n",
    "for last_token_ls in last_token_spaces:\n",
    "    rr_last_token_spaces.append(projector.fit(x=anchors_last_token)(x=last_token_ls.vectors)[\"x\"])\n",
    "\n",
    "assert rr_cls_spaces[0].shape == rr_last_token_spaces[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine similarity between two embeddings over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding1_index = 50\n",
    "embedding2_index = 54"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 3, figsize=(20, 10), sharey=\"row\")\n",
    "\n",
    "plot_similarity_embeddings(\n",
    "    embeddings=mean_layer_embeddings,\n",
    "    ax=axes[0, 0],\n",
    "    title=\"Mean Similarity Trajectories over Layer\",\n",
    "    e1=embedding1_index,\n",
    "    e2=embedding2_index,\n",
    "    color=\"green\",\n",
    ")\n",
    "\n",
    "axes[0, 0].set_ylabel(\"Absolute\")\n",
    "\n",
    "plot_similarity_embeddings(\n",
    "    embeddings=cls_layer_embeddings,\n",
    "    ax=axes[0, 1],\n",
    "    title=\"CLS Similarity Trajectories over Layer\",\n",
    "    e1=embedding1_index,\n",
    "    e2=embedding2_index,\n",
    "    color=\"green\",\n",
    ")\n",
    "plot_similarity_embeddings(\n",
    "    embeddings=last_token_layer_embeddings,\n",
    "    ax=axes[0, 2],\n",
    "    title=\"Last Token Similarity Trajectories over Layer\",\n",
    "    e1=embedding1_index,\n",
    "    e2=embedding2_index,\n",
    "    color=\"green\",\n",
    ")\n",
    "\n",
    "####\n",
    "\n",
    "plot_similarity_embeddings(\n",
    "    embeddings=mean_layer_embeddings,\n",
    "    ax=axes[1, 0],\n",
    "    title=\"\",\n",
    "    e1=embedding1_index,\n",
    "    e2=embedding2_index,\n",
    "    color=\"purple\",\n",
    ")\n",
    "\n",
    "axes[1, 0].set_ylabel(\"Relative\")\n",
    "\n",
    "plot_similarity_embeddings(\n",
    "    embeddings=rr_cls_spaces,\n",
    "    ax=axes[1, 1],\n",
    "    title=\"\",\n",
    "    e1=embedding1_index,\n",
    "    e2=embedding2_index,\n",
    "    color=\"purple\",\n",
    ")\n",
    "\n",
    "plot_similarity_embeddings(\n",
    "    embeddings=rr_last_token_spaces,\n",
    "    ax=axes[1, 2],\n",
    "    title=\"\",\n",
    "    e1=embedding1_index,\n",
    "    e2=embedding2_index,\n",
    "    color=\"purple\",\n",
    ")\n",
    "\n",
    "for ax in axes.flat:\n",
    "    ax.set_xticks(range(len(cls_layer_embeddings)))\n",
    "    ax.tick_params(axis=\"y\", which=\"both\", labelleft=True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dir = PLOTS_DIR / str(max_samples) / \"embedding_tajectory\"\n",
    "os.makedirs(embedding_dir, exist_ok=True)\n",
    "\n",
    "img_name = (\n",
    "    dataset_name\n",
    "    + \"_\"\n",
    "    + str(train_dataset[embedding1_index][label_name])\n",
    "    + \"_\"\n",
    "    + str(train_dataset[embedding2_index][label_name])\n",
    "    + \"_\"\n",
    "    + model_name.split(\"/\")[1]\n",
    "    + \".png\"\n",
    ")\n",
    "\n",
    "save_path = os.path.join(embedding_dir, img_name)\n",
    "fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine matrix over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 3, figsize=(20, 10))\n",
    "\n",
    "plot_similarity_matrix(\n",
    "    layers=mean_layer_embeddings, ax=axes[0, 0], title=\"Mean Cosine Similarity Matrix\", metric=\"cosine\", vmax=1\n",
    ")\n",
    "plot_similarity_matrix(\n",
    "    layers=cls_layer_embeddings, ax=axes[0, 1], title=\"CLS Cosine Similarity Matrix\", metric=\"cosine\", vmax=1\n",
    ")\n",
    "plot_similarity_matrix(\n",
    "    layers=last_token_layer_embeddings,\n",
    "    ax=axes[0, 2],\n",
    "    title=\"Last Token Cosine Similarity Matrix\",\n",
    "    metric=\"cosine\",\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "plot_similarity_matrix(layers=rr_mean_spaces, ax=axes[1, 0], title=\"\", metric=\"cosine\", vmax=1)\n",
    "plot_similarity_matrix(layers=rr_cls_spaces, ax=axes[1, 1], title=\"\", metric=\"cosine\", vmax=1)\n",
    "plot_similarity_matrix(layers=rr_last_token_spaces, ax=axes[1, 2], title=\"\", metric=\"cosine\", vmax=1)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mat_dir = PLOTS_DIR / str(max_samples) / \"cosine_matrix\"\n",
    "os.makedirs(mat_dir, exist_ok=True)\n",
    "\n",
    "img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "save_path = os.path.join(mat_dir, img_name)\n",
    "fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mat_dir = PLOTS_DIR / str(max_samples) / \"cosine_matrix\"\n",
    "os.makedirs(mat_dir, exist_ok=True)\n",
    "\n",
    "img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "save_path = os.path.join(mat_dir, img_name)\n",
    "fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MSE matrix over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 3, figsize=(20, 10))\n",
    "\n",
    "plot_similarity_matrix(\n",
    "    layers=mean_layer_embeddings, ax=axes[0, 0], title=\"Mean MSE Similarity Matrix\", metric=\"MSE\", vmax=1\n",
    ")\n",
    "plot_similarity_matrix(\n",
    "    layers=cls_layer_embeddings, ax=axes[0, 1], title=\"CLS MSE Similarity Matrix\", metric=\"MSE\", vmax=1\n",
    ")\n",
    "plot_similarity_matrix(\n",
    "    layers=last_token_layer_embeddings, ax=axes[0, 2], title=\"Last Token MSE Similarity Matrix\", metric=\"MSE\", vmax=1\n",
    ")\n",
    "\n",
    "plot_similarity_matrix(layers=rr_mean_spaces, ax=axes[1, 0], title=\"\", metric=\"MSE\", vmax=2)\n",
    "plot_similarity_matrix(layers=rr_cls_spaces, ax=axes[1, 1], title=\"\", metric=\"MSE\", vmax=2)\n",
    "plot_similarity_matrix(layers=rr_last_token_spaces, ax=axes[1, 2], title=\"\", metric=\"MSE\", vmax=2)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mat_dir = PLOTS_DIR / str(max_samples) / \"MSE_matrix\"\n",
    "os.makedirs(mat_dir, exist_ok=True)\n",
    "\n",
    "img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "save_path = os.path.join(mat_dir, img_name)\n",
    "fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CKA matrix over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 3, figsize=(20, 10))\n",
    "\n",
    "plot_similarity_matrix(\n",
    "    layers=mean_layer_embeddings, ax=axes[0, 0], title=\"Mean CKA Similarity Matrix\", metric=\"CKA\", vmax=1\n",
    ")\n",
    "plot_similarity_matrix(\n",
    "    layers=cls_layer_embeddings, ax=axes[0, 1], title=\"CLS CKA Similarity Matrix\", metric=\"CKA\", vmax=1\n",
    ")\n",
    "plot_similarity_matrix(\n",
    "    layers=last_token_layer_embeddings, ax=axes[0, 2], title=\"Last Token CKA Similarity Matrix\", metric=\"CKA\", vmax=1\n",
    ")\n",
    "\n",
    "plot_similarity_matrix(layers=rr_mean_spaces, ax=axes[1, 0], title=\"\", metric=\"CKA\", vmax=1)\n",
    "plot_similarity_matrix(layers=rr_cls_spaces, ax=axes[1, 1], title=\"\", metric=\"CKA\", vmax=1)\n",
    "plot_similarity_matrix(layers=rr_last_token_spaces, ax=axes[1, 2], title=\"\", metric=\"CKA\", vmax=1)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## test\n",
    "from latentis.measure.functional.cka import cka as cka_fn\n",
    "from latentis.measure.functional.cka import kernel_hsic, linear_hsic\n",
    "from latentis.measure.functional.svcca import robust_svcca as svcca_fn\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from presto import Presto\n",
    "from sklearn.random_projection import GaussianRandomProjection as Gauss\n",
    "\n",
    "temp_cls = cls_layer_embeddings[:-1]\n",
    "cls_sim = {}\n",
    "\n",
    "presto = Presto()\n",
    "\n",
    "\n",
    "for id_layer, layer in enumerate(temp_cls):\n",
    "    current_layer = id_layer\n",
    "    succ_layer = id_layer + 1\n",
    "    prec_layer = id_layer - 1\n",
    "\n",
    "    # sim = cka_fn(\n",
    "    #     cls_layer_embeddings[current_layer].to(device), cls_layer_embeddings[succ_layer].to(device), hsic=kernel_hsic\n",
    "    # )\n",
    "\n",
    "    # sim = svcca_fn(\n",
    "    #     cls_layer_embeddings[current_layer], cls_layer_embeddings[succ_layer]\n",
    "    # )\n",
    "\n",
    "    squared_diff = ((cls_layer_embeddings[current_layer] - cls_layer_embeddings[succ_layer]) ** 2).mean(dim=-1)\n",
    "    sim = squared_diff.mean(dim=-1)\n",
    "\n",
    "    # sim = cosine_similarity(cls_layer_embeddings[current_layer], cls_layer_embeddings[succ_layer]).mean()\n",
    "\n",
    "    # dist = presto.fit_transform(\n",
    "    #     cls_layer_embeddings[current_layer],\n",
    "    #     cls_layer_embeddings[succ_layer],\n",
    "    #     n_projections=50,\n",
    "    #     n_components=2,\n",
    "    #     normalize=True,\n",
    "    #     seed=0\n",
    "    # )\n",
    "\n",
    "    cls_sim = cls_sim | {\"[(\" + str(current_layer) + \",\" + str(succ_layer) + \")]\": round(sim.item(), 3)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_keys = sorted(cls_sim, key=lambda key: cls_sim[key], reverse=False)\n",
    "for x in res_keys:\n",
    "    print(f\"{x}, {cls_sim[x]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mat_dir = PLOTS_DIR / str(max_samples) / \"CKA_matrix\"\n",
    "os.makedirs(mat_dir, exist_ok=True)\n",
    "\n",
    "img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "save_path = os.path.join(mat_dir, img_name)\n",
    "fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine trajectories over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cosine_similarities = similarities.pairwise_layer_cosine_similarity(cls_layer_embeddings)\n",
    "\n",
    "# num_layers = cosine_similarities.shape[0]\n",
    "# fig = plt.figure(figsize=(18, 15))\n",
    "\n",
    "# sim_threshhold = 0.96\n",
    "\n",
    "# for i in range(num_layers):\n",
    "#     plt.subplot(5, 3, i + 1)\n",
    "#     plt.plot(range(num_layers), cosine_similarities[i].cpu().numpy(), marker=\"o\", label=f\"Layer {i}\")\n",
    "#     plt.axhline(y=sim_threshhold, color=\"r\", linestyle=\"--\", label=f\"Similarity Threshold ({sim_threshhold})\")\n",
    "#     plt.axvline(x=i, color=\"b\", linestyle=\"--\", label=f\"Current Layer {i}\")\n",
    "#     plt.xlabel(\"Layer\")\n",
    "#     plt.ylabel(\"Cosine Similarity\")\n",
    "#     plt.title(f\"Cosine Similarity of Layer {i} with Others\")\n",
    "#     plt.xticks(ticks=range(num_layers))\n",
    "#     plt.legend()\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# traj_dir = PLOTS_DIR / str(max_samples) / \"cosine_trajectory\"\n",
    "# os.makedirs(traj_dir, exist_ok=True)\n",
    "\n",
    "# img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "# save_path = os.path.join(traj_dir, img_name)\n",
    "# fig.savefig(save_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "layskip",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
