{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from pytorch_lightning import seed_everything\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from collections import defaultdict\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModel,\n",
    "    AutoImageProcessor,\n",
    ")\n",
    "import functools\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go\n",
    "from layskip.utils import similarities\n",
    "from layskip.modules.module import HFwrapper\n",
    "from layskip.utils.plots import plot_similarity_matrix, plot_similarity_embeddings\n",
    "from layskip.utils.utils import image_encode, extract_all_layers\n",
    "from layskip.utils.dictionaries import (\n",
    "    DATASET2IMAGE_COLUMN,\n",
    "    DATASET2LABEL_COLUMN,\n",
    "    DATASET2NUM_CLASSES,\n",
    "    MODEL2NUM_LAYERS,\n",
    ")\n",
    "import seaborn as sns\n",
    "\n",
    "from latentis.space import LatentSpace\n",
    "from latentis.sample import Uniform\n",
    "from latentis.transform import projection\n",
    "from latentis.transform.projection import RelativeProjection\n",
    "from latentis.transform.base import Centering\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "seed_everything(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"facebook/dinov2-small\"\n",
    "\n",
    "# \"WinKawaks/vit-small-patch16-224\"\n",
    "# \"facebook/deit-small-patch16-224\"\n",
    "# \"microsoft/beit-base-patch16-224\"\n",
    "# \"facebook/dinov2-small\"\n",
    "\n",
    "dataset_name = \"cifar100-fine\"\n",
    "# mnist cifar10 cifar100\n",
    "\n",
    "PLOTS_DIR = PROJECT_ROOT / \"plots\" / model_name.split(\"/\")[1] / dataset_name\n",
    "\n",
    "if dataset_name in [\"cifar100-fine\", \"cifar100-coarse\"]:\n",
    "    dataset = load_dataset(\"cifar100\")\n",
    "else:\n",
    "    dataset = load_dataset(dataset_name)\n",
    "\n",
    "train_dataset = dataset[\"train\"]\n",
    "test_dataset = dataset[\"test\"]\n",
    "\n",
    "image_name = DATASET2IMAGE_COLUMN[dataset_name]\n",
    "label_name = DATASET2LABEL_COLUMN[dataset_name]\n",
    "num_classes = DATASET2NUM_CLASSES[dataset_name]\n",
    "original_num_layers = MODEL2NUM_LAYERS[model_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)\n",
    "processor = AutoImageProcessor.from_pretrained(model_name)\n",
    "encoder = AutoModel.from_pretrained(model_name, config=config)\n",
    "encoder.to(device)\n",
    "encoder.eval()\n",
    "\n",
    "dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    num_workers=8,\n",
    "    pin_memory=True,\n",
    "    collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Analyze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_samples = 3000\n",
    "cls_layer_embeddings = extract_all_layers(encoder, max_samples, dataloader, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(cls_layer_embeddings) == original_num_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(cls_layer_embeddings), cls_layer_embeddings[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_layer_embeddings = [layer_output for _, layer_output in cls_layer_embeddings.items()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cls_spaces = []\n",
    "\n",
    "# for layer_embedding in cls_layer_embeddings:\n",
    "#     cls_spaces.append(LatentSpace(vector_source=layer_embedding))\n",
    "\n",
    "# anchors_cls = cls_spaces[0].sample(sampler=Uniform(), n=384).vectors\n",
    "# projector = RelativeProjection(projection_fn=projection.cosine_proj, abs_transform=Centering())\n",
    "\n",
    "# rr_cls_spaces = []\n",
    "\n",
    "# for cls_ls in cls_spaces:\n",
    "#     rr_cls_spaces.append(projector.fit(x=anchors_cls)(x=cls_ls.vectors)[\"x\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MSE matrix over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "# plot_similarity_matrix(layers=cls_layer_embeddings, ax=axes[0], title=\"Absolute\", metric=\"MSE\", vmax=2)\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.suptitle(\"CLS MSE Similarity Matrix\", fontsize=11)\n",
    "# plt.show()\n",
    "\n",
    "# # save\n",
    "# os.makedirs(PLOTS_DIR, exist_ok=True)\n",
    "# save_path = os.path.join(PLOTS_DIR, \"MSE_matrix\")\n",
    "# fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine & CKA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "from tueplots.figsizes import _GOLDEN_RATIO\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 1\n",
    "RATIO = _GOLDEN_RATIO\n",
    "\n",
    "# Use tueplots for iclr2024 formatting\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "plt.rcParams.update(bundles.iclr2024())\n",
    "plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "plt.rcParams.update(axes.lines())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity_matrix = similarities.pairwise_layer_cosine_similarity(cls_layer_embeddings)\n",
    "\n",
    "fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, sharex=True, sharey=True)\n",
    "\n",
    "f = sns.heatmap(\n",
    "    similarity_matrix,\n",
    "    annot=True,\n",
    "    cmap=\"viridis\",\n",
    "    fmt=\".2f\",\n",
    "    xticklabels=range(len(similarity_matrix)),\n",
    "    yticklabels=range(len(similarity_matrix)),\n",
    "    ax=ax,\n",
    "    vmax=1,\n",
    "    annot_kws={\"size\": 8},  # Reduce annotation font size here\n",
    ")\n",
    "\n",
    "ax.set_aspect(\"equal\")\n",
    "\n",
    "PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "fig.savefig(f\"{PLOTS_DIR}/all_cos_{dataset_name}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### In-out difference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_impact = {}\n",
    "\n",
    "for id, layer in enumerate(cls_layer_embeddings):\n",
    "    if id != 0:\n",
    "        squared_diff = ((cls_layer_embeddings[id - 1] - cls_layer_embeddings[id]) ** 2).mean(dim=-1)\n",
    "        mse = squared_diff.mean(dim=-1).item()\n",
    "\n",
    "        layer_impact = layer_impact | {id: mse}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_values = list(range(1, len(layer_impact) + 1))\n",
    "impact_values = list(layer_impact.values())\n",
    "\n",
    "fig = plt.figure(figsize=(8, 6))\n",
    "plt.plot(x_values, impact_values, marker=\"o\", linestyle=\"-\", color=\"green\", label=\"Impact Value\")\n",
    "\n",
    "plt.title(\"Layer Impact\", fontsize=14)  # , fontweight=\"bold\")\n",
    "plt.xlabel(\"Layer\", fontsize=12)\n",
    "plt.ylabel(\"I/O Difference\", fontsize=12)\n",
    "\n",
    "plt.xticks(x_values)\n",
    "plt.ylim([0, 0.3])\n",
    "\n",
    "# plt.legend()\n",
    "# plt.grid(False)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# save\n",
    "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
    "save_path = os.path.join(PLOTS_DIR, \"impact_value\")\n",
    "fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CKA matrix over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "# plot_similarity_matrix(layers=cls_layer_embeddings, ax=axes[0], title=\"Absolute\", metric=\"CKA\", vmax=1)\n",
    "# plot_similarity_matrix(layers=rr_cls_spaces, ax=axes[1], title=\"Relative\", metric=\"CKA\", vmax=1)\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.suptitle(\"CLS CKA Similarity Matrix\", fontsize=11)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mat_dir = PLOTS_DIR / str(max_samples) / \"CKA_matrix\"\n",
    "# os.makedirs(mat_dir, exist_ok=True)\n",
    "\n",
    "# img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "# save_path = os.path.join(mat_dir, img_name)\n",
    "# fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine matrix over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "# plot_similarity_matrix(layers=cls_layer_embeddings, ax=axes[0], title=\"Absolute\", metric=\"cosine\", vmax=1)\n",
    "# plot_similarity_matrix(layers=rr_cls_spaces, ax=axes[1], title=\"Relative\", metric=\"cosine\", vmax=1)\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.suptitle(\"CLS Cosine Similarity Matrix\", fontsize=11)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mat_dir = PLOTS_DIR / str(max_samples) / \"cosine_matrix\"\n",
    "# os.makedirs(mat_dir, exist_ok=True)\n",
    "\n",
    "# img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "# save_path = os.path.join(mat_dir, img_name)\n",
    "# fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine trajectories over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cosine_similarities = similarities.pairwise_layer_cosine_similarity(cls_layer_embeddings)\n",
    "\n",
    "# num_layers = cosine_similarities.shape[0]\n",
    "# fig = plt.figure(figsize=(18, 15))\n",
    "\n",
    "# sim_threshhold = 0.98\n",
    "\n",
    "# for i in range(num_layers):\n",
    "#     plt.subplot(5, 3, i + 1)\n",
    "#     plt.plot(range(num_layers), cosine_similarities[i].cpu().numpy(), marker=\"o\", label=f\"Layer {i}\")\n",
    "#     plt.axhline(y=sim_threshhold, color=\"r\", linestyle=\"--\", label=f\"Similarity Threshold ({sim_threshhold})\")\n",
    "#     plt.axvline(x=i, color=\"b\", linestyle=\"--\", label=f\"Current Layer {i}\")\n",
    "#     plt.xlabel(\"Layer\")\n",
    "#     plt.ylabel(\"Cosine Similarity\")\n",
    "#     plt.title(f\"Cosine Similarity of Layer {i} with Others\")\n",
    "#     plt.xticks(ticks=range(num_layers))\n",
    "#     plt.legend()\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# traj_dir = PLOTS_DIR / str(max_samples) / \"cosine_trajectory\"\n",
    "# os.makedirs(traj_dir, exist_ok=True)\n",
    "\n",
    "# img_name = dataset_name + \"_\" + model_name.split(\"/\")[1] + \".png\"\n",
    "\n",
    "# save_path = os.path.join(traj_dir, img_name)\n",
    "# fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine similarity between two embeddings over layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding1_index = 50\n",
    "# embedding2_index = 54\n",
    "# max_samples = 1000\n",
    "\n",
    "# fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=\"row\")\n",
    "\n",
    "# plot_similarity_embeddings(\n",
    "#     embeddings=cls_layer_embeddings,\n",
    "#     ax=axes[0],\n",
    "#     title=\"Absolute\",\n",
    "#     e1=embedding1_index,\n",
    "#     e2=embedding2_index,\n",
    "#     color=\"green\",\n",
    "# )\n",
    "\n",
    "# plot_similarity_embeddings(\n",
    "#     embeddings=rr_cls_spaces,\n",
    "#     ax=axes[1],\n",
    "#     title=\"Relative\",\n",
    "#     e1=embedding1_index,\n",
    "#     e2=embedding2_index,\n",
    "#     color=\"purple\",\n",
    "# )\n",
    "\n",
    "# for ax in axes.flat:\n",
    "#     ax.set_xticks(range(len(cls_layer_embeddings)))\n",
    "#     ax.tick_params(axis=\"y\", which=\"both\", labelleft=True)\n",
    "\n",
    "# plt.tight_layout()\n",
    "\n",
    "# # Add the general title\n",
    "# plt.suptitle(\"CLS Similarity Trajectories Over Layers\", fontsize=11)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding_dir = PLOTS_DIR / str(max_samples) / \"embedding_tajectory\"\n",
    "# os.makedirs(embedding_dir, exist_ok=True)\n",
    "\n",
    "# img_name = (\n",
    "#     dataset_name\n",
    "#     + \"_\"\n",
    "#     + str(train_dataset[embedding1_index][label_name])\n",
    "#     + \"_\"\n",
    "#     + str(train_dataset[embedding2_index][label_name])\n",
    "#     + \"_\"\n",
    "#     + model_name.split(\"/\")[1]\n",
    "#     + \".png\"\n",
    "# )\n",
    "\n",
    "# save_path = os.path.join(embedding_dir, img_name)\n",
    "# fig.savefig(save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "layskip",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
