{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from torch import nn\n",
    "import torch\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from pytorch_lightning import seed_everything\n",
    "import pandas as pd\n",
    "from sklearn.decomposition import PCA\n",
    "import seaborn as sns\n",
    "from datasets import DatasetDict\n",
    "from layskip.utils.dictionaries import (\n",
    "    DATASET2LABEL_COLUMN,\n",
    "    DATASET2NUM_CLASSES,\n",
    ")\n",
    "from tqdm import tqdm\n",
    "from latentis.measure.functional.cka import cka as cka_fn\n",
    "from latentis.measure.functional.cka import kernel_hsic, linear_hsic\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import AutoConfig, AutoModel\n",
    "from layskip.utils.utils import count_parameters, convert_parameters\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "from tueplots.figsizes import _GOLDEN_RATIO\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#temp\n",
    "EMBEDDINGS_DIR = str(PROJECT_ROOT / \"data\" / \"embeddings\" / \"cifar100-coarse\" / \"vit-base-patch16-224\")\n",
    "embeddings = DatasetDict.load_from_disk(EMBEDDINGS_DIR)\n",
    "embeddings.set_format(\"torch\")\n",
    "\n",
    "embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "seed_everything(seed)\n",
    "\n",
    "datasets = [\n",
    "    \"mnist\",\n",
    "    \"fashion-mnist\",\n",
    "    \"cifar10\",\n",
    "    \"cifar100\",\n",
    "]\n",
    "\n",
    "model_name = \"facebook/deit-small-patch16-224\"  # \"WinKawaks/vit-small-patch16-224\" \"facebook/dinov2-small\" \"facebook/deit-small-patch16-224\"\n",
    "PLOTS_DIR = PROJECT_ROOT / \"plots\" / model_name.split(\"/\")[1]\n",
    "\n",
    "approximations = [\n",
    "    [],\n",
    "    [(0, 1)],\n",
    "    [(1, 2)],\n",
    "    [(2, 3)],\n",
    "    [(3, 4)],\n",
    "    [(4, 5)],\n",
    "    [(5, 6)],\n",
    "    [(6, 7)],\n",
    "    [(7, 8)],\n",
    "    [(8, 9)],\n",
    "    [(9, 10)],\n",
    "    [(10, 11)],\n",
    "]\n",
    "\n",
    "max_samples = 3000\n",
    "redundancies = {}\n",
    "similarities = {}\n",
    "\n",
    "for dataset_name in datasets:\n",
    "\n",
    "    if dataset_name == \"cifar100\":\n",
    "        EMBEDDINGS_DIR = str(PROJECT_ROOT / \"data\" / \"embeddings\" / \"cifar100-fine\" / model_name.split(\"/\")[1])\n",
    "    else:\n",
    "        EMBEDDINGS_DIR = str(PROJECT_ROOT / \"data\" / \"embeddings\" / dataset_name / model_name.split(\"/\")[1])\n",
    "    embeddings = DatasetDict.load_from_disk(EMBEDDINGS_DIR)\n",
    "    embeddings.set_format(\"torch\")\n",
    "\n",
    "    indexes = torch.randperm(len(embeddings[\"train\"]), generator=torch.Generator().manual_seed(42))[:max_samples]\n",
    "\n",
    "    representations = []\n",
    "    block_redundancy = {}\n",
    "    diff_repr = {}\n",
    "\n",
    "    for approx in tqdm(approximations):\n",
    "        approx_emb = (embeddings[\"train\"].select_columns([str(approx)]).rename_column(str(approx), \"images\"))[indexes]\n",
    "\n",
    "        representations.append(approx_emb[\"images\"])\n",
    "\n",
    "    # BR index\n",
    "    for id in range(1, len(representations)):\n",
    "\n",
    "        squared_diff = ((representations[id] - representations[(id - 1)]) ** 2).mean(dim=-1)\n",
    "        mse = squared_diff.mean()\n",
    "        br = -(mse.item())\n",
    "\n",
    "        block_redundancy = block_redundancy | {id: br}\n",
    "\n",
    "    redundancies = redundancies | {dataset_name: block_redundancy}\n",
    "\n",
    "    for id in range(1, len(representations)):\n",
    "\n",
    "        squared_diff = ((representations[id] - representations[0]) ** 2).mean(dim=-1)\n",
    "        mse = squared_diff.mean()\n",
    "        diff = mse.item()\n",
    "\n",
    "        diff_repr = diff_repr | {id: diff}\n",
    "\n",
    "    similarities = similarities | {dataset_name: diff_repr}\n",
    "\n",
    "N_ROWS = 2\n",
    "N_COLS = 1\n",
    "RATIO = 1.618\n",
    "\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "plt.rcParams.update({\"font.size\": 14})\n",
    "\n",
    "fig, axes = plt.subplots(nrows=N_ROWS, ncols=N_COLS, figsize=(10, 6 * RATIO))\n",
    "\n",
    "if N_COLS == 1:\n",
    "    axes = np.atleast_1d(axes)\n",
    "\n",
    "colors = [\"#1f77b4\", \"#2ca02c\", \"#9467bd\", \"#ff7f0e\"]  # 17becf\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    x_values = list(redundancies[dataset].keys())\n",
    "    br_values = list(redundancies[dataset].values())\n",
    "    axes[0].plot(x_values, br_values, marker=\"o\", linestyle=\"-\", color=colors[i], label=dataset)\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    x_values = list(similarities[dataset].keys())\n",
    "    sim_values = list(similarities[dataset].values())\n",
    "    axes[1].plot(x_values, sim_values, marker=\"o\", linestyle=\"-\", color=colors[i], label=dataset)\n",
    "\n",
    "for ax in axes:\n",
    "    ax.legend(fontsize=14)\n",
    "    ax.tick_params(axis=\"both\", labelsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "fig.savefig(f\"{PLOTS_DIR}/br_all.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PCA plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = pd.DataFrame(columns=[\"sample_id\", \"x\", \"y\", \"labels\"])\n",
    "df2 = pd.DataFrame(columns=[\"sample_id\", \"x\", \"y\", \"labels\"])\n",
    "\n",
    "layer_redundancies = {}\n",
    "\n",
    "max_samples = 3000\n",
    "indexes = torch.randperm(len(embeddings[\"train\"]), generator=torch.Generator().manual_seed(42))[:max_samples]\n",
    "\n",
    "representations = []\n",
    "block_redundancy = {}\n",
    "\n",
    "approx = \"[(0, 1)]\"\n",
    "approx_layer = 1\n",
    "\n",
    "orig_repr = (\n",
    "    embeddings[\"train\"]\n",
    "    .select_columns([\"[]\", DATASET2LABEL_COLUMN[dataset_name]])\n",
    "    .rename_column(\"[]\", \"images\")\n",
    "    .rename_column(DATASET2LABEL_COLUMN[dataset_name], \"labels\")\n",
    ")[indexes]\n",
    "\n",
    "orig_label = orig_repr[\"labels\"]\n",
    "orig_repr = orig_repr[\"images\"]\n",
    "\n",
    "approx_repr = (\n",
    "    embeddings[\"train\"].select_columns([approx, DATASET2LABEL_COLUMN[dataset_name]]).rename_column(approx, \"images\")\n",
    ")[indexes]\n",
    "approx_repr = approx_repr[\"images\"]\n",
    "\n",
    "orig_repr_np = orig_repr.cpu().numpy()\n",
    "approx_repr_np = approx_repr.cpu().numpy()\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 2\n",
    "RATIO = _GOLDEN_RATIO\n",
    "\n",
    "# Use tueplots for iclr2024 formatting\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "plt.rcParams.update(bundles.iclr2024())\n",
    "plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "plt.rcParams.update(axes.lines())\n",
    "\n",
    "pca_orig = PCA(n_components=2)\n",
    "transformed_orig = pca_orig.fit_transform(orig_repr_np)\n",
    "\n",
    "df1 = pd.DataFrame(\n",
    "    {\n",
    "        \"sample_id\": range(len(orig_label)),\n",
    "        \"x\": transformed_orig[:, 0],\n",
    "        \"y\": transformed_orig[:, 1],\n",
    "        # \"labels\": list(map(str, orig_label.cpu().numpy())),\n",
    "        \"labels\": orig_label.cpu().numpy(),\n",
    "    }\n",
    ")\n",
    "\n",
    "pca_approx = PCA(n_components=2)\n",
    "transformed_approx = pca_approx.fit_transform(approx_repr_np)\n",
    "# transformed_approx = pca_orig.transform(approx_repr_np)\n",
    "\n",
    "df2 = pd.DataFrame(\n",
    "    {\n",
    "        \"sample_id\": range(len(orig_label)),\n",
    "        \"x\": transformed_approx[:, 0],\n",
    "        \"y\": transformed_approx[:, 1],\n",
    "        # \"labels\": list(map(str, orig_label.cpu().numpy())),\n",
    "        \"labels\": orig_label.cpu().numpy(),\n",
    "    }\n",
    ")\n",
    "\n",
    "labels, unique_labels = pd.factorize(df1[\"labels\"])\n",
    "\n",
    "palette = sns.color_palette(\"Set2\", len(unique_labels))\n",
    "color_map = {unique_labels[i]: palette[i] for i in range(len(unique_labels))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
    "\n",
    "for label in unique_labels:\n",
    "    idx = df1[\"labels\"] == label\n",
    "    axes[0].scatter(df1.loc[idx, \"x\"], df1.loc[idx, \"y\"], label=label, color=color_map[label])\n",
    "\n",
    "# axes[0].set_title(\"Last Layer Original Model\")\n",
    "# axes[0].legend(title=\"Classes\", loc=\"upper right\")\n",
    "\n",
    "for label in unique_labels:\n",
    "    idx = df2[\"labels\"] == label\n",
    "    axes[1].scatter(df2.loc[idx, \"x\"], df2.loc[idx, \"y\"], label=label, color=color_map[label])\n",
    "\n",
    "# axes[1].set_title(f\"Last Layer Approximating Layer {approx_layer}\")\n",
    "# axes[1].legend(title=\"Classes\", loc=\"upper right\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "fig.savefig(f\"{PLOTS_DIR}/{dataset_name}_A{approx_layer}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(CONSECUTIVE_APPROX)\n",
    "df = df[df[\"model\"] == model_name].drop(\n",
    "    columns=[\n",
    "        \"seed\",\n",
    "        \"num_layers\",\n",
    "        \"optimizer\",\n",
    "        \"lr\",\n",
    "        \"batch_size\",\n",
    "        \"delta_acc\",\n",
    "        \"num_epochs\",\n",
    "        \"original_accuracy\",\n",
    "        \"classifier\",\n",
    "    ]\n",
    ")\n",
    "\n",
    "mnist = df[df[\"dataset\"] == \"mnist\"]\n",
    "fmnist = df[df[\"dataset\"] == \"fashion-mnist\"]\n",
    "cifar10 = df[df[\"dataset\"] == \"cifar10\"]\n",
    "cifar100c = df[df[\"dataset\"] == \"cifar100-coarse\"]\n",
    "cifar100f = df[df[\"dataset\"] == \"cifar100-fine\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_accuracies = mnist.accuracy.to_list()[1:]\n",
    "fmnist_accuracies = fmnist.accuracy.to_list()[1:]\n",
    "cifar10_accuracies = cifar10.accuracy.to_list()[1:]\n",
    "cifar100c_accuracies = cifar100c.accuracy.to_list()[1:]\n",
    "cifar100f_accuracies = cifar100f.accuracy.to_list()[1:]\n",
    "\n",
    "# mnist_orig_acc = mnist.accuracy.to_list()[0]\n",
    "# fmnist_orig_acc = fmnist.accuracy.to_list()[0]\n",
    "# cifar10_orig_acc = cifar10.accuracy.to_list()[0]\n",
    "# cifar100c_orig_acc = cifar100c.accuracy.to_list()[0]\n",
    "# cifar100f_orig_acc = cifar100f.accuracy.to_list()[0]\n",
    "\n",
    "x_values = list(range(1, len(cifar100f_accuracies) + 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(x_values, mnist_accuracies, marker=\"o\", linestyle=\"-\", color=\"blue\", label=\"MNIST\")\n",
    "plt.plot(x_values, fmnist_accuracies, marker=\"o\", linestyle=\"-\", color=\"green\", label=\"F-MNIST\")\n",
    "plt.plot(x_values, cifar10_accuracies, marker=\"o\", linestyle=\"-\", color=\"pink\", label=\"CIFAR10\")\n",
    "plt.plot(x_values, cifar100c_accuracies, marker=\"o\", linestyle=\"-\", color=\"violet\", label=\"CIFAR100-c\")\n",
    "plt.plot(x_values, cifar100f_accuracies, marker=\"o\", linestyle=\"-\", color=\"orange\", label=\"CIFAR100-f\")\n",
    "\n",
    "# plt.axhline(y=orig_acc, color=\"blue\", linestyle=\"--\", label=\"original accuracy\")\n",
    "\n",
    "plt.legend()\n",
    "plt.title(f\"Classification Accuracy using {model_name.split('/')[1]}\")\n",
    "plt.xlabel(\"Block\")\n",
    "plt.ylabel(\"Value\")\n",
    "plt.xticks(x_values)\n",
    "\n",
    "PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "fig.savefig(f\"{PLOTS_DIR}/all_accuracies.pdf\", bbox_inches=\"tight\", pad_inches=0)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Number of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)\n",
    "# encoder = AutoModel.from_pretrained(model_name, config=config)\n",
    "# classifier = nn.Linear(encoder.config.hidden_size, DATASET2NUM_CLASSES[dataset_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ## manual\n",
    "# embeddings_param = count_parameters(encoder.embeddings)\n",
    "\n",
    "# encoder_params = 0\n",
    "# for layer in enumerate(encoder.encoder.layer):\n",
    "#     encoder_params = encoder_params + count_parameters(layer[1])\n",
    "\n",
    "# layernorm_params = count_parameters(encoder.layernorm)\n",
    "\n",
    "# pooler_params = 0\n",
    "# if model_name != \"facebook/dinov2-small\":\n",
    "#     pooler_params = count_parameters(encoder.pooler)\n",
    "\n",
    "# classifier_params = count_parameters(classifier)\n",
    "\n",
    "# tot_params = embeddings_param + encoder_params + layernorm_params + pooler_params + classifier_params\n",
    "# tot_params = convert_parameters(tot_params)\n",
    "# tot_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# skips = [(1, 3), (6, 8)]\n",
    "# layer_to_skip = [2, 3, 7, 8]\n",
    "# num_skip = len(layer_to_skip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ## manual\n",
    "# skip_encoder_params = 0\n",
    "# for layer in enumerate(encoder.encoder.layer):\n",
    "#     if layer[0] not in layer_to_skip:\n",
    "#         skip_encoder_params = skip_encoder_params + count_parameters(layer[1])\n",
    "\n",
    "# transf_param = (384 * 384) * num_skip  # *197\n",
    "\n",
    "# tot_skip_params = (\n",
    "#     embeddings_param + skip_encoder_params + transf_param + layernorm_params + pooler_params + classifier_params\n",
    "# )\n",
    "# tot_skip_params = convert_parameters(tot_skip_params)\n",
    "# tot_skip_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f\"Original params: {tot_params} vs. Skip params: {tot_skip_params}\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "layskip",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
