{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6ea0bcf-2afc-4cf2-b517-6a3c64748321",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# IMPORTS\n",
    "# =============================================================================\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, random_split, Subset\n",
    "from torchvision.datasets import CIFAR10\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet34, ResNet34_Weights\n",
    "import numpy as np\n",
    "import os\n",
    "import shutil\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9da5ead9-8de3-4f83-820a-8b337782cae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# CONFIGURATION\n",
    "# =============================================================================\n",
    "\n",
    "class Config:\n",
    "    \"\"\"Configuration class for all hyperparameters.\"\"\"\n",
    "    # Data parameters\n",
    "    data_dir = './data'\n",
    "\n",
    "    # Model parameters\n",
    "    resnet_embedding_dim = 512                 # ResNet34 output dimension\n",
    "    mlp_hidden_dim = 1024                      # Monotonic MLP hidden dimension\n",
    "    mlp_num_layers = 1                         # Monotonic MLP number of hidden layers\n",
    "\n",
    "    # Training parameters\n",
    "    epochs = 200                               # Maximum number of epochs used for training\n",
    "    patience = 20                              # Number of epochs after which early stopping is triggered\n",
    "    batch_size = 256                           # Mini-batch size\n",
    "    warmup_epochs = 10                         # Number of epochs to warm up the head. Set to 0 to disable.\n",
    "    lr_warmup = 1e-4                           # Learning rate specifically for the warm-up phase.\n",
    "    lr_resnet = 5e-4                           # Learning rate for the ResNet encoder\n",
    "    lr_mlp = 5e-4                              # Learning rate for the monotonic MLP head\n",
    "    weight_decay = 1e-4                        # Weight decay (L2) regularization\n",
    "    temperature = 0.1                          # Temperature parameter for SupCon loss. Controls the importance of hard negatives in the learning task\n",
    "    grad_clip_norm = 1.0                       # Max norm of gradients. Leads to more stable training\n",
    "    seed = 42                                  # Random number generator seed\n",
    "    is_ablation = True                        # Set to True to bypass the monotonic MLP head for the ablation study training run.\n",
    "\n",
    "    # Evaluation parameters\n",
    "    knn_k = 5                                  # Number of nearest neighbors in the k nearest neighbors (knn) algorithm\n",
    "    recall_k = 1                               # Number of top matches for the recall@k metric\n",
    "    validation_frequency = 5                   ## frequency (in number of epochs) with which validation metrics are calculated\n",
    "    \n",
    "    # System parameters\n",
    "    base_results_dir = './results_CIFAR10_kNN_monitoring'                   # Name of the base directory to save checkpoints and other results\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'                  # Specify whether using GPU or CPU\n",
    "\n",
    "args = Config()                                # Create an instance of the Config class to pass arguments to functions as needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12337601-d57b-45b8-ac75-c49bda2ecc1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# SETUP REPRODUCIBILITY AND DIRECTORIES\n",
    "# =============================================================================\n",
    "print(f\"Using device: {args.device}\")\n",
    "torch.manual_seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "if args.device == 'cuda':\n",
    "    torch.cuda.manual_seed_all(args.seed)\n",
    "torch.backends.cudnn.deterministic = True               # Use only deterministic algorithms in the cuDNN backend\n",
    "torch.backends.cudnn.benchmark = False                  # Prevents cuDNN from potentially picking a fast but non-deterministic algorithm.\n",
    "\n",
    "run_name = (\n",
    "    f\"temp_{args.temperature}_bs_{args.batch_size}_lr_encoder_{args.lr_resnet}_num_mlp_layers_{args.mlp_num_layers}\"\n",
    "    f\"{f'_lr_head_{args.lr_mlp}' if not args.is_ablation else ''}\"        # Conditionally add head learning rate\n",
    "    f\"{'_ablation' if args.is_ablation else ''}\"       \n",
    ")\n",
    "output_dir = os.path.join(args.base_results_dir, run_name)\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "print(f\"Output directory for this run: {output_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc87ea75-24d3-42e4-902e-5cd524e04b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# TRANSFORMS\n",
    "# =============================================================================\n",
    "\n",
    "# Use ImageNet's standard normalization statistics for the pre-trained ResNet\n",
    "imagenet_mean = [0.485, 0.456, 0.406]\n",
    "imagenet_std = [0.229, 0.224, 0.225]\n",
    "\n",
    "# Note: Order of composition of transforms is important. Operations must be performed on correct data types. \n",
    "# For eg. .ToTensor() must precede .Normalize() as normalization must be performed on pytorch tensors, not images.\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.TrivialAugmentWide(),                                                 # Applies one out of a wide range of transformations randomly\n",
    "    transforms.ToTensor(),                                                           # Converts image to a Pytorch tensor\n",
    "    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0),  # Erases pixels in a random rectangle with probability p\n",
    "    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)                       # Normalizes by subtracting mean and dividing by std\n",
    "])\n",
    "\n",
    "test_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c3fa474-7de7-4234-898d-94f7ea9175b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# MODEL AND DATA CLASSES\n",
    "# =============================================================================\n",
    "\n",
    "class SupConLoss(nn.Module):\n",
    "    \"\"\"Supervised Contrastive Learning loss function\"\"\"\n",
    "    def __init__(self, temperature=0.07):\n",
    "        super().__init__()\n",
    "        self.temperature = temperature\n",
    "\n",
    "    def forward(self, features, labels):\n",
    "        device = features.device                                     # Identify the device on which the features tensor is stored\n",
    "        batch_size = features.shape[0]                               # Extract the batch_size. Useful for later operations\n",
    "        labels = labels.contiguous().view(-1,1)                      # Assign a continuous memory block to sotre labels. Reshape the labels tensor\n",
    "        mask = torch.eq(labels,labels.T).float().to(device)          # Creates a square matrix mask via broadcasting, to identify positive pairs\n",
    "        anchor_dot_contrast = torch.div(torch.matmul(features,features.T),self.temperature)   # Matrix of scaled feature dot products\n",
    "        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)          # Compute maximum value of the scaled dot product in each row\n",
    "        # The next line subtracts a constant from each row, which leaves the loss unchanged but prevents numerical instability\n",
    "        logits = anchor_dot_contrast - logits_max.detach()\n",
    "        # The following line creates a mask with diagonal elements 0 and the rest 1. \n",
    "        # Implementation using torch.scatter instead of the simple torch.eye leads to smoother training and better learned representations\n",
    "        logits_mask = torch.scatter(                                 \n",
    "            torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).to(device), 0  \n",
    "        )                                           \n",
    "        mask = mask*logits_mask                     # Removes diagonal elements, so that positive pairs of samples with themselves are not counted\n",
    "        \n",
    "        # ====== Main SupCon loss calculation =============\n",
    "        exp_logits = torch.exp(logits)*logits_mask\n",
    "        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-9)\n",
    "        mean_log_prob_pos = (mask*log_prob).sum(1) / (mask.sum(1) + 1e-9)\n",
    "        loss = -mean_log_prob_pos.mean()\n",
    "        # =================================================\n",
    "        return loss\n",
    "\n",
    "\n",
    "class MonotonicMLP(nn.Module):\n",
    "    \"\"\"Monotonic MLP with positive weights to ensure an increasing function.\"\"\"\n",
    "    def __init__(self, input_dim, hidden_dim, num_layers):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        layers = []\n",
    "        current_dim = input_dim\n",
    "        for _ in range(num_layers):\n",
    "            layers.append(nn.Linear(current_dim, hidden_dim))\n",
    "            current_dim = hidden_dim\n",
    "        self.hidden_layers = nn.ModuleList(layers)\n",
    "        self.output_layer = nn.Linear(current_dim, input_dim)\n",
    "\n",
    "    def forward(self, x):                                   # Implement monotonicity constraints via squared weights and non-decreasing activations\n",
    "        z = x   \n",
    "        for layer in self.hidden_layers:\n",
    "            positive_weight = layer.weight**2               # Using squared weights to ensure positivity in the hidden layers during the forward pass\n",
    "            z = F.leaky_relu(F.linear(z, positive_weight, layer.bias))  # Uses the non-decreasing Leaky Relu activation function\n",
    "        positive_weight_out = self.output_layer.weight**2\n",
    "        z = F.linear(z, positive_weight_out, self.output_layer.bias)    # Squared weights are used for positivity in the output as well\n",
    "        return z\n",
    "\n",
    "        \n",
    "class MonoCon_ResNet(nn.Module):\n",
    "    \"\"\"Metric learning model using a pre-trained ResNet and a Monotonic MLP head.\"\"\"\n",
    "    def __init__(self, resnet_embedding_dim, mlp_hidden_dim, mlp_num_layers, is_ablation=False):\n",
    "        super().__init__()\n",
    "        self.is_ablation = is_ablation                                       # Flag indicating whether to run the full model or ablation study\n",
    "        self.encoder = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)      # Initialize ResNet34 encoder with pre-trained weights\n",
    "        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # Modify convolutional layer for low res CIFAR-10 images\n",
    "        self.encoder.maxpool = nn.Identity()             # Remove max pooling layer to prevent over-compressing low res CIFAR-10 images\n",
    "        self.encoder.fc = nn.Identity()                  # Remove the classification layer at the end\n",
    "        if not self.is_ablation:\n",
    "            self.monotonic_mlp = MonotonicMLP(resnet_embedding_dim, mlp_hidden_dim, mlp_num_layers) # Ablation study is run without the MLP head\n",
    "\n",
    "    def forward(self, x):\n",
    "        y_encoder = self.encoder(x)\n",
    "        y = y_encoder if self.is_ablation else self.monotonic_mlp(y_encoder)    \n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de66d9c1-36bc-4220-8d29-a2188cb1136d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# HELPER FUNCTIONS\n",
    "# =============================================================================\n",
    "\n",
    "def calculate_recall_at_k(query_embed, query_labels, gallery_embed, gallery_labels, k):\n",
    "    \"\"\"Calculates Recall@K.\"\"\"\n",
    "    query_embed = F.normalize(query_embed, p=2, dim=1)                           \n",
    "    gallery_embed = F.normalize(gallery_embed, p=2, dim=1)\n",
    "    sim_matrix = torch.matmul(query_embed, gallery_embed.T)           # Computes the cosine similarity matrix between query and gallery embeddings\n",
    "    \n",
    "    # Extract indices and labels of the top k matches (largest cosine similarity) to the query\n",
    "    top_k_indices = torch.topk(sim_matrix, k=k, dim=1).indices        \n",
    "    top_k_labels = gallery_labels[top_k_indices]\n",
    "\n",
    "    # Compute the fraction of queries for which a match was found within the top k entries\n",
    "    correct_recalls = (top_k_labels == query_labels.unsqueeze(1)).any(dim=1)\n",
    "    recall_at_k = correct_recalls.float().mean().item()\n",
    "    return recall_at_k\n",
    "\n",
    "\n",
    "def run_validation_metrics(model, val_loader, gallery_loader, criterion, args):\n",
    "    \"\"\"Runs a full validation, returning loss, k-NN accuracy, and Recall@K.\"\"\"\n",
    "    model.eval()\n",
    "\n",
    "    #  Compute gallery and query embeddings for quantifying recall@k score\n",
    "    gallery_embed, gallery_labels = compute_embeddings(model, gallery_loader, args.device)\n",
    "    query_embed, query_labels = compute_embeddings(model, val_loader, args.device)\n",
    "\n",
    "    # Compute the kNN classification accuracy\n",
    "    knn_accuracy = evaluate_embeddings(\n",
    "        F.normalize(gallery_embed, p=2, dim=1).numpy(), gallery_labels.numpy(),\n",
    "        F.normalize(query_embed, p=2, dim=1).numpy(), query_labels.numpy(),\n",
    "        args.knn_k, \"Validation k-NN\"\n",
    "    )\n",
    "\n",
    "    # Compute the recall@k score\n",
    "    recall_at_k = calculate_recall_at_k(query_embed, query_labels, gallery_embed, gallery_labels, k=args.recall_k)\n",
    "\n",
    "    # Compute the average validation loss\n",
    "    total_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for images, labels in val_loader:\n",
    "            images, labels = images.to(args.device), labels.to(args.device)\n",
    "            y_batch_norm = F.normalize(model(images), p=2, dim=1)\n",
    "            loss = criterion(y_batch_norm, labels)\n",
    "            total_loss += loss.item()\n",
    "    avg_val_loss = total_loss / len(val_loader)\n",
    "\n",
    "    return {\n",
    "        'val_loss': avg_val_loss,\n",
    "        'knn_acc': knn_accuracy,\n",
    "        f'recall@{args.recall_k}': recall_at_k\n",
    "    }\n",
    "\n",
    "\n",
    "def compute_embeddings(model, dataloader, device, baseline_resnet_only=False):\n",
    "    model.eval()\n",
    "    y_list, labels_list = [], []\n",
    "    pbar = tqdm(dataloader, desc=\"Computing embeddings\", unit=\"batch\", leave=False)\n",
    "    for images, labels in pbar:\n",
    "        images = images.to(device)\n",
    "        labels_list.append(labels.cpu())\n",
    "        with torch.no_grad():\n",
    "            if baseline_resnet_only:\n",
    "                y_batch = model.encoder(images)\n",
    "            else:\n",
    "                y_batch = model(images)\n",
    "        y_list.append(y_batch.cpu())\n",
    "    y_embed = torch.cat(y_list, dim=0)\n",
    "    all_labels = torch.cat(labels_list, dim=0)\n",
    "    return y_embed, all_labels\n",
    "\n",
    "\n",
    "def evaluate_embeddings(train_embed, train_labels, test_embed, test_labels, k, name):\n",
    "    print(f\"\\n--- Evaluating {name} embeddings (k={k}) ---\")\n",
    "    knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)\n",
    "    knn.fit(train_embed, train_labels)\n",
    "    preds = knn.predict(test_embed)\n",
    "    accuracy = accuracy_score(test_labels, preds)\n",
    "    print(f\"k-NN Classification Accuracy: {accuracy * 100:.2f}%\")\n",
    "    return accuracy\n",
    "\n",
    "\n",
    "def visualize_embeddings(embeddings, labels, output_dir, seed, name_prefix):\n",
    "    print(f\"\\n--- Visualizing {name_prefix} embeddings ---\")\n",
    "    tsne = TSNE(n_components=2, perplexity=30, max_iter=1000, random_state=seed, n_jobs=-1)\n",
    "    embeddings_tsne = tsne.fit_transform(embeddings)\n",
    "    plt.figure(figsize=(14, 12))\n",
    "    plt.scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=labels, cmap='hsv', s=12, alpha=0.7)\n",
    "    plt.title(f\"t-SNE of {name_prefix} Embeddings on CIFAR-10\", fontsize=16)\n",
    "    plt.xlabel(\"Component 1\")\n",
    "    plt.ylabel(\"Component 2\")\n",
    "    plt.grid(True, linestyle='--', alpha=0.6)\n",
    "    filename = f\"{name_prefix}_cifar10_tsne.png\"\n",
    "    plt.savefig(os.path.join(output_dir, filename))\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "    print(f\"Saved plot to: {os.path.join(output_dir, filename)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4386a41c-d115-4b45-adfe-3dbbfc2983cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# MODEL TRAINING FUNCTIONS\n",
    "# =============================================================================\n",
    "\n",
    "def run_epoch(model, dataloader, criterion, optimizer, args, epoch_desc, params_to_clip):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    pbar = tqdm(dataloader, desc=epoch_desc, unit=\"batch\")\n",
    "    for images, labels in pbar:\n",
    "        images, labels = images.to(args.device), labels.to(args.device)\n",
    "        optimizer.zero_grad()\n",
    "        y_batch = model(images)\n",
    "        y_batch_norm = F.normalize(y_batch, p=2, dim=1)\n",
    "        loss = criterion(y_batch_norm, labels)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=args.grad_clip_norm)\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "        pbar.set_postfix({\"Loss\": f\"{loss.item():.4f}\"})\n",
    "    return total_loss / len(dataloader)\n",
    "    \n",
    "\n",
    "def train_model(model, train_loader, val_loader, gallery_loader, criterion, optimizer, scheduler, args, output_dir):\n",
    "    \"\"\"Main training loop with robust checkpointing, CSV logging, and periodic validation.\"\"\"\n",
    "    start_epoch = 0\n",
    "    best_knn_acc = 0.0\n",
    "    epochs_no_improve = 0\n",
    "    \n",
    "    best_model_path = os.path.join(output_dir, 'best_model.pth')\n",
    "    latest_checkpoint_path = os.path.join(output_dir, 'latest_checkpoint.pth')\n",
    "    log_file_path = os.path.join(output_dir, 'training_log.csv')\n",
    "    final_best_model_path = None\n",
    "\n",
    "    if os.path.exists(latest_checkpoint_path):\n",
    "        print(f\"Resuming training from checkpoint: {latest_checkpoint_path}\")\n",
    "        checkpoint = torch.load(latest_checkpoint_path)\n",
    "        model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
    "        start_epoch = checkpoint['epoch']\n",
    "        best_knn_acc = checkpoint.get('best_knn_acc', 0.0) \n",
    "        epochs_no_improve = checkpoint.get('epochs_no_improve', 0)\n",
    "        print(f\"   Resumed from epoch {start_epoch}. Best k-NN accuracy: {best_knn_acc*100:.2f}%\")\n",
    "        print(f\"   Patience counter (epochs with no improvement): {epochs_no_improve}\")\n",
    "    else:\n",
    "        print(\"Starting training from scratch.\")\n",
    "        with open(log_file_path, 'w', newline='') as f:\n",
    "            writer = csv.writer(f)\n",
    "            header = ['epoch', 'train_loss', 'val_loss', 'knn_acc', f'recall@{args.recall_k}', 'lr_encoder', 'lr_head']\n",
    "            if args.is_ablation:\n",
    "                header = header[:-1]\n",
    "            writer.writerow(header)\n",
    "\n",
    "    for epoch in range(start_epoch, args.epochs):\n",
    "        epoch_desc = f\"Epoch {epoch+1}/{args.epochs}\"\n",
    "        avg_train_loss = run_epoch(model, train_loader, criterion, optimizer, args, epoch_desc, model.parameters())\n",
    "        scheduler.step()\n",
    "\n",
    "        if (epoch + 1) % args.validation_frequency == 0 or (epoch + 1) == args.epochs:\n",
    "            val_metrics = run_validation_metrics(model, val_loader, gallery_loader, criterion, args)\n",
    "\n",
    "            if not args.is_ablation:\n",
    "                lr_log_str = f\"LRs (ResNet/Head): {optimizer.param_groups[0]['lr']:.1e}/{optimizer.param_groups[1]['lr']:.1e}\"\n",
    "            else:\n",
    "                lr_log_str = f\"LR (ResNet): {optimizer.param_groups[0]['lr']:.1e}\"\n",
    "            \n",
    "            print(f\"Epoch {epoch+1}/{args.epochs} -> \"\n",
    "                  f\"Train Loss: {avg_train_loss:.4f} | \"\n",
    "                  f\"Val Loss: {val_metrics['val_loss']:.4f} | \"\n",
    "                  f\"Val k-NN: {val_metrics['knn_acc']*100:.2f}% | \"\n",
    "                  f\"Val Recall@{args.recall_k}: {val_metrics[f'recall@{args.recall_k}']*100:.2f}% | \"\n",
    "                  f\"{lr_log_str}\")\n",
    "\n",
    "            if val_metrics['knn_acc'] > best_knn_acc:\n",
    "                best_knn_acc = val_metrics['knn_acc']\n",
    "                epochs_no_improve = 0\n",
    "                torch.save(model.state_dict(), best_model_path)\n",
    "                final_best_model_path = best_model_path\n",
    "                print(f\"New best model saved with Val k-NN Acc: {best_knn_acc*100:.2f}%\")\n",
    "            else:\n",
    "                epochs_no_improve += args.validation_frequency\n",
    "\n",
    "            current_epoch_checkpoint_path = os.path.join(output_dir, f'checkpoint_epoch_{epoch+1}.pth')\n",
    "            torch.save({\n",
    "                'epoch': epoch + 1, 'model_state_dict': model.state_dict(),\n",
    "                'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),\n",
    "                'best_knn_acc': best_knn_acc, 'epochs_no_improve': epochs_no_improve,\n",
    "            }, current_epoch_checkpoint_path)\n",
    "            shutil.copyfile(current_epoch_checkpoint_path, latest_checkpoint_path)\n",
    "\n",
    "            with open(log_file_path, 'a', newline='') as f:\n",
    "                writer = csv.writer(f)\n",
    "                lr_encoder = optimizer.param_groups[0]['lr']\n",
    "                row_data = {\n",
    "                    'epoch': epoch + 1, 'train_loss': avg_train_loss, 'val_loss': val_metrics['val_loss'],\n",
    "                    'knn_acc': val_metrics['knn_acc'], f'recall@{args.recall_k}': val_metrics[f'recall@{args.recall_k}'],\n",
    "                    'lr_encoder': lr_encoder\n",
    "                }\n",
    "                if not args.is_ablation:\n",
    "                    row_data['lr_head'] = optimizer.param_groups[1]['lr']\n",
    "                writer.writerow(list(row_data.values()))\n",
    "            \n",
    "            if epochs_no_improve >= args.patience:\n",
    "                print(f\"Early stopping triggered after {epochs_no_improve} epochs with no improvement on k-NN accuracy.\")\n",
    "                break\n",
    "        else:\n",
    "            print(f\"Epoch {epoch+1}/{args.epochs} -> Train Loss: {avg_train_loss:.4f}\")\n",
    "            \n",
    "    print(\"--- Training Finished ---\")\n",
    "    return final_best_model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d31317c9-f730-4664-9b52-7526e79054a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# DATASETS AND DATALOADERS\n",
    "# =============================================================================\n",
    "\n",
    "# 1. Create two base datasets from the same training data but with different transforms\n",
    "train_dataset_aug = CIFAR10(root=args.data_dir, train=True, download=True, transform=train_transform)\n",
    "train_dataset_no_aug = CIFAR10(root=args.data_dir, train=True, download=True, transform=test_transform)\n",
    "\n",
    "# 2. Create a single set of indices to split the data\n",
    "num_train = len(train_dataset_aug)\n",
    "indices = list(range(num_train))\n",
    "split = int(0.9 * num_train)\n",
    "random.seed(args.seed)\n",
    "random.shuffle(indices)\n",
    "train_indices, val_indices = indices[:split], indices[split:]\n",
    "\n",
    "# 3. Create three distinct, disjoint subsets for each purpose\n",
    "train_subset = Subset(train_dataset_aug, train_indices)         # For training, has augmentations\n",
    "val_subset = Subset(train_dataset_no_aug, val_indices)          # For validation queries, no augmentations\n",
    "gallery_subset = Subset(train_dataset_no_aug, train_indices)    # For validation gallery, no augmentations\n",
    "\n",
    "# Final test set\n",
    "test_dataset = CIFAR10(root=args.data_dir, train=False, download=True, transform=test_transform)\n",
    "\n",
    "# 4. Create the final DataLoaders\n",
    "train_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)\n",
    "val_loader = DataLoader(val_subset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "gallery_loader = DataLoader(gallery_subset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fade69a0-97b8-4105-babd-10bc42a0db2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# DEFINE MODEL AND EVALUATE PRE-TRAINED ENCODER (UNTRAINED) BASELINE\n",
    "# =============================================================================\n",
    "\n",
    "model = MonoCon_ResNet(\n",
    "    resnet_embedding_dim=args.resnet_embedding_dim,\n",
    "    mlp_hidden_dim=args.mlp_hidden_dim,\n",
    "    mlp_num_layers=args.mlp_num_layers,\n",
    "    is_ablation=args.is_ablation,\n",
    ").to(args.device)\n",
    "\n",
    "print(\"\\n--- Evaluating Pre-Trained ResNet Encoder (Baseline) on CIFAR-10 ---\")\n",
    "# For this initial check, we can use the full train set as the gallery\n",
    "full_train_loader_no_aug = DataLoader(train_dataset_no_aug, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "y_train_initial, labels_train_initial = compute_embeddings(model, full_train_loader_no_aug, args.device, baseline_resnet_only=True)\n",
    "y_test_initial, labels_test_initial = compute_embeddings(model, test_loader, args.device, baseline_resnet_only=True)\n",
    "evaluate_embeddings(y_train_initial.numpy(), labels_train_initial.numpy(), y_test_initial.numpy(), labels_test_initial.numpy(), args.knn_k, \"initial 'y' (pre-trained ResNet)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81725abc-d269-4d9f-84c7-6eb370885dc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================================\n",
    "# EXECUTE HEAD WARMUP PHASE AND SET UP DIFFERENTIAL LEARNING RATE STRATEGY\n",
    "# ================================================================================================\n",
    "\n",
    "criterion = SupConLoss(temperature=args.temperature).to(args.device)\n",
    "\n",
    "if not args.is_ablation:\n",
    "    head_params = list(model.monotonic_mlp.parameters())\n",
    "    if args.warmup_epochs > 0:\n",
    "        print(f\"\\n--- Starting Head Warm-up for {args.warmup_epochs} epochs ---\")\n",
    "        for param in model.encoder.parameters():\n",
    "            param.requires_grad = False\n",
    "        optimizer_warmup = torch.optim.AdamW(head_params, lr=args.lr_warmup, weight_decay=args.weight_decay)\n",
    "        scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_warmup, T_max=args.warmup_epochs, eta_min=0)\n",
    "        for epoch in range(args.warmup_epochs):\n",
    "            epoch_desc = f\"Warm-up Epoch {epoch+1}/{args.warmup_epochs}\"\n",
    "            avg_loss_warmup = run_epoch(model, train_loader, criterion, optimizer_warmup, args, epoch_desc, head_params)\n",
    "            scheduler_warmup.step()\n",
    "            print(f\"{epoch_desc} -> Avg Loss: {avg_loss_warmup:.4f}\")\n",
    "        print(\"--- Warm-up finished. Unfreezing encoder for main training. ---\")\n",
    "        for param in model.encoder.parameters():\n",
    "            param.requires_grad = True\n",
    "\n",
    "    param_groups = [\n",
    "        {'params': model.encoder.parameters(), 'lr': args.lr_resnet},\n",
    "        {'params': head_params, 'lr': args.lr_mlp}\n",
    "    ]\n",
    "    optimizer = torch.optim.AdamW(param_groups, weight_decay=args.weight_decay)\n",
    "    print(\"--- Optimizer set up with differential learning rates for encoder and head for main training ---\")\n",
    "else:\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_resnet, weight_decay=args.weight_decay)\n",
    "    print(\"--- Performing ablation study. Optimizer set up with the ResNet learning rate alone. ---\")\n",
    "    print(\"--- There is no warmup phase in the ablation study. ---\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44524949-3e1f-48c4-afd8-cca75f25843c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================================\n",
    "# MAIN TRAINING PHASE\n",
    "# ================================================================================================\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0)\n",
    "best_model_path = train_model(model, train_loader, val_loader, gallery_loader, criterion, optimizer, scheduler, args, output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d29b9b4-cb01-4560-8f20-c1ee15e1c766",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================================\n",
    "# FINAL EVALUATION PHASE\n",
    "# ================================================================================================\n",
    "print(f\"\\n--- Final Evaluation on Best Model ---\")\n",
    "\n",
    "if best_model_path and os.path.exists(best_model_path):\n",
    "    print(f\"Loading best model from: {best_model_path}\")\n",
    "    model.load_state_dict(torch.load(best_model_path))\n",
    "\n",
    "    # For the final evaluation, we use the full training set as the gallery against the test set\n",
    "    y_train_final, labels_train_final = compute_embeddings(model, full_train_loader_no_aug, args.device)\n",
    "    y_test_final, labels_test_final = compute_embeddings(model, test_loader, args.device)\n",
    "\n",
    "    y_train_norm = F.normalize(y_train_final, p=2, dim=1).numpy()\n",
    "    y_test_norm = F.normalize(y_test_final, p=2, dim=1).numpy()\n",
    "    \n",
    "    model_name = \"MonoCon Ablation\" if args.is_ablation else \"MonoCon\"\n",
    "    evaluate_embeddings(y_train_norm, labels_train_final.numpy(), y_test_norm, labels_test_final.numpy(), args.knn_k, f\"final 'y' ({model_name})\")\n",
    "    \n",
    "    visualize_embeddings(y_test_norm, labels_test_final.numpy(), output_dir, args.seed, f\"y_final_{model_name}\")\n",
    "\n",
    "    print(f\"\\n--- PCA on Final 'y' Embeddings ---\")\n",
    "    pca = PCA(n_components=None)\n",
    "    pca.fit(y_test_norm) # It's common to fit PCA on test embeddings for analysis of the final representation\n",
    "    cumulative_variance = np.cumsum(pca.explained_variance_ratio_)\n",
    "    num_components_99_var = np.searchsorted(cumulative_variance, 0.99) + 1\n",
    "    \n",
    "    print(f\"Number of components required to explain 99% of variance: {num_components_99_var}\")\n",
    "    if num_components_99_var <= len(cumulative_variance):\n",
    "        print(f\"   (The top {num_components_99_var} components actually explain {cumulative_variance[num_components_99_var-1] * 100:.2f}% of the variance)\")\n",
    "    \n",
    "    print(f\"\\n--- Detailed breakdown for the first 20 components ---\")\n",
    "    for i in range(20):\n",
    "        if i < len(pca.explained_variance_ratio_):\n",
    "            var = pca.explained_variance_ratio_[i]\n",
    "            cum_var = cumulative_variance[i]\n",
    "            print(f\"  - PC {i+1}: {var * 100:.2f}% (Cumulative: {cum_var * 100:.2f}%)\")\n",
    "else:\n",
    "    print(\"Could not find a saved model. Skipping final evaluation.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ea194ee-3f38-4744-90f9-8cd0ab63d3c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
