{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab0095f-35c4-474a-b354-9305e4e756c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# IMPORTS\n",
    "# =============================================================================\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, random_split, Subset\n",
    "from torchvision.datasets import CIFAR100\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet34, ResNet34_Weights\n",
    "import numpy as np\n",
    "import os\n",
    "import shutil\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87321cc0-dcf9-4a0a-bcd9-169a99531cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# CONFIGURATION\n",
    "# =============================================================================\n",
    "\n",
    "class Config:\n",
    "    \"\"\"Configuration class for all hyperparameters.\"\"\"\n",
    "    # Data parameters\n",
    "    data_dir = './data'\n",
    "\n",
    "    # Model parameters\n",
    "    resnet_embedding_dim = 512                 # ResNet34 output dimension\n",
    "    mlp_hidden_dim = 1024                      # Monotonic MLP hidden dimension\n",
    "    mlp_num_layers = 1                         # Monotonic MLP number of hidden layers\n",
    "\n",
    "    # Training parameters\n",
    "    batch_size = 256                           # Mini-batch size\n",
    "    lr_resnet = 2e-4                           # Learning rate for the ResNet encoder\n",
    "    lr_mlp = 2e-4                              # Learning rate for the monotonic MLP head\n",
    "    seed = 42                                  # Random number generator seed\n",
    "    temperature = 0.1\n",
    "    is_ablation = False                        # Set to True to bypass the monotonic MLP head for the ablation study training run.\n",
    "\n",
    "    # Evaluation parameters\n",
    "    knn_k = 5                                  # Number of nearest neighbors in the k nearest neighbors (knn) algorithm\n",
    "    \n",
    "    # System parameters\n",
    "    base_results_dir = './results_CIFAR100_kNN_monitoring_initial_train_dynamics' # Name of the base directory to save checkpoints and other results\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'                  # Specify whether using GPU or CPU\n",
    "\n",
    "args = Config()                                # Create an instance of the Config class to pass arguments to functions as needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9676548-12d2-4fd1-9fce-71e8859e473e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# SETUP REPRODUCIBILITY AND DIRECTORIES\n",
    "# =============================================================================\n",
    "torch.manual_seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "\n",
    "run_name = (\n",
    "    f\"temp_{args.temperature}_bs_{args.batch_size}_lr_encoder_{args.lr_resnet}_num_mlp_layers_{args.mlp_num_layers}\"\n",
    "    f\"{f'_lr_head_{args.lr_mlp}' if not args.is_ablation else ''}\"        # Conditionally add head learning rate\n",
    "    f\"{'_ablation' if args.is_ablation else ''}\"       \n",
    ")\n",
    "loading_dir = os.path.join(args.base_results_dir, run_name)\n",
    "os.makedirs(loading_dir, exist_ok=True)\n",
    "print(f\"Checkpoints will be loaded from this directory: {loading_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ee61a76-82cc-43ef-8f7b-80ec19b4aa9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# TRANSFORMS\n",
    "# =============================================================================\n",
    "\n",
    "# Use ImageNet's standard normalization statistics for the pre-trained ResNet\n",
    "imagenet_mean = [0.485, 0.456, 0.406]\n",
    "imagenet_std = [0.229, 0.224, 0.225]\n",
    "\n",
    "eval_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),                                                           # Converts image to a Pytorch tensor\n",
    "    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)                       # Normalizes by subtracting mean and dividing by std\n",
    "])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e95e61f4-d62f-44c7-8694-66b65222bb7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# MODEL CLASSES\n",
    "# =============================================================================\n",
    "\n",
    "class MonotonicMLP(nn.Module):\n",
    "    \"\"\"Monotonic MLP with positive weights to ensure an increasing function.\"\"\"\n",
    "    def __init__(self, input_dim, hidden_dim, num_layers):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        layers = []\n",
    "        current_dim = input_dim\n",
    "        for _ in range(num_layers):\n",
    "            layers.append(nn.Linear(current_dim, hidden_dim))\n",
    "            current_dim = hidden_dim\n",
    "        self.hidden_layers = nn.ModuleList(layers)\n",
    "        self.output_layer = nn.Linear(current_dim, input_dim)\n",
    "\n",
    "    def forward(self, x):                                   # Implement monotonicity constraints via squared weights and non-decreasing activations\n",
    "        z = x   \n",
    "        for layer in self.hidden_layers:\n",
    "            positive_weight = layer.weight**2               # Using squared weights to ensure positivity in the hidden layers during the forward pass\n",
    "            z = F.leaky_relu(F.linear(z, positive_weight, layer.bias))  # Uses the non-decreasing Leaky Relu activation function\n",
    "        positive_weight_out = self.output_layer.weight**2\n",
    "        z = F.linear(z, positive_weight_out, self.output_layer.bias)    # Squared weights are used for positivity in the output as well\n",
    "        return z\n",
    "\n",
    "        \n",
    "class MonoCon_ResNet(nn.Module):\n",
    "    \"\"\"Metric learning model using a pre-trained ResNet and a Monotonic MLP head.\"\"\"\n",
    "    def __init__(self, resnet_embedding_dim, mlp_hidden_dim, mlp_num_layers, is_ablation=False):\n",
    "        super().__init__()\n",
    "        self.is_ablation = is_ablation                                       # Flag indicating whether to run the full model or ablation study\n",
    "        self.encoder = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)      # Initialize ResNet34 encoder with pre-trained weights\n",
    "        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # Modify convolutional layer for low res CIFAR-100 images\n",
    "        self.encoder.maxpool = nn.Identity()             # Remove max pooling layer to prevent over-compressing low res CIFAR-100 images\n",
    "        self.encoder.fc = nn.Identity()                  # Remove the classification layer at the end\n",
    "        if not self.is_ablation:\n",
    "            self.monotonic_mlp = MonotonicMLP(resnet_embedding_dim, mlp_hidden_dim, mlp_num_layers) # Ablation study is run without the MLP head\n",
    "\n",
    "    def forward(self, x):\n",
    "        y_encoder = self.encoder(x)\n",
    "        y = y_encoder if self.is_ablation else self.monotonic_mlp(y_encoder)    \n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bb09bf3-827f-43f2-92fe-f231c641c7db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# FUNCTION FOR COMPUTING EMBEDDINGS\n",
    "# =============================================================================\n",
    "\n",
    "def compute_embeddings(model, dataloader, device):\n",
    "    model.eval()\n",
    "    y_list, labels_list = [], []\n",
    "    pbar = tqdm(dataloader, desc=\"Computing embeddings\", unit=\"batch\", leave=False)\n",
    "    for images, labels in pbar:\n",
    "        images = images.to(device)\n",
    "        labels_list.append(labels.cpu())\n",
    "        with torch.no_grad():\n",
    "            y_batch = model(images)\n",
    "        y_list.append(y_batch.cpu())\n",
    "    y_embed = torch.cat(y_list, dim=0)\n",
    "    all_labels = torch.cat(labels_list, dim=0)\n",
    "    return y_embed, all_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cbb2f7d-045f-406e-b26a-db25f9d6265c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# DATASETS AND DATALOADERS\n",
    "# =============================================================================\n",
    "# Create train and test datasets\n",
    "train_dataset = CIFAR100(root=args.data_dir, train=True, download=True, transform=eval_transform)   # Tensor-ified and normalized train dataset\n",
    "test_dataset = CIFAR100(root=args.data_dir, train=False, download=True, transform=eval_transform)   # Tensor-ified and normalized test dataset\n",
    "\n",
    "num_train = len(train_dataset)\n",
    "indices = list(range(num_train))\n",
    "split = int(0.98 * num_train)\n",
    "print(split)\n",
    "random.seed(args.seed)\n",
    "random.shuffle(indices)\n",
    "subset_indices = indices[split:]\n",
    "\n",
    "train_data_subset = Subset(train_dataset,subset_indices)\n",
    "print(len(train_data_subset))\n",
    "\n",
    "# 4. Create train and test dataloaders\n",
    "train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "train_subset_loader =  DataLoader(train_data_subset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "\n",
    "print(\"Data loaded successfully.\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd36ec1-290d-4f9e-b77d-766f23058e9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ========================================================================================================\n",
    "# FUNCTIONS FOR COMPUTE EFFECTIVE DIMENSIONALITY AND GENERATING AND SAVING CLUSTERMAPS\n",
    "# ========================================================================================================\n",
    "\n",
    "def get_pca_components_for_variance(data, variance_threshold=0.99):\n",
    "    \"\"\"\n",
    "    Fits PCA to the data and returns the number of components \n",
    "    needed to explain a certain amount of variance.\n",
    "    \"\"\"\n",
    "    # 1. Initialize and fit PCA to find all components\n",
    "    pca = PCA(n_components=None, random_state=42)\n",
    "    pca.fit(data)\n",
    "    \n",
    "    # 2. Calculate the cumulative sum of explained variance\n",
    "    cumulative_variance = np.cumsum(pca.explained_variance_ratio_)\n",
    "    \n",
    "    # 3. Find the number of components to reach the threshold\n",
    "    # np.searchsorted finds the index where the threshold would be inserted\n",
    "    # We add 1 to convert the 0-based index to a 1-based count\n",
    "    n_components = np.searchsorted(cumulative_variance, variance_threshold) + 1\n",
    "    \n",
    "    return n_components\n",
    "\n",
    "def generate_and_save_clustermap(embeddings, title, save_path):\n",
    "    \"\"\"\n",
    "    Computes a feature correlation matrix from embeddings, generates a clustermap,\n",
    "    and saves it to the specified path.\n",
    "    \"\"\"\n",
    "    print(f\"Generating clustermap for: {title}\")\n",
    "    \n",
    "    # 1. Compute the feature correlation matrix\n",
    "    corr_matrix = np.corrcoef(embeddings.cpu().numpy(), rowvar=False)\n",
    "    \n",
    "    # 2. Generate the clustermap\n",
    "    g = sns.clustermap(\n",
    "        corr_matrix,\n",
    "        cmap='RdBu_r',\n",
    "        vmin=-1,\n",
    "        vmax=1,\n",
    "        figsize=(10, 10),\n",
    "        cbar_pos=(0.02, 0.8, 0.03, 0.15)\n",
    "    )\n",
    "    \n",
    "    # 3. Set labels, title, and save the figure\n",
    "    g.ax_heatmap.set_xlabel('')\n",
    "    g.ax_heatmap.set_ylabel('')\n",
    "    g.fig.suptitle(title, fontsize=40, y=1.04)\n",
    "    g.savefig(save_path, bbox_inches='tight')\n",
    "    \n",
    "    # 4. Close the plot to free up memory\n",
    "    plt.close(g.fig)\n",
    "    \n",
    "    print(f\"Plot saved to {save_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "178e824f-46a2-4398-b760-2a058bbf9ae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===================================================================================================================\n",
    "# LOAD CHECKPOINTS, COMPUTE EFFECTIVE DIMENSIONALITY, PLOT AND SAVE CLUSTERMAPS FOR ENCODER AS WELL AS MODEL OUTPUTS\n",
    "# ===================================================================================================================\n",
    "\n",
    "model = MonoCon_ResNet(\n",
    "    resnet_embedding_dim=args.resnet_embedding_dim,\n",
    "    mlp_hidden_dim=args.mlp_hidden_dim,\n",
    "    mlp_num_layers=args.mlp_num_layers,\n",
    "    is_ablation=args.is_ablation,\n",
    ").to(args.device)\n",
    "\n",
    "model.eval()\n",
    "\n",
    "epoch_num = list(range(1,201))\n",
    "\n",
    "encoder_dim = np.zeros(len(epoch_num))\n",
    "unnorm_mlp_dim = np.zeros(len(epoch_num))\n",
    "norm_mlp_dim = np.zeros(len(epoch_num))\n",
    "\n",
    "# --- Setup Directories for saving clustermaps ---\n",
    "encoder_plot_dir = os.path.join(loading_dir, 'clustermaps_encoder')\n",
    "unnorm_plot_dir = os.path.join(loading_dir, 'clustermaps_unnormalized')\n",
    "norm_plot_dir = os.path.join(loading_dir, 'clustermaps_normalized')\n",
    "\n",
    "os.makedirs(encoder_plot_dir, exist_ok=True)\n",
    "os.makedirs(unnorm_plot_dir, exist_ok=True)\n",
    "os.makedirs(norm_plot_dir, exist_ok=True)\n",
    "\n",
    "for i in epoch_num:\n",
    "    \n",
    "    print(i)\n",
    "    checkpoint_path = os.path.join(loading_dir, f\"checkpoint_epoch_{i}.pth\")\n",
    "\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        raise FileNotFoundError(f\"Model file not found at {checkpoint_path}.\")\n",
    "        \n",
    "    checkpoint = torch.load(checkpoint_path, map_location=args.device)\n",
    "    model.load_state_dict(checkpoint['model_state_dict'])\n",
    "    #print(\"Model loaded successfully.\")\n",
    "\n",
    "    # Compute embeddings for encoder output, unnormalized model output, and normalized model output\n",
    "    train_subset_encoder_embed, train_subset_encoder_labels = compute_embeddings(model.encoder, train_subset_loader, args.device)\n",
    "    train_subset_output_embed, train_subset_output_labels = compute_embeddings(model, train_subset_loader, args.device)\n",
    "    train_subset_output_embed_norm = F.normalize(train_subset_output_embed, p=2, dim=1)\n",
    "\n",
    "    # Convert tensors to numpy arrays \n",
    "    enc_embed_np = train_subset_encoder_embed.cpu().numpy()\n",
    "    out_embed_np = train_subset_output_embed.cpu().numpy()\n",
    "    out_embed_norm_np = train_subset_output_embed_norm.cpu().numpy()\n",
    "\n",
    "    # Compute effective dimensionality of the representation, defined as number of PCA components that explain 99% variance\n",
    "    # PCA and clustermaps for encoder\n",
    "    encoder_dim[i-1] = get_pca_components_for_variance(enc_embed_np)\n",
    "    unnorm_mlp_dim[i-1] = get_pca_components_for_variance(out_embed_np)\n",
    "    norm_mlp_dim[i-1] = get_pca_components_for_variance(out_embed_norm_np)\n",
    "\n",
    "    # --- C: Generate and Save Clustermaps ---\n",
    "    # Encoder Clustermap\n",
    "    enc_save_path = os.path.join(encoder_plot_dir, f\"clustermap_epoch_{i}.png\")\n",
    "    generate_and_save_clustermap(train_subset_encoder_embed, f\"Encoder Clustermap - Epoch {i}\", enc_save_path)\n",
    "    \n",
    "    # Unnormalized Output Clustermap\n",
    "    unnorm_save_path = os.path.join(unnorm_plot_dir, f\"clustermap_epoch_{i}.png\")\n",
    "    generate_and_save_clustermap(train_subset_output_embed, f\"Unnormalized Output - Epoch {i}\", unnorm_save_path)\n",
    "\n",
    "    # Normalized Output Clustermap\n",
    "    norm_save_path = os.path.join(norm_plot_dir, f\"clustermap_epoch_{i}.png\")\n",
    "    generate_and_save_clustermap(train_subset_output_embed_norm, f\"Normalized Output - Epoch {i}\", norm_save_path)\n",
    "\n",
    "    print(\"-\" * 50) # Separator for clarity\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5deb990-b732-4a89-99ca-4f847563be77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure with two subplots side-by-side\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))\n",
    "axis_label_font_size = 18\n",
    "title_font_size = 20\n",
    "tick_label_fontsize = 16\n",
    "\n",
    "# --- Plot 1: Encoder Dimensionality ---\n",
    "ax1.plot(epoch_num, encoder_dim, linestyle='-', label='Encoder')\n",
    "ax1.set_title('Encoder',fontsize=title_font_size)\n",
    "ax1.set_xlabel('Epoch', fontsize=axis_label_font_size)\n",
    "ax1.set_ylabel('Effective dim', fontsize=axis_label_font_size)\n",
    "ax1.tick_params(axis='both', labelsize=tick_label_fontsize)\n",
    "ax1.grid(False)\n",
    "\n",
    "# --- Plot 2: MLP Head Dimensionality ---\n",
    "ax2.plot(epoch_num, unnorm_mlp_dim, linestyle='-', label='Unnormalized monotonic MLP Output')\n",
    "ax2.plot(epoch_num, norm_mlp_dim, linestyle='--', label='Normalized monotonic MLP Output')\n",
    "ax2.set_title('Monotonic MLP Head', fontsize=title_font_size)\n",
    "ax2.set_xlabel('Epoch', fontsize=axis_label_font_size)\n",
    "ax2.set_ylabel('Effective dim', fontsize=axis_label_font_size)\n",
    "ax2.tick_params(axis='both', labelsize=tick_label_fontsize)\n",
    "ax2.legend(fontsize = 16)\n",
    "ax2.grid(False)\n",
    "\n",
    "# Adjust layout to prevent overlap and save the figure\n",
    "plt.tight_layout()\n",
    "\n",
    "file_path = os.path.join(loading_dir, 'dimensionality_vs_epochs.png')\n",
    "plt.savefig(file_path)\n",
    "\n",
    "print(f\"Plot saved to {file_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c2dfd8-a1a8-47a7-8bde-e93a72e5b846",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(encoder_dim.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45259625-4b0f-429b-b3cc-d5f757612d28",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
