{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from pytorch_lightning import seed_everything\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModel,\n",
    "    AutoImageProcessor,\n",
    ")\n",
    "import functools\n",
    "import matplotlib.pyplot as plt\n",
    "from layskip.utils import similarities\n",
    "from layskip.utils.utils import image_encode, extract_all_layers\n",
    "from layskip.utils.dictionaries import (\n",
    "    DATASET2IMAGE_COLUMN,\n",
    "    DATASET2LABEL_COLUMN,\n",
    "    DATASET2NUM_CLASSES,\n",
    "    DATASET_NAME2HF_NAME,\n",
    ")\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "from tueplots.figsizes import _GOLDEN_RATIO\n",
    "from tqdm import tqdm\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "seed_everything(seed)\n",
    "\n",
    "model_name = \"facebook/deit-small-patch16-224\"\n",
    "\n",
    "# \"WinKawaks/vit-small-patch16-224\"\n",
    "# google/vit-base-patch16-224\n",
    "# \"facebook/deit-small-patch16-224\"\n",
    "# \"facebook/dinov2-smalle\"\n",
    "# \"facebook/dinov2-base\"\n",
    "\n",
    "dataset_names = [\"mnist\", \"fashion-mnist\", \"cifar10\", \"cifar100\"]\n",
    "for dataset_name in tqdm(dataset_names):\n",
    "\n",
    "    PLOTS_DIR = PROJECT_ROOT / \"plots\" / model_name.split(\"/\")[1]\n",
    "\n",
    "    dataset = load_dataset(DATASET_NAME2HF_NAME[dataset_name])\n",
    "    train_dataset = dataset[\"train\"]\n",
    "    test_dataset = dataset[\"test\"]\n",
    "    image_name = DATASET2IMAGE_COLUMN[dataset_name]\n",
    "    label_name = DATASET2LABEL_COLUMN[dataset_name]\n",
    "    num_classes = DATASET2NUM_CLASSES[dataset_name]\n",
    "\n",
    "    config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)\n",
    "    processor = AutoImageProcessor.from_pretrained(model_name)\n",
    "    encoder = AutoModel.from_pretrained(model_name, config=config)\n",
    "    encoder.to(device)\n",
    "    encoder.eval()\n",
    "\n",
    "    dataloader = DataLoader(\n",
    "        train_dataset,\n",
    "        batch_size=256,\n",
    "        shuffle=True,\n",
    "        num_workers=8,\n",
    "        pin_memory=True,\n",
    "        collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),\n",
    "    )\n",
    "\n",
    "    max_samples = 1000\n",
    "    cls_layer_embeddings = extract_all_layers(encoder, max_samples, dataloader, True)\n",
    "    cls_layer_embeddings = [layer_output for _, layer_output in cls_layer_embeddings.items()]\n",
    "\n",
    "    N_ROWS = 1\n",
    "    N_COLS = 1\n",
    "    RATIO = _GOLDEN_RATIO\n",
    "\n",
    "    # Use tueplots for iclr2024 formatting\n",
    "    plt.rcParams.update({\"figure.dpi\": 150})\n",
    "    plt.rcParams.update(bundles.iclr2024())\n",
    "    plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "    plt.rcParams.update(axes.lines())\n",
    "\n",
    "    similarity_matrix = similarities.pairwise_layer_cosine_similarity(cls_layer_embeddings)\n",
    "\n",
    "    fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, sharex=True, sharey=True)\n",
    "\n",
    "    f = sns.heatmap(\n",
    "        similarity_matrix,\n",
    "        annot=True,\n",
    "        cmap=\"viridis\",\n",
    "        fmt=\".2f\",\n",
    "        xticklabels=range(len(similarity_matrix)),\n",
    "        yticklabels=range(len(similarity_matrix)),\n",
    "        ax=ax,\n",
    "        vmax=1,\n",
    "        annot_kws={\"size\": 8},  # Reduce annotation font size here\n",
    "    )\n",
    "\n",
    "    ax.set_aspect(\"equal\")\n",
    "\n",
    "    PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "    fig.savefig(f\"{PLOTS_DIR}/cos_matrix_{dataset_name}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "layskip",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
