{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daf20fa2-01e3-46ff-9378-3a4b7949ff8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# IMPORTS\n",
    "# =============================================================================\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import numpy as np\n",
    "import os\n",
    "import shutil\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import spearmanr\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28411834-016d-4b83-8ba0-75220191abe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# CONFIGURATION\n",
    "# =============================================================================\n",
    "\n",
    "####################### Create a separate class to define all parameters #######################\n",
    "class Config:\n",
    "    \"\"\"Configuration class for all hyperparameters.\"\"\"\n",
    "    # Model parameters\n",
    "    model_name = 'sentence-transformers/all-MiniLM-L6-v2'\n",
    "    lm_embedding_dim = 384      # MiniLM's output dimension\n",
    "    mlp_hidden_dim = 768        # Monotonic MLP hidden dimension\n",
    "    mlp_num_layers = 1          # Monotonic MLP number of hidden layers\n",
    "    max_seq_length = 128        # Max sequence length for tokenizer\n",
    "\n",
    "    # Training parameters\n",
    "    epochs = 40                 # Maximum number of epochs used for training\n",
    "    patience = 10                # Number of epochs after which early stopping is triggered\n",
    "    batch_size = 128            # Mini-batch size\n",
    "    warmup_epochs = 1           # Number of epochs to warm up the head. Set to 0 to disable.\n",
    "    lr_warmup = 1e-4            # Learning rate specifically for the warm-up phase.\n",
    "    lr_lm = 2e-4                # Learning rate for fine-tuning the language model encoder\n",
    "    lr_mlp = 2e-4               # Learning rate for the monotonic MLP head \n",
    "    weight_decay = 1e-4         # Weight decay (L2) regularization\n",
    "    temperature = 0.05          # Temperature parameter for SupCon loss. Controls the importance of hard negatives in the learning task\n",
    "    grad_clip_norm = 1.0        # Max norm of gradients. Leads to more stable training\n",
    "    seed = 42                   # random number generator seed\n",
    "    is_ablation = False         # Set to True to bypass the monotonic MLP head for the ablation study training run.\n",
    "\n",
    "    # System parameters\n",
    "    base_results_dir = './results_SNLI' # Only the base path is needed here\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "args = Config()                  # Create an instance of the Config class to pass arguments to functions as needed\n",
    "#################################################################################################\n",
    "\n",
    "\n",
    "############### Set up the code for reproducibility ###################\n",
    "print(f\"Using device: {args.device}\")\n",
    "torch.manual_seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "random.seed(args.seed)\n",
    "if args.device == 'cuda':\n",
    "    torch.cuda.manual_seed_all(args.seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "#######################################################################\n",
    "\n",
    "\n",
    "############# Prepare output directories ##############################\n",
    "# Create a unique, descriptive name for the experiment run\n",
    "run_name = (\n",
    "    f\"temp_{args.temperature}_bs_{args.batch_size}_lr_encoder_{args.lr_lm}_epochs_{args.epochs}_patience_{args.patience}\"\n",
    "    f\"{f'_lr_head_{args.lr_mlp}' if not args.is_ablation else ''}\"  # <-- Conditionally add head learning rate\n",
    "    f\"{'_ablation' if args.is_ablation else ''}\"\n",
    ")\n",
    "\n",
    "# Define the final output directory for this specific run\n",
    "output_dir = os.path.join(args.base_results_dir, run_name)\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "print(f\"Output directory for this run: {output_dir}\")\n",
    "#######################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e0dbb52-e261-41a5-b24c-3abdff026886",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# MODEL AND DATA CLASSES\n",
    "# =============================================================================\n",
    "\n",
    "class SupConLoss(nn.Module):\n",
    "    \"\"\"Supervised Contrastive Learning loss function.\"\"\"\n",
    "    def __init__(self, temperature=0.07):\n",
    "        super().__init__()\n",
    "        self.temperature = temperature\n",
    "\n",
    "    def forward(self, features, labels):\n",
    "        device = features.device                                       # Identify the device on which the features tensor is stored\n",
    "        batch_size = features.shape[0]                                 # Extract the batch_size. Useful for later operations\n",
    "        labels = labels.contiguous().view(-1, 1)                       # Assign a continuous memory block to sotre labels. Reshape the labels tensor\n",
    "        mask = torch.eq(labels, labels.T).float().to(device)           # Creates a square matrix mask via broadcasting, to identify positive pairs\n",
    "        anchor_dot_contrast = torch.div(torch.matmul(features, features.T), self.temperature)  # Matrix of scaled feature dot products\n",
    "        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)            # Compute maximum value of the scaled dot product in each row\n",
    "        # The next line subtracts a constant from each row, which leaves the loss unchanged but prevents numerical instability\n",
    "        logits = anchor_dot_contrast - logits_max.detach()\n",
    "        logits_mask = torch.scatter(\n",
    "            torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).to(device), 0\n",
    "        )\n",
    "        mask = mask * logits_mask                     # Removes diagonal elements, so that positive pairs of samples with themselves are not counted\n",
    "        \n",
    "        # ====== Main SupCon loss calculation =============\n",
    "        exp_logits = torch.exp(logits) * logits_mask\n",
    "        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-9)\n",
    "        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-9)\n",
    "        loss = -mean_log_prob_pos.mean()\n",
    "        # =================================================\n",
    "        return loss\n",
    "\n",
    "class MonotonicMLP(nn.Module):\n",
    "    \"\"\"Monotonic MLP implemented using positive weights and nondecreasing activation functions.\"\"\"\n",
    "    def __init__(self, input_dim: int, hidden_dim: int, num_layers: int):  \n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        layers = []\n",
    "        current_dim = input_dim\n",
    "        for _ in range(num_layers):\n",
    "            layers.append(nn.Linear(current_dim, hidden_dim))\n",
    "            current_dim = hidden_dim\n",
    "        self.hidden_layers = nn.ModuleList(layers)\n",
    "        self.output_layer = nn.Linear(current_dim, input_dim)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:    # Implement monotonicity constraints via squared weights and non-decreasing activations\n",
    "        z = x\n",
    "        for layer in self.hidden_layers:\n",
    "            positive_weight = layer.weight**2    # Using squared weights to ensure positivity in the hidden layers during the forward pass\n",
    "            z = F.leaky_relu(F.linear(z, positive_weight, layer.bias))  # Uses the non-decreasing Leaky Relu activation function\n",
    "        positive_weight_out = self.output_layer.weight**2\n",
    "        z = F.linear(z, positive_weight_out, self.output_layer.bias)    # Squared weights are used for positivity in the output as well\n",
    "        return z\n",
    "\n",
    "class MonoCon_LM(nn.Module):\n",
    "    \"\"\"Metric learning model using a pre-trained LM and a Monotonic MLP head.\"\"\"\n",
    "    def __init__(self, model_name: str, lm_embedding_dim: int, mlp_hidden_dim: int, mlp_num_layers: int, is_ablation: bool = False):\n",
    "        super().__init__()\n",
    "        self.is_ablation = is_ablation      # Flag indicating whether to run the full model or ablation study\n",
    "        self.encoder = AutoModel.from_pretrained(model_name)     # Sentence transformer encoder (MiniLM) specified in the Config class\n",
    "        if not self.is_ablation:\n",
    "            self.monotonic_mlp = MonotonicMLP(lm_embedding_dim, mlp_hidden_dim, mlp_num_layers)  # Monotonic MLP is created only for the full model\n",
    "\n",
    "    def _mean_pooling(self, model_output, attention_mask):    # Generates a single average vector from multiple contextual embedding vectors \n",
    "        token_embeddings = model_output[0]                    # Extracts token embeddings from the model output\n",
    "        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()   # Accounts for variable sentence lengths by using padded dimensions that don't contribute to averaging\n",
    "        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
    "\n",
    "    def forward(self, **kwargs):\n",
    "        y_encoder_tokens = self.encoder(**kwargs)     # The MiniLM encoder generates sequences of tokens with contextual embeddings for each token\n",
    "        y_encoder = self._mean_pooling(y_encoder_tokens, kwargs['attention_mask'])    # Averages the contextual embeddings to yield a single sentence embedding vector \n",
    "        y = y_encoder if self.is_ablation else self.monotonic_mlp(y_encoder)          # MLP is only used to get embeddings for the full model\n",
    "        return y\n",
    "\n",
    "class SNLIDataset(Dataset):\n",
    "    \"\"\"Custom PyTorch dataset that contains only those sentence pairs from the full SNLI dataset that have an entailment relationship. \n",
    "       These constitute positive pairs for the purposes of supervised contrastive learning\"\"\"\n",
    "    def __init__(self, data):\n",
    "        self.samples = []\n",
    "        for example in data:\n",
    "            if example['label'] == 0: # The label 0 indicates entailment\n",
    "                self.samples.append((example['premise'], example['hypothesis']))\n",
    "\n",
    "    def __len__(self):      # Returns the total number of entailment pairs, i.e. the total dataset size\n",
    "        return len(self.samples)\n",
    "\n",
    "    def __getitem__(self, idx):    # Returns a sentence pair based on its index\n",
    "        return self.samples[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "244f2a47-1681-401b-97fe-624ccd1012b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# HELPER FUNCTIONS\n",
    "# =============================================================================\n",
    "\n",
    "def create_collate_fn(tokenizer, max_length):\n",
    "    def collate_fn(batch):\n",
    "        premises = [item[0] for item in batch]\n",
    "        hypotheses = [item[1] for item in batch]\n",
    "        sentences = premises + hypotheses\n",
    "        tokenized = tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='pt')\n",
    "        labels = torch.arange(len(premises)).repeat(2)\n",
    "        return tokenized, labels\n",
    "    return collate_fn\n",
    "\n",
    "\n",
    "def compute_embeddings_for_stsb(model, tokenizer, stsb_data, device, args):\n",
    "    model.eval()\n",
    "    sents1 = [ex['sentence1'] for ex in stsb_data]\n",
    "    sents2 = [ex['sentence2'] for ex in stsb_data]\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def get_embeds(sentences):\n",
    "        all_embeds = []\n",
    "        pbar_embed = tqdm(range(0, len(sentences), args.batch_size), desc=\"Computing embeddings\", leave=False)\n",
    "        for start_idx in pbar_embed:\n",
    "            batch_sents = sentences[start_idx : start_idx + args.batch_size]\n",
    "            tokenized = tokenizer(batch_sents, padding=True, truncation=True, max_length=args.max_seq_length, return_tensors='pt')\n",
    "            tokenized = {k: v.to(device) for k, v in tokenized.items()}\n",
    "            y_batch = model(**tokenized)\n",
    "            all_embeds.append(y_batch.cpu())\n",
    "        return torch.cat(all_embeds, dim=0)\n",
    "\n",
    "    embeds1 = get_embeds(sents1)\n",
    "    embeds2 = get_embeds(sents2)\n",
    "    return embeds1, embeds2\n",
    "\n",
    "\n",
    "def evaluate_stsb(model, tokenizer, stsb_data, args, silent=False):\n",
    "    if not silent:\n",
    "        print(f\"\\n--- Evaluating on STSb Benchmark ---\")\n",
    "    \n",
    "    embeddings1, embeddings2 = compute_embeddings_for_stsb(model, tokenizer, stsb_data, args.device, args)\n",
    "    \n",
    "    embeddings1 = F.normalize(embeddings1, p=2, dim=1)\n",
    "    embeddings2 = F.normalize(embeddings2, p=2, dim=1)\n",
    "    \n",
    "    cosine_scores = torch.sum(embeddings1 * embeddings2, dim=1)\n",
    "    labels = torch.tensor([ex['score'] for ex in stsb_data])\n",
    "    \n",
    "    spearman_corr, _ = spearmanr(cosine_scores.numpy(), labels.numpy())\n",
    "    \n",
    "    if not silent:\n",
    "        print(f\"Spearman Correlation: {spearman_corr:.4f}\")\n",
    "    return spearman_corr\n",
    "\n",
    "\n",
    "def evaluate_encoder_only_stsb(encoder_model, tokenizer, stsb_data, args):\n",
    "    \"\"\"Evaluates the encoder part of the model only.\"\"\"\n",
    "    encoder_model.eval()\n",
    "    device = args.device\n",
    "    \n",
    "    sents1 = [ex['sentence1'] for ex in stsb_data]\n",
    "    sents2 = [ex['sentence2'] for ex in stsb_data]\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def get_encoder_embeds(sentences):\n",
    "        all_embeds = []\n",
    "        pbar_embed = tqdm(range(0, len(sentences), args.batch_size), desc=\"Computing encoder embeddings\", leave=False)\n",
    "        for start_idx in pbar_embed:\n",
    "            batch_sents = sentences[start_idx : start_idx + args.batch_size]\n",
    "            tokenized = tokenizer(batch_sents, padding=True, truncation=True, max_length=args.max_seq_length, return_tensors='pt')\n",
    "            tokenized = {k: v.to(device) for k, v in tokenized.items()}\n",
    "            \n",
    "            # Manually perform the encoder forward pass and pooling\n",
    "            model_output = encoder_model(**tokenized)\n",
    "            token_embeddings = model_output[0]\n",
    "            input_mask_expanded = tokenized['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()\n",
    "            pooled_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
    "            \n",
    "            all_embeds.append(pooled_embeddings.cpu())\n",
    "        return torch.cat(all_embeds, dim=0)\n",
    "\n",
    "    embeddings1 = get_encoder_embeds(sents1)\n",
    "    embeddings2 = get_encoder_embeds(sents2)\n",
    "    \n",
    "    embeddings1 = F.normalize(embeddings1, p=2, dim=1)\n",
    "    embeddings2 = F.normalize(embeddings2, p=2, dim=1)\n",
    "    \n",
    "    cosine_scores = torch.sum(embeddings1 * embeddings2, dim=1)\n",
    "    labels = torch.tensor([ex['score'] for ex in stsb_data])\n",
    "    spearman_corr, _ = spearmanr(cosine_scores.numpy(), labels.numpy())\n",
    "    \n",
    "    print(f\"Spearman Correlation: {spearman_corr:.4f}\")\n",
    "    return spearman_corr\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b758fcc5-7e8d-40bd-89fe-fe1e934926d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# MODEL TRAINING FUNCTIONS\n",
    "# =============================================================================\n",
    "\n",
    "def run_epoch(model, dataloader, criterion, optimizer, args, epoch_desc, params_to_clip):\n",
    "    \"\"\"Runs a single training epoch and returns the average loss.\"\"\"\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    pbar = tqdm(dataloader, desc=epoch_desc, unit=\"batch\")\n",
    "\n",
    "    for tokenized_batch, labels in pbar:\n",
    "        # Make sure tokenized embeddings and lables are on the same device as other tensors\n",
    "        tokenized_batch = {k: v.to(args.device) for k, v in tokenized_batch.items()}   \n",
    "        labels = labels.to(args.device)\n",
    "        \n",
    "        optimizer.zero_grad()                                             # Zero out gradients from the previous iteration\n",
    "\n",
    "        y_batch = model(**tokenized_batch)                   # Forward pass. Compute model output\n",
    "        y_batch_norm = F.normalize(y_batch, p=2, dim=1)      # Normalize output embeddings\n",
    "\n",
    "        loss = criterion(y_batch_norm, labels)               # Compute SupCon loss\n",
    "        loss.backward()                                      # Calculate gradients (backward pass)\n",
    "        \n",
    "        # Clip gradients for the specified parameters\n",
    "        torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=args.grad_clip_norm)   # Implement gradient norm clipping for stability\n",
    "        optimizer.step()    # Update model weights\n",
    "\n",
    "        total_loss += loss.item()                              # Add batch loss to total loss for the epoch\n",
    "        pbar.set_postfix({\"Loss\": f\"{loss.item():.4f}\"})        # Attach postfix description to progress bar\n",
    "\n",
    "    return total_loss / len(dataloader)                        # Return average loss\n",
    "\n",
    "\n",
    "def train_model(model, train_loader, stsb_val_data, tokenizer, criterion, optimizer, scheduler, args, output_dir):\n",
    "    \"\"\"\n",
    "    Main training loop with robust, fully resumable checkpointing.\n",
    "\n",
    "    Args:\n",
    "        model (nn.Module): The model to be trained.\n",
    "        train_loader (DataLoader): DataLoader for the training set.\n",
    "        val_loader (DataLoader): DataLoader for the STSb validation set.\n",
    "        tokenizer (Tokenizer): Generator of word tokens\n",
    "        criterion (nn.Module): The loss function.\n",
    "        optimizer (torch.optim.Optimizer): The optimizer.\n",
    "        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.\n",
    "        args (Config): Configuration object with hyperparameters.\n",
    "        output_dir (str): The directory to save checkpoints and the best model.\n",
    "\n",
    "    Returns:\n",
    "        str: The file path to the best performing model's weights, or None if no model was saved.\n",
    "    \"\"\"\n",
    "    # --- 1. INITIALIZATION AND PATH DEFINITION ---\n",
    "    start_epoch = 0\n",
    "    best_val_corr = -1\n",
    "    epochs_no_improve = 0\n",
    "\n",
    "    # Define all file paths used in the function for clarity\n",
    "    best_model_path = os.path.join(output_dir, 'best_model.pth')\n",
    "    latest_checkpoint_path = os.path.join(output_dir, 'latest_checkpoint.pth')\n",
    "    final_best_model_path = None # Will store the path of the best model if one is saved\n",
    "\n",
    "    # --- 2. LOAD CHECKPOINT IF IT EXISTS ---\n",
    "    if os.path.exists(latest_checkpoint_path):\n",
    "        print(f\"Resuming training from checkpoint: {latest_checkpoint_path}\")\n",
    "        checkpoint = torch.load(latest_checkpoint_path)\n",
    "        \n",
    "        model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
    "        \n",
    "        start_epoch = checkpoint['epoch']\n",
    "        best_val_corr = checkpoint['best_val_corr']\n",
    "        # Load early stopping counter, defaulting to 0 if not found for backward compatibility\n",
    "        epochs_no_improve = checkpoint.get('epochs_no_improve', 0) \n",
    "        \n",
    "        print(f\"   Resumed from epoch {start_epoch}. Best validation loss: {best_val_corr:.4f}\")\n",
    "        print(f\"   Patience counter (epochs with no improvement): {epochs_no_improve}\")\n",
    "    else:\n",
    "        print(\"Starting training from scratch.\")\n",
    "\n",
    "    # --- 3. MAIN TRAINING LOOP ---\n",
    "    for epoch in range(start_epoch, args.epochs):\n",
    "        epoch_desc = f\"Epoch {epoch+1}/{args.epochs} (Train)\"\n",
    "        avg_train_loss = run_epoch(model, train_loader, criterion, optimizer, args, epoch_desc, model.parameters())\n",
    "        val_corr = evaluate_stsb(model, tokenizer, stsb_val_data, args, silent=True)\n",
    "        scheduler.step()\n",
    "\n",
    "        # Print epoch summary\n",
    "        # First, determine the learning rate string based on the mode\n",
    "        if not args.is_ablation:\n",
    "            lr_log_str = f\"LRs (MiniLM/Head): {optimizer.param_groups[0]['lr']:.1e}/{optimizer.param_groups[1]['lr']:.1e}\"\n",
    "        else:\n",
    "            lr_log_str = f\"LR (MiniLM): {optimizer.param_groups[0]['lr']:.1e}\"\n",
    "        \n",
    "        # Then, use that string in a single, clean print statement\n",
    "        print(f\"Epoch {epoch+1}/{args.epochs} -> \"\n",
    "              f\"Train Loss: {avg_train_loss:.4f} | \"\n",
    "              f\"Val Spearman Corr: {val_corr:.4f} | \"\n",
    "              f\"{lr_log_str}\")\n",
    "\n",
    "        # --- 4. SAVE/COPY OPERATIONS & EARLY STOPPING ---\n",
    "        \n",
    "        # A) Save the best model if validation loss improves\n",
    "        if val_corr > best_val_corr:\n",
    "            best_val_corr = val_corr\n",
    "            epochs_no_improve = 0\n",
    "            torch.save(model.state_dict(), best_model_path)\n",
    "            final_best_model_path = best_model_path # Update the final path\n",
    "            print(f\"New best model saved with Val Correlation: {best_val_corr:.4f}\")\n",
    "        else:\n",
    "            epochs_no_improve += 1\n",
    "\n",
    "        # B) Save a comprehensive checkpoint for the current epoch\n",
    "        current_epoch_checkpoint_path = os.path.join(output_dir, f'checkpoint_epoch_{epoch+1}.pth')\n",
    "        torch.save({\n",
    "            'epoch': epoch + 1,\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'scheduler_state_dict': scheduler.state_dict(),\n",
    "            'best_val_corr': best_val_corr,\n",
    "            'epochs_no_improve': epochs_no_improve,\n",
    "        }, current_epoch_checkpoint_path)\n",
    "\n",
    "        # C) Update the 'latest' checkpoint to point to this epoch's file\n",
    "        shutil.copyfile(current_epoch_checkpoint_path, latest_checkpoint_path)\n",
    "\n",
    "        # D) Check for early stopping\n",
    "        if epochs_no_improve >= args.patience:\n",
    "            print(f\"Early stopping triggered after {args.patience} epochs with no improvement.\")\n",
    "            break\n",
    "            \n",
    "    print(\"--- Training Finished ---\")\n",
    "    return final_best_model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d8e7c08-2ce5-4410-99bb-b02ec2d75461",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# DATASETS AND DATALOADERS\n",
    "# =============================================================================\n",
    "\n",
    "print(\"--- Loading tokenizer and datasets ---\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n",
    "\n",
    "snli_dataset_raw = load_dataset('snli', split='train').filter(lambda ex: ex['label'] in [0, 1, 2])\n",
    "train_dataset = SNLIDataset(snli_dataset_raw)\n",
    "collate_fn = create_collate_fn(tokenizer, args.max_seq_length)\n",
    "train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=0)\n",
    "\n",
    "stsb_val_data = list(load_dataset('sentence-transformers/stsb', split='validation'))\n",
    "stsb_test_data = list(load_dataset('sentence-transformers/stsb', split='test'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b6493d-997f-4077-9946-d6e0a9a88128",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# DEFINE MODEL AND EVALUATE PRE-TRAINED ENCODER BASELINE\n",
    "# =============================================================================\n",
    "\n",
    "model = MonoCon_LM(\n",
    "    model_name=args.model_name,\n",
    "    lm_embedding_dim=args.lm_embedding_dim,\n",
    "    mlp_hidden_dim=args.mlp_hidden_dim,\n",
    "    mlp_num_layers=args.mlp_num_layers,\n",
    "    is_ablation=args.is_ablation,\n",
    ").to(args.device)\n",
    "\n",
    "print(\"\\n--- Evaluating Pre-Trained MiniLM Encoder (Baseline) on STSb ---\")\n",
    "evaluate_encoder_only_stsb(model.encoder, tokenizer, stsb_test_data, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "555d4ad3-8075-4f6e-b4eb-d8604a082443",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================================\n",
    "# EXECUTE HEAD WARMUP PHASE AND SET UP DIFFERENTIAL LEARNING RATE STRATEGY IF RUNNING FULL MODEL\n",
    "# ================================================================================================\n",
    "\n",
    "criterion = SupConLoss(temperature=args.temperature).to(args.device)    # Define SupCon loss as the loss function\n",
    "\n",
    "if not args.is_ablation:\n",
    "\n",
    "    head_params = list(model.monotonic_mlp.parameters()) # Define head parameters separately for warmup phase and differential learning rate strategy\n",
    "    \n",
    "    if args.warmup_epochs > 0:      # The warmup phase is executed only if number of warmup epochs is a positive integer\n",
    "        print(f\"\\n--- Starting Head Warm-up for {args.warmup_epochs} epochs ---\")\n",
    "        \n",
    "        # Freeze the encoder layers\n",
    "        for param in model.encoder.parameters():\n",
    "            param.requires_grad = False\n",
    "        \n",
    "        # Create a new optimizer and scheduler just for the head during warm-up\n",
    "        # The 'head_params' variable is already defined above\n",
    "        optimizer_warmup = torch.optim.AdamW(head_params, lr=args.lr_warmup, weight_decay=args.weight_decay)   # Optimizer is set up with learning rate for the head warmup phase\n",
    "        scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_warmup, T_max=args.warmup_epochs, eta_min=0)\n",
    "    \n",
    "        # Simplified training loop for the warm-up phase\n",
    "        for epoch in range(args.warmup_epochs):\n",
    "    \n",
    "            epoch_desc = f\"Warm-up Epoch {epoch+1}/{args.warmup_epochs}\"\n",
    "            avg_loss_warmup = run_epoch(model, train_loader, criterion, optimizer_warmup, args, epoch_desc, head_params)\n",
    "            scheduler_warmup.step()      # Update learning rate according to the warmup learning rate scheduler\n",
    "            print(f\"{epoch_desc} -> Avg Loss: {avg_loss_warmup:.4f}\")\n",
    "    \n",
    "        # Unfreeze the encoder layers for the main training phase\n",
    "        print(\"--- Warm-up finished. Unfreezing encoder for main training. ---\")\n",
    "        for param in model.encoder.parameters():\n",
    "            param.requires_grad = True\n",
    "\n",
    "    # Set up optimizer for main training\n",
    "    param_groups = [\n",
    "        {'params': model.encoder.parameters(), 'lr': args.lr_lm},      # Parameters and learning rate for encoder\n",
    "        {'params': head_params, 'lr': args.lr_mlp}                         # Parameters and learning rate for monotonic MLP head  \n",
    "    ]\n",
    "    # Note: 'params' and 'lr' are special dictionary keys recognized by the AdamW optimizer, and 'params' is mandatory. \n",
    "    optimizer = torch.optim.AdamW(param_groups, weight_decay=args.weight_decay)   # Set up AdamW optimizer with differential learning rates for encoder and head\n",
    "    print(\"--- Optimizer set up with differential learning rates for encoder and head for main training ---\")\n",
    "        \n",
    "else:\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_lm, weight_decay=args.weight_decay) # No differential learning rates for the ablation study\n",
    "    print(\"--- Performing ablation study. Optimizer set up with the MiniLM learning rate alone. ---\")\n",
    "    print(\"--- There is no warmup phase in the ablation study. ---\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b6df785-b3ac-4f66-b46f-38910003791d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================================\n",
    "# MAIN TRAINING PHASE\n",
    "# ================================================================================================\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)\n",
    "best_model_path = train_model(model, train_loader, stsb_val_data, tokenizer, criterion, optimizer, scheduler, args, output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca33e9b0-3f21-4e1c-9f02-84722d952658",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================================\n",
    "# FINAL EVALUATION PHASE\n",
    "# ================================================================================================\n",
    "print(f\"\\n--- Final Evaluation on Best Model ---\")\n",
    "\n",
    "# Check if the model was successfully trained and saved\n",
    "if best_model_path and os.path.exists(best_model_path):\n",
    "    print(f\"Loading best model from: {best_model_path}\")\n",
    "    model.load_state_dict(torch.load(best_model_path))\n",
    "    evaluate_stsb(model, tokenizer, stsb_test_data, args)\n",
    "\n",
    "    print(\"--- Starting Final Analysis: PCA and Reconstruction Error ---\")\n",
    "    \n",
    " # =============================================================================\n",
    "    # 1. HELPER FUNCTION TO GENERATE EMBEDDINGS\n",
    "    # =============================================================================\n",
    "    @torch.no_grad(   )\n",
    "    def generate_embeddings(sentences, model, tokenizer, args):\n",
    "        \"\"\"Computes embeddings for a given list of sentences.\"\"\"\n",
    "        model.eval()\n",
    "        all_embeds = []\n",
    "        # Create a progress bar for embedding generation\n",
    "        pbar_embed = tqdm(\n",
    "            range(0, len(sentences), args.batch_size),\n",
    "            desc=f\"Computing embeddings for {len(sentences)} sentences\",\n",
    "            leave=False\n",
    "        )\n",
    "        for start_idx in pbar_embed:\n",
    "            batch_sents = sentences[start_idx : start_idx + args.batch_size]\n",
    "            tokenized = tokenizer(\n",
    "                batch_sents,\n",
    "                padding=True,\n",
    "                truncation=True,\n",
    "                max_length=args.max_seq_length,\n",
    "                return_tensors='pt'\n",
    "            )\n",
    "            tokenized = {k: v.to(args.device) for k, v in tokenized.items()}\n",
    "            y_batch = model(**tokenized)\n",
    "            all_embeds.append(y_batch.cpu())\n",
    "        return torch.cat(all_embeds, dim=0)\n",
    "    \n",
    "    # =============================================================================\n",
    "    # 2. GENERATE AND NORMALIZE EMBEDDINGS FOR TRAIN AND TEST SETS\n",
    "    # =============================================================================\n",
    "    # NOTE: This script assumes 'model' is the best-performing model, already loaded.\n",
    "    model.eval()\n",
    "    \n",
    "    # --- Generate Test Embeddings (STSb) ---\n",
    "    print(\"\\nStep 1: Generating embeddings for the TEST dataset (STSb)...\")\n",
    "    sents1_test = [ex['sentence1'] for ex in stsb_test_data]\n",
    "    sents2_test = [ex['sentence2'] for ex in stsb_test_data]\n",
    "    unique_test_sentences = list(set(sents1_test + sents2_test))\n",
    "    test_embeddings_tensor = generate_embeddings(unique_test_sentences, model, tokenizer, args)\n",
    "    # Normalize embeddings for PCA, as it's sensitive to vector magnitude\n",
    "    test_embed = F.normalize(test_embeddings_tensor, p=2, dim=1).numpy()\n",
    "    print(f\"Generated {test_embed.shape[0]} unique test embeddings of dimension {test_embed.shape[1]}.\")\n",
    "    \n",
    "    # --- Generate Train Embeddings (SNLI) ---\n",
    "    print(\"\\nStep 2: Generating embeddings for the TRAIN dataset (SNLI)...\")\n",
    "    train_sentences = list(set([s for p, h in train_dataset.samples for s in (p, h)]))\n",
    "    train_embeddings_tensor = generate_embeddings(train_sentences, model, tokenizer, args)\n",
    "    # Normalize embeddings\n",
    "    train_embed = F.normalize(train_embeddings_tensor, p=2, dim=1).numpy()\n",
    "    print(f\"Generated {train_embed.shape[0]} unique train embeddings of dimension {train_embed.shape[1]}.\")\n",
    "    \n",
    "    # =============================================================================\n",
    "    # 3. PERFORM PCA AND RECONSTRUCTION ERROR ANALYSIS\n",
    "    # =============================================================================\n",
    "    print(\"\\nStep 3: Performing PCA and calculating generalization error...\")\n",
    "    \n",
    "    # --- Fit a full PCA on the training data to find the optimal number of components ---\n",
    "    print(\"\\n--- Fitting PCA on training data to analyze variance ---\")\n",
    "    pca_full = PCA(n_components=None, random_state=42)\n",
    "    pca_full.fit(train_embed)\n",
    "    \n",
    "    # Find the number of components that explain 99% of the variance\n",
    "    cumulative_variance = np.cumsum(pca_full.explained_variance_ratio_)\n",
    "    n_for_99_variance = np.searchsorted(cumulative_variance, 0.99) + 1\n",
    "    print(f\"Number of components to explain 99% of TRAIN variance: {n_for_99_variance}\")\n",
    "    \n",
    "    # --- Calculate Reconstruction Error to Quantify Generalization ---\n",
    "    print(f\"\\n--- Calculating Test Set Reconstruction Error using top {n_for_99_variance} components from train subspace ---\")\n",
    "    \n",
    "    # 1. Define a new PCA model with the determined number of components\n",
    "    pca_reconstruction = PCA(n_components=n_for_99_variance, random_state=42)\n",
    "    \n",
    "    # 2. Fit this PCA model ONLY on the training data to define the subspace\n",
    "    pca_reconstruction.fit(train_embed)\n",
    "    \n",
    "    # 3. Reconstruct the test data by projecting it onto the train subspace and then inverting\n",
    "    test_projected = pca_reconstruction.transform(test_embed)\n",
    "    test_reconstructed = pca_reconstruction.inverse_transform(test_projected)\n",
    "    \n",
    "    # 4. Calculate the Mean Squared Error between the original and reconstructed test data\n",
    "    reconstruction_mse = np.mean(np.square(test_embed - test_reconstructed))\n",
    "    \n",
    "    # 5. Calculate the Root Mean Squared Error between original and reconstructed test data\n",
    "    reconstruction_rmse = np.sqrt(np.mean(np.square(test_embed - test_reconstructed)))\n",
    "    print(f\"Test Set Reconstruction RMSE from Train Subspace: {reconstruction_rmse:.8f}\")\n",
    "\n",
    "    print(f\"Test Set Reconstruction MSE from Train Subspace: {reconstruction_mse:.8f}\")\n",
    "    \n",
    "    print(\"\\n--- Analysis Complete ---\")\n",
    "else:\n",
    "    print(\"Could not find a saved model. Skipping final evaluation.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
