{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6822e3cf-65b4-432b-b56e-2c1d0af0cabf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.models import resnet18\n",
    "from torchvision.datasets import CIFAR10\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from tqdm.notebook import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a00c8821-682e-4120-8fd2-2d2d7ac4abe5",
   "metadata": {},
   "source": [
    "## Source: Bonet et al. 2023"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc271a22-e950-4f88-9d9f-2d5d1045b5cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "class L2Norm(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x / x.norm(p=2, dim=1, keepdim=True)\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, in_channel: int = 3, feat_dim: int = 128, no_bias = False):\n",
    "        super().__init__()\n",
    "        self.rn = resnet18(num_classes=32 * 32)\n",
    "\n",
    "        if no_bias:\n",
    "            self.rn.fc = nn.Linear(*self.rn.fc.weight.data.shape[::-1], bias=False)\n",
    "        self.rn.maxpool = nn.Identity()\n",
    "        self.rn.conv1 = nn.Conv2d(in_channel, 64,\n",
    "                kernel_size=3, stride=1, padding=2, bias=False)\n",
    "\n",
    "        self.predictor = nn.Sequential(\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(32 * 32, feat_dim, bias=False),\n",
    "            L2Norm(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x, layer_index:int = -1):\n",
    "        if layer_index == -1:\n",
    "            return self.predictor(self.rn(x))\n",
    "\n",
    "        if layer_index == -2:\n",
    "            return self.rn(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e38b10b-c658-459b-b56a-45e726bf69d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3w = ResNet(feat_dim=3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results/method_s3w_epochs_200_feat_dim_3_batch_size_512_num_projections_200_num_rotations_1_unif_w_0.1_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_no_bias/encoder.pth\")\n",
    "s3w.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99a31d4c-3207-4fa2-b964-7d43f5d7fcf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ri_s3w = ResNet(feat_dim=3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results/method_ri_s3w_epochs_200_feat_dim_3_batch_size_512_num_projections_200_num_rotations_10_unif_w_0.1_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_dim3_final2/encoder.pth\")\n",
    "ri_s3w.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b21d9eb-bac1-4a46-98ba-729edca6a446",
   "metadata": {},
   "outputs": [],
   "source": [
    "ari_s3w = ResNet(feat_dim=3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results/method_ari_s3w_epochs_200_feat_dim_3_batch_size_512_pool_size_200_num_projections_200_num_rotations_10_unif_w_0.1_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_dim3_reruns/encoder.pth\", map_location=device)\n",
    "ari_s3w.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f57b358-f528-49fc-ba3a-10d95374d799",
   "metadata": {},
   "outputs": [],
   "source": [
    "ssw = ResNet(feat_dim = 3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results/method_ssw_epochs_200_feat_dim_3_batch_size_512_num_projections_200_num_rotations_1_unif_w_20.0_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_reruns/encoder.pth\")\n",
    "ssw.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f0af70c-8bc7-4fb9-9411-b24dcd19dd7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sw = ResNet(feat_dim = 3).cuda().eval()\n",
    "# SSL/results/method_sw_epochs_200_feat_dim_3_batch_size_512_num_projections_200_num_rotations_1_unif_w_1.0_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_sw_updated/encoder.pth\n",
    "checkpoint = torch.load(\"./results/method_sw_epochs_200_feat_dim_3_batch_size_512_num_projections_200_num_rotations_1_unif_w_0.1_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_sw/encoder.pth\")\n",
    "sw.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42902503-8041-407e-85b4-7be14fc2f286",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypersphere = ResNet(feat_dim=3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results/method_hypersphere_epochs_200_feat_dim_3_batch_size_512_num_projections_1_num_rotations_2_unif_w_1.0_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_no_bias/encoder.pth\")\n",
    "hypersphere.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7f8834-eb96-41ba-b542-948dcb35119d",
   "metadata": {},
   "outputs": [],
   "source": [
    "simclr = ResNet(feat_dim=3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results/method_simclr_epochs_200_feat_dim_3_batch_size_512_num_projections_1_num_rotations_2_unif_w_20.0_align_w_1.0_lr_0.05_momentum_0.9_seed_0_weight_decay_0.001_no_bias/encoder.pth\")\n",
    "simclr.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "855890e6-ec39-464f-870e-af1a2985ce41",
   "metadata": {},
   "outputs": [],
   "source": [
    "supervised = ResNet(feat_dim=3).cuda().eval()\n",
    "checkpoint = torch.load(\"./results_supervised_dim3_s2/encoder.pth\")\n",
    "supervised.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81668e10-538f-45cf-b3e3-69f173477ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_encoder = ResNet(feat_dim=3, no_bias=True).cuda().eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6de0a4ca-dd26-43c1-a071-7d25ac18be3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_transform = lambda mean, std, resize, crop_size: torchvision.transforms.Compose(\n",
    "    [\n",
    "        torchvision.transforms.Resize(resize),\n",
    "        torchvision.transforms.CenterCrop(crop_size),\n",
    "        torchvision.transforms.ToTensor(),\n",
    "        torchvision.transforms.Normalize(mean=mean, std=std,),\n",
    "    ]\n",
    ")\n",
    "\n",
    "transform = get_transform(\n",
    "    mean=(0.4915, 0.4822, 0.4466),\n",
    "    std=(0.2470, 0.2435, 0.2616),\n",
    "    crop_size=32,\n",
    "    resize=32,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe5e533e-271e-42b6-a4a6-6c9eb0127854",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10 = CIFAR10(\"./data\", train=False, download=True, transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44cc3d80-27ef-438c-9fac-2498280771b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(cifar10, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d116675-23a6-4981-87e3-96c7d80f3740",
   "metadata": {},
   "outputs": [],
   "source": [
    "limited = lambda gen, size=-1: (gen) if size == -1 or size > len(gen) else (x for _, x in zip(range(size), gen))\n",
    "\n",
    "def get_embeddings(encoder):\n",
    "    all_z = None\n",
    "    all_y = None\n",
    "\n",
    "    n_batches = 100\n",
    "    for x, y in tqdm(limited(dataloader, size=n_batches)):\n",
    "        with torch.no_grad():\n",
    "            z = encoder(x.cuda(), layer_index = -1)\n",
    "\n",
    "        if all_z is None:\n",
    "            all_z = z.cpu()\n",
    "        else:\n",
    "            all_z = torch.cat((all_z, z.cpu()))\n",
    "\n",
    "        if all_y is None:\n",
    "            all_y = y.cpu()\n",
    "        else:\n",
    "            all_y = torch.cat((all_y, y.cpu()))\n",
    "    return all_z, all_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5048cceb-22e6-470a-bb89-305a405535f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_s3w, all_y_s3w = get_embeddings(s3w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "734ec778-89fa-4916-93a3-cb84d2613ebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_ri_s3w, all_y_ri_s3w = get_embeddings(ri_s3w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6892b6a2-c6a3-47cd-bce7-48767241f8a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_ari_s3w, all_y_ari_s3w = get_embeddings(ari_s3w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f654c35-a62a-4b7b-b77e-75a982f22cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_sw, all_y_sw = get_embeddings(sw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd0e31a2-f966-4887-b445-614384ac7d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_hypersphere, all_y_hypersphere = get_embeddings(hypersphere)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "532ccf36-e6a7-48ad-ae45-2961e8487ec4",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_ssw, all_y_ssw = get_embeddings(ssw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bc5a36f-f43e-4703-a3b8-1d1adaccde70",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_simclr, all_y_simclr = get_embeddings(simclr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcad28f2-e513-48ff-be10-10853282691e",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_rand, all_y_rand = get_embeddings(random_encoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd37adfd-ada6-42ac-9474-14b7a1271c45",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_z_supervised, all_y_supervised = get_embeddings(supervised)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ae2f367-518f-45c3-885e-0a149bdb3148",
   "metadata": {},
   "outputs": [],
   "source": [
    "def scatter_dists(all_z, all_y, desc=None, include_legend=False, export_legend=False):\n",
    "    fig = plt.figure()\n",
    "    plt.subplot(111, projection=\"mollweide\")\n",
    "\n",
    "    for i in range(10):\n",
    "        selector = all_y == i\n",
    "        θ = torch.atan2(-all_z[selector, 1], -all_z[selector, 0])\n",
    "        ϕ = torch.asin(all_z[selector,2])\n",
    "        plt.scatter(θ, ϕ,\n",
    "                    s=.7, label=cifar10.classes[i]) # marker=',', label = \"\")\n",
    "\n",
    "    # desc is None or plt.title(desc)\n",
    "    if include_legend:\n",
    "        legend = plt.legend(bbox_to_anchor=(1,0.5), loc=\"center left\", markerscale=15, fontsize=16)\n",
    "        # legend = plt.legend( loc=\"lower center\", markerscale=15, ncols=5, bbox_to_anchor=(0.5,-0.4), fontsize=12) # 1.1\n",
    "    \n",
    "        if export_legend:\n",
    "            fig = legend.figure\n",
    "            fig.canvas.draw()\n",
    "            bbox  = legend.get_window_extent()\n",
    "            bbox = bbox.from_extents(*(bbox.extents + np.array([-5,-5,5,5])))\n",
    "            bbox = bbox.transformed(fig.dpi_scale_trans.inverted())\n",
    "            fig.savefig('features_legend.pdf', dpi=1000, bbox_inches=bbox)\n",
    "\n",
    "\n",
    "    plt.subplots_adjust(right=0.75)\n",
    "    plt.grid()\n",
    "    \n",
    "    suffix = \"_\".join(desc.lower().split())\n",
    "    plt.savefig(f\"features_plot_{suffix}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74936bc8-d0f7-403c-b9fe-cfe20d5410dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_supervised, all_y_supervised, \"Supervised predictive\", include_legend=True, export_legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c80b083-37da-49d2-a83d-bdc2f95dc702",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_s3w, all_y_s3w, \"S3W\", include_legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e440f18-a9ec-4039-accc-da5f1616eb97",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_ri_s3w, all_y_ri_s3w, \"RI-S3W\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "598d3e8f-f015-491a-a75e-e9ae44df4728",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_ari_s3w, all_y_ari_s3w, \"ARI-S3W\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c4c9a2-c833-4e38-9b48-5000e076e8ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_sw, all_y_sw, \"SW\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9096c46d-9ab9-4dc1-beb9-1492cf307b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_ssw, all_y_ssw, \"SSW\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a3c80aa-7b50-4ab4-bb5a-05dee5e343fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_hypersphere, all_y_hypersphere, \"Wang\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b7c9c66-c6f8-4057-b2dc-59b29ebca558",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_simclr, all_y_simclr, \"SimCLR\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de44379c-3ca2-4725-8ceb-0aab33f306aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_dists(all_z_rand, all_y_rand, \"Random initialization\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
